geowatch.utils.lightning_ext.callbacks.batch_plotter module¶
- class geowatch.utils.lightning_ext.callbacks.batch_plotter.BatchPlotter(num_draw=2, draw_interval='5minutes', max_items=2, overlay_on_image=False)[source]¶
Bases:
Callback
These are callbacks used to monitor the training.
To be used, the trainer datamodule must have a draw_batch method that returns an ndarray to draw a batch.
See [LightningCallbacks].
- Parameters:
num_draw (int) – number of batches to draw at the start of each epoch
draw_interval (datetime.timedelta | str | numbers.Number) – This is the amount of time to wait before drawing the next batch item within an epoch. Can be given as a timedelta, a string parsable by coerce_timedelta (e.g. ‘1M’) or a numeric number of seconds.
max_items (int) – Maximum number of items within this batch to draw in a single figure. Defaults to 2.
overlay_on_image (bool) – if True overlay annotations on image data for a more compact view. if False separate annotations / images for a less cluttered view.
- FIXME:
[ ] This breaks when using strategy=DDP and multiple gpus
Todo
[ ] Doctest
Example
>>> # >>> from geowatch.utils.lightning_ext.callbacks.batch_plotter import * # NOQA >>> from geowatch.utils.lightning_ext import demo >>> from geowatch.monkey import monkey_lightning >>> monkey_lightning.disable_lightning_hardware_warnings() >>> model = demo.LightningToyNet2d(num_train=55) >>> default_root_dir = ub.Path.appdir('lightning_ext/tests/BatchPlotter').ensuredir() >>> # >>> trainer = pl.Trainer(callbacks=[BatchPlotter()], >>> default_root_dir=default_root_dir, >>> max_epochs=3, accelerator='cpu', devices=1) >>> trainer.fit(model) >>> import pathlib >>> train_dpath = pathlib.Path(trainer.log_dir) >>> list((train_dpath / 'monitor').glob('*')) >>> print('trainer.logger.log_dir = {!r}'.format(train_dpath))
References
- classmethod add_argparse_args(parent_parser)[source]¶
Example
>>> from geowatch.utils.lightning_ext.callbacks.batch_plotter import * # NOQA >>> from geowatch.utils.configargparse_ext import ArgumentParser >>> cls = BatchPlotter >>> parent_parser = ArgumentParser(formatter_class='defaults') >>> cls.add_argparse_args(parent_parser) >>> parent_parser.print_help() >>> parent_parser.parse_known_args()