geowatch.utils.lightning_ext.callbacks.batch_plotter module

class geowatch.utils.lightning_ext.callbacks.batch_plotter.BatchPlotter(num_draw=2, draw_interval='5minutes', max_items=2, overlay_on_image=False)[source]

Bases: Callback

These are callbacks used to monitor the training.

To be used, the trainer datamodule must have a draw_batch method that returns an ndarray to draw a batch.

See [LightningCallbacks].

Parameters:
  • num_draw (int) – number of batches to draw at the start of each epoch

  • draw_interval (datetime.timedelta | str | numbers.Number) – This is the amount of time to wait before drawing the next batch item within an epoch. Can be given as a timedelta, a string parsable by coerce_timedelta (e.g. ‘1M’) or a numeric number of seconds.

  • max_items (int) – Maximum number of items within this batch to draw in a single figure. Defaults to 2.

  • overlay_on_image (bool) – if True overlay annotations on image data for a more compact view. if False separate annotations / images for a less cluttered view.

FIXME:
  • [ ] This breaks when using strategy=DDP and multiple gpus

Todo

  • [ ] Doctest

Example

>>> #
>>> from geowatch.utils.lightning_ext.callbacks.batch_plotter import *  # NOQA
>>> from geowatch.utils.lightning_ext import demo
>>> from geowatch.monkey import monkey_lightning
>>> monkey_lightning.disable_lightning_hardware_warnings()
>>> model = demo.LightningToyNet2d(num_train=55)
>>> default_root_dir = ub.Path.appdir('lightning_ext/tests/BatchPlotter').ensuredir()
>>> #
>>> trainer = pl.Trainer(callbacks=[BatchPlotter()],
>>>                      default_root_dir=default_root_dir,
>>>                      max_epochs=3, accelerator='cpu', devices=1)
>>> trainer.fit(model)
>>> import pathlib
>>> train_dpath = pathlib.Path(trainer.log_dir)
>>> list((train_dpath / 'monitor').glob('*'))
>>> print('trainer.logger.log_dir = {!r}'.format(train_dpath))

References

setup(trainer, pl_module, stage)[source]
classmethod add_argparse_args(parent_parser)[source]

Example

>>> from geowatch.utils.lightning_ext.callbacks.batch_plotter import *  # NOQA
>>> from geowatch.utils.configargparse_ext import ArgumentParser
>>> cls = BatchPlotter
>>> parent_parser = ArgumentParser(formatter_class='defaults')
>>> cls.add_argparse_args(parent_parser)
>>> parent_parser.print_help()
>>> parent_parser.parse_known_args()
draw_batch(trainer, outputs, batch, batch_idx)[source]
draw_if_ready(trainer, pl_module, outputs, batch, batch_idx)[source]
on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]
on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]
on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]