geowatch.utils.lightning_ext.callbacks package

Submodules

Module contents

class geowatch.utils.lightning_ext.callbacks.AutoResumer[source]

Bases: Callback

Auto-resumes from the most recent checkpoint

THIS SEEMS TO BE BROKEN WITH NEW LIGHTNING VERSIONS

Example

>>> from geowatch.utils.lightning_ext.callbacks.auto_resumer import AutoResumer
>>> from kwutil import util_path
>>> from geowatch.utils.lightning_ext.demo import LightningToyNet2d
>>> from geowatch.utils.lightning_ext.callbacks import StateLogger
>>> import pytorch_lightning as pl
>>> import ubelt as ub
>>> from geowatch.monkey import monkey_lightning
>>> monkey_lightning.disable_lightning_hardware_warnings()
>>> default_root_dir = ub.Path.appdir('lightning_ext/test/auto_resume')
>>> default_root_dir.delete()
>>> #
>>> # STEP 1:
>>> # Test starting a model without any existing checkpoints
>>> import pytest
>>> try:
>>>     AutoResumer()
>>> except NotImplementedError:
>>>     pytest.skip()
>>> trainer_orig = pl.Trainer(default_root_dir=default_root_dir, callbacks=[AutoResumer(), StateLogger()], max_epochs=2, accelerator='cpu', devices=1)
>>> model = LightningToyNet2d()
>>> trainer_orig.fit(model)
>>> assert len(list((ub.Path(trainer_orig.logger.log_dir) / 'checkpoints').glob('*'))) > 0
>>> # See contents written
>>> print(ub.urepr(list(util_path.tree(default_root_dir)), sort=0))
>>> #
>>> # CHECK 1:
>>> # Make a new trainer that should auto-resume
>>> self = AutoResumer()
>>> trainer = trainer_resume1 = pl.Trainer(default_root_dir=default_root_dir, callbacks=[self, StateLogger()], max_epochs=2, accelerator='cpu', devices=1)
>>> model = LightningToyNet2d()
>>> trainer_resume1.fit(model)
>>> print(ub.urepr(list(util_path.tree(default_root_dir)), sort=0))
>>> # max_epochs should prevent auto-resume from doing anything
>>> assert len(list((ub.Path(trainer_resume1.logger.log_dir) / 'checkpoints').glob('*'))) == 0
>>> #
>>> # CHECK 2:
>>> # Increasing max epochs will let it train for longer
>>> trainer_resume2 = pl.Trainer(default_root_dir=default_root_dir, callbacks=[AutoResumer(), StateLogger()], max_epochs=3, accelerator='cpu', devices=1)
>>> model = LightningToyNet2d()
>>> trainer_resume2.fit(model)
>>> print(ub.urepr(list(util_path.tree(ub.Path(default_root_dir))), sort=0))
>>> # max_epochs should prevent auto-resume from doing anything
>>> assert len(list((ub.Path(trainer_resume2.logger.log_dir) / 'checkpoints').glob('*'))) > 0

Todo

  • [ ] Configure how to find which checkpoint to resume from

on_init_start(trainer: Trainer) None[source]
recent_checkpoints(train_dpath)[source]

Return a list of existing checkpoints in some Trainer root directory

class geowatch.utils.lightning_ext.callbacks.BatchPlotter(num_draw=2, draw_interval='5minutes', max_items=2, overlay_on_image=False)[source]

Bases: Callback

These are callbacks used to monitor the training.

To be used, the trainer datamodule must have a draw_batch method that returns an ndarray to draw a batch.

See [LightningCallbacks].

Parameters:
  • num_draw (int) – number of batches to draw at the start of each epoch

  • draw_interval (datetime.timedelta | str | numbers.Number) – This is the amount of time to wait before drawing the next batch item within an epoch. Can be given as a timedelta, a string parsable by coerce_timedelta (e.g. ‘1M’) or a numeric number of seconds.

  • max_items (int) – Maximum number of items within this batch to draw in a single figure. Defaults to 2.

  • overlay_on_image (bool) – if True overlay annotations on image data for a more compact view. if False separate annotations / images for a less cluttered view.

FIXME:
  • [ ] This breaks when using strategy=DDP and multiple gpus

Todo

  • [ ] Doctest

Example

>>> #
>>> from geowatch.utils.lightning_ext.callbacks.batch_plotter import *  # NOQA
>>> from geowatch.utils.lightning_ext import demo
>>> from geowatch.monkey import monkey_lightning
>>> monkey_lightning.disable_lightning_hardware_warnings()
>>> model = demo.LightningToyNet2d(num_train=55)
>>> default_root_dir = ub.Path.appdir('lightning_ext/tests/BatchPlotter').ensuredir()
>>> #
>>> trainer = pl.Trainer(callbacks=[BatchPlotter()],
>>>                      default_root_dir=default_root_dir,
>>>                      max_epochs=3, accelerator='cpu', devices=1)
>>> trainer.fit(model)
>>> import pathlib
>>> train_dpath = pathlib.Path(trainer.log_dir)
>>> list((train_dpath / 'monitor').glob('*'))
>>> print('trainer.logger.log_dir = {!r}'.format(train_dpath))

References

setup(trainer, pl_module, stage)[source]
classmethod add_argparse_args(parent_parser)[source]

Example

>>> from geowatch.utils.lightning_ext.callbacks.batch_plotter import *  # NOQA
>>> from geowatch.utils.configargparse_ext import ArgumentParser
>>> cls = BatchPlotter
>>> parent_parser = ArgumentParser(formatter_class='defaults')
>>> cls.add_argparse_args(parent_parser)
>>> parent_parser.print_help()
>>> parent_parser.parse_known_args()
draw_batch(trainer, outputs, batch, batch_idx)[source]
draw_if_ready(trainer, pl_module, outputs, batch, batch_idx)[source]
on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]
on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]
on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]
class geowatch.utils.lightning_ext.callbacks.Packager(package_fpath='auto')[source]

Bases: Callback

Packages the best checkpoint at the end of training and at various other key phases of the training loop.

The lightning module must have a “save_package” method that can be called with a filepath.

Todo

  • [ ] Package the “best” checkpoints according to the monitor

  • [ ] Package arbitrary checkpoints:
    • [ ] Package model topology without any weights

    • [ ] Copy checkpoint weights into a package to get a package with

      that “weight state”.

  • [ ] Initializer should be able to point at a package and use

    torch-liberator partial load to transfer the weights.

  • [ ] Replace print statements with logging statements

  • [ ] Create a trainer-level logger instance (similar to netharn)

  • [ ] what is the right way to handle running eval after fit?

    There may be multiple candidate models that need to be tested, so we can’t just specify one package, one prediction dumping ground, and one evaluation dataset, maybe we specify the paths where the “best” ones are written?.

Parameters:

package_fpath (PathLike) – Specifies a path where a torch packaged model will be written (or symlinked) to.

References

https://discuss.pytorch.org/t/packaging-pytorch-topology-first-and-checkpoints-later/129478/2

Example

>>> from geowatch.utils.lightning_ext.callbacks.packager import *  # NOQA
>>> from geowatch.utils.lightning_ext.demo import LightningToyNet2d
>>> from geowatch.utils.lightning_ext.callbacks import StateLogger
>>> import ubelt as ub
>>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings
>>> disable_lightning_hardware_warnings()
>>> default_root_dir = ub.Path.appdir('lightning_ext/test/packager')
>>> default_root_dir.delete().ensuredir()
>>> # Test starting a model without any existing checkpoints
>>> trainer = pl.Trainer(default_root_dir=default_root_dir, callbacks=[
>>>     Packager(default_root_dir / 'final_package.pt'),
>>>     StateLogger()
>>> ], max_epochs=2, accelerator='cpu', devices=1)
>>> model = LightningToyNet2d()
>>> trainer.fit(model)
classmethod add_argparse_args(parent_parser)[source]

Example

>>> from geowatch.utils.lightning_ext.callbacks.packager import *  # NOQA
>>> from geowatch.utils.configargparse_ext import ArgumentParser
>>> cls = Packager
>>> parent_parser = ArgumentParser(formatter_class='defaults')
>>> cls.add_argparse_args(parent_parser)
>>> parent_parser.print_help()
>>> assert parent_parser.parse_known_args(None)[0].package_fpath == 'auto'
setup(trainer, pl_module, stage=None)[source]

Finalize initialization step. Resolve the paths where files will be written.

Parameters:
  • trainer (pl.Trainer)

  • pl_module (pl.LightningModule)

  • stage (str | None)

Returns:

None

_after_initialization(trainer)[source]
on_fit_start(trainer: Trainer, pl_module: LightningModule) None[source]

Todo

  • [ ] Write out the uninitialized topology

on_fit_end(trainer: Trainer, pl_module: LightningModule) None[source]

Create the final package (or a list of candidate packages) for evaluation and deployment.

Todo

  • [ ] how do we properly package all of the candidate checkpoints?

  • [ ] Symlink to “BEST” package at the end.

  • [ ] write some script such that any checkpoint can be packaged.

on_exception(trainer: Trainer, pl_module: LightningModule, *args, **kw) None[source]

Saving a package on keyboard interrupt is useful for manual early stopping.

Todo

  • [X] Package current model state

  • [ ] Package “best” model state

_make_package_fpath(trainer, dname='packages')[source]
_save_package(model, package_fpath)[source]
class geowatch.utils.lightning_ext.callbacks.StateLogger[source]

Bases: Callback

Prints out what callbacks are being called

DEPRECATE: Use text_logger

setup(trainer: Trainer, pl_module: LightningModule, stage: str | None = None) None[source]
teardown(trainer: Trainer, pl_module: LightningModule, stage: str | None = None) None[source]
on_fit_start(trainer: Trainer, pl_module: LightningModule) None[source]
on_fit_end(trainer: Trainer, pl_module: LightningModule) None[source]
on_save_checkpoint(trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) dict[source]
on_sanity_check_start(trainer: Trainer, pl_module: LightningModule) None[source]
on_sanity_check_end(trainer: Trainer, pl_module: LightningModule) None[source]
on_exception(trainer: Trainer, pl_module: LightningModule, *args, **kw) None[source]
class geowatch.utils.lightning_ext.callbacks.TensorboardPlotter[source]

Bases: Callback

Asynchronously dumps PNGs to disk visualize tensorboard scalars. exit

CommandLine

xdoctest -m geowatch.utils.lightning_ext.callbacks.tensorboard_plotter TensorboardPlotter

Example

>>> # xdoctest: +REQUIRES(module:tensorboard)
>>> from geowatch.utils.lightning_ext import demo
>>> from geowatch.monkey import monkey_lightning
>>> monkey_lightning.disable_lightning_hardware_warnings()
>>> self = demo.LightningToyNet2d(num_train=55)
>>> default_root_dir = ub.Path.appdir('lightning_ext/tests/TensorboardPlotter').ensuredir()
>>> #
>>> trainer = pl.Trainer(callbacks=[TensorboardPlotter()],
>>>                      default_root_dir=default_root_dir,
>>>                      max_epochs=3, accelerator='cpu', devices=1)
>>> trainer.fit(self)
>>> train_dpath = trainer.logger.log_dir
>>> print('trainer.logger.log_dir = {!r}'.format(train_dpath))
>>> data = read_tensorboard_scalars(train_dpath)
>>> for key in data.keys():
>>>     d = data[key]
>>>     df = pd.DataFrame({key: d['ydata'], 'step': d['xdata'], 'wall': d['wall']})
>>>     print(df)
_on_epoch_end(trainer, logs=None, serial=False)[source]
on_train_epoch_end(trainer, logs=None)[source]
on_validation_epoch_end(trainer, logs=None)[source]
on_test_epoch_end(trainer, logs=None)[source]
class geowatch.utils.lightning_ext.callbacks.TextLogger(args=None)[source]

Bases: Callback

Writes logging information to text files.

Example

>>> #
>>> from geowatch.utils.lightning_ext.callbacks.text_logger import *  # NOQA
>>> from geowatch.utils.lightning_ext import demo
>>> from geowatch.monkey import monkey_lightning
>>> monkey_lightning.disable_lightning_hardware_warnings()
>>> self = demo.LightningToyNet2d(num_train=55)
>>> default_root_dir = ub.Path.appdir('lightning_ext/tests/TextLogger').ensuredir()
>>> #
>>> trainer = pl.Trainer(callbacks=[TextLogger()],
>>>                      default_root_dir=default_root_dir,
>>>                      max_epochs=3, accelerator='cpu', devices=1)
>>> trainer.fit(self)
>>> text_logs = ub.Path(trainer.text_logger.log_fpath).read_text()
>>> print(text_logs)
setup(trainer: Trainer, pl_module: LightningModule, stage: str | None = None) None[source]
teardown(trainer: Trainer, pl_module: LightningModule, stage: str | None = None) None[source]
on_fit_start(trainer: Trainer, pl_module: LightningModule) None[source]
on_fit_end(trainer: Trainer, pl_module: LightningModule) None[source]
state_dict()[source]
load_state_dict(checkpoint)[source]
on_train_start(trainer: Trainer, pl_module: LightningModule) None[source]
on_train_end(trainer: Trainer, pl_module: LightningModule) None[source]
on_sanity_check_start(trainer: Trainer, pl_module: LightningModule) None[source]
on_sanity_check_end(trainer: Trainer, pl_module: LightningModule) None[source]
on_exception(trainer: Trainer, pl_module: LightningModule, *args, **kw) None[source]
on_train_epoch_start(trainer: Trainer, pl_module: LightningModule) None[source]
on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]
on_validation_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]
on_validation_epoch_start(trainer: Trainer, pl_module: LightningModule) None[source]
geowatch.utils.lightning_ext.callbacks.default_save_package(model, package_path, verbose=1)[source]
class geowatch.utils.lightning_ext.callbacks.LightningTelemetry[source]

Bases: Callback

The idea is that we wrap a fit job with ProcessContext

Example

>>> from geowatch.utils.lightning_ext.callbacks.telemetry import *  # NOQA
>>> from geowatch.utils.lightning_ext.demo import LightningToyNet2d
>>> from geowatch.utils.lightning_ext.callbacks import StateLogger
>>> import ubelt as ub
>>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings
>>> disable_lightning_hardware_warnings()
>>> default_root_dir = ub.Path.appdir('lightning_ext/test/telemetry')
>>> default_root_dir.delete().ensuredir()
>>> self = LightningTelemetry()
>>> # Test starting a model without any existing checkpoints
>>> trainer = pl.Trainer(default_root_dir=default_root_dir, callbacks=[
>>>     self,
>>>     StateLogger()
>>> ], max_epochs=2, accelerator='cpu', devices=1)
>>> model = LightningToyNet2d()
>>> trainer.fit(model)
classmethod add_argparse_args(parent_parser)[source]

Example

>>> from geowatch.utils.lightning_ext.callbacks.telemetry import *  # NOQA
>>> from geowatch.utils.configargparse_ext import ArgumentParser
>>> cls = LightningTelemetry
>>> parent_parser = ArgumentParser(formatter_class='defaults')
>>> cls.add_argparse_args(parent_parser)
>>> parent_parser.print_help()
setup(trainer, pl_module, stage=None)[source]

Finalize initialization step. Resolve the paths where files will be written.

Parameters:
  • trainer (pl.Trainer)

  • pl_module (pl.LightningModule)

  • stage (str | None)

Returns:

None

_after_initialization(trainer)[source]
on_fit_start(trainer: Trainer, pl_module: LightningModule) None[source]

Todo

  • [ ] Write out the uninitialized topology

on_fit_end(trainer: Trainer, pl_module: LightningModule) None[source]
on_exception(trainer: Trainer, pl_module: LightningModule, *args, **kw) None[source]
_dump(trainer)[source]