geowatch.utils.lightning_ext.callbacks package¶
Submodules¶
- geowatch.utils.lightning_ext.callbacks.auto_resumer module
- geowatch.utils.lightning_ext.callbacks.batch_plotter module
- geowatch.utils.lightning_ext.callbacks.packager module
- geowatch.utils.lightning_ext.callbacks.state_logger module
- geowatch.utils.lightning_ext.callbacks.telemetry module
- geowatch.utils.lightning_ext.callbacks.tensorboard_plotter module
- geowatch.utils.lightning_ext.callbacks.text_logger module
TextLogger
TextLogger.setup()
TextLogger.teardown()
TextLogger.on_fit_start()
TextLogger.on_fit_end()
TextLogger.state_dict()
TextLogger.load_state_dict()
TextLogger.on_train_start()
TextLogger.on_train_end()
TextLogger.on_sanity_check_start()
TextLogger.on_sanity_check_end()
TextLogger.on_exception()
TextLogger.on_train_epoch_start()
TextLogger.on_train_epoch_end()
TextLogger.on_validation_epoch_end()
TextLogger.on_validation_epoch_start()
_InstanceLogger
_strip_ansi()
Module contents¶
- class geowatch.utils.lightning_ext.callbacks.AutoResumer[source]¶
Bases:
Callback
Auto-resumes from the most recent checkpoint
THIS SEEMS TO BE BROKEN WITH NEW LIGHTNING VERSIONS
Example
>>> from geowatch.utils.lightning_ext.callbacks.auto_resumer import AutoResumer >>> from kwutil import util_path >>> from geowatch.utils.lightning_ext.demo import LightningToyNet2d >>> from geowatch.utils.lightning_ext.callbacks import StateLogger >>> import pytorch_lightning as pl >>> import ubelt as ub >>> from geowatch.monkey import monkey_lightning >>> monkey_lightning.disable_lightning_hardware_warnings() >>> default_root_dir = ub.Path.appdir('lightning_ext/test/auto_resume') >>> default_root_dir.delete() >>> # >>> # STEP 1: >>> # Test starting a model without any existing checkpoints >>> import pytest >>> try: >>> AutoResumer() >>> except NotImplementedError: >>> pytest.skip() >>> trainer_orig = pl.Trainer(default_root_dir=default_root_dir, callbacks=[AutoResumer(), StateLogger()], max_epochs=2, accelerator='cpu', devices=1) >>> model = LightningToyNet2d() >>> trainer_orig.fit(model) >>> assert len(list((ub.Path(trainer_orig.logger.log_dir) / 'checkpoints').glob('*'))) > 0 >>> # See contents written >>> print(ub.urepr(list(util_path.tree(default_root_dir)), sort=0)) >>> # >>> # CHECK 1: >>> # Make a new trainer that should auto-resume >>> self = AutoResumer() >>> trainer = trainer_resume1 = pl.Trainer(default_root_dir=default_root_dir, callbacks=[self, StateLogger()], max_epochs=2, accelerator='cpu', devices=1) >>> model = LightningToyNet2d() >>> trainer_resume1.fit(model) >>> print(ub.urepr(list(util_path.tree(default_root_dir)), sort=0)) >>> # max_epochs should prevent auto-resume from doing anything >>> assert len(list((ub.Path(trainer_resume1.logger.log_dir) / 'checkpoints').glob('*'))) == 0 >>> # >>> # CHECK 2: >>> # Increasing max epochs will let it train for longer >>> trainer_resume2 = pl.Trainer(default_root_dir=default_root_dir, callbacks=[AutoResumer(), StateLogger()], max_epochs=3, accelerator='cpu', devices=1) >>> model = LightningToyNet2d() >>> trainer_resume2.fit(model) >>> print(ub.urepr(list(util_path.tree(ub.Path(default_root_dir))), sort=0)) >>> # max_epochs should prevent auto-resume from doing anything >>> assert len(list((ub.Path(trainer_resume2.logger.log_dir) / 'checkpoints').glob('*'))) > 0
Todo
[ ] Configure how to find which checkpoint to resume from
- class geowatch.utils.lightning_ext.callbacks.BatchPlotter(num_draw=2, draw_interval='5minutes', max_items=2, overlay_on_image=False)[source]¶
Bases:
Callback
These are callbacks used to monitor the training.
To be used, the trainer datamodule must have a draw_batch method that returns an ndarray to draw a batch.
See [LightningCallbacks].
- Parameters:
num_draw (int) – number of batches to draw at the start of each epoch
draw_interval (datetime.timedelta | str | numbers.Number) – This is the amount of time to wait before drawing the next batch item within an epoch. Can be given as a timedelta, a string parsable by coerce_timedelta (e.g. ‘1M’) or a numeric number of seconds.
max_items (int) – Maximum number of items within this batch to draw in a single figure. Defaults to 2.
overlay_on_image (bool) – if True overlay annotations on image data for a more compact view. if False separate annotations / images for a less cluttered view.
- FIXME:
[ ] This breaks when using strategy=DDP and multiple gpus
Todo
[ ] Doctest
Example
>>> # >>> from geowatch.utils.lightning_ext.callbacks.batch_plotter import * # NOQA >>> from geowatch.utils.lightning_ext import demo >>> from geowatch.monkey import monkey_lightning >>> monkey_lightning.disable_lightning_hardware_warnings() >>> model = demo.LightningToyNet2d(num_train=55) >>> default_root_dir = ub.Path.appdir('lightning_ext/tests/BatchPlotter').ensuredir() >>> # >>> trainer = pl.Trainer(callbacks=[BatchPlotter()], >>> default_root_dir=default_root_dir, >>> max_epochs=3, accelerator='cpu', devices=1) >>> trainer.fit(model) >>> import pathlib >>> train_dpath = pathlib.Path(trainer.log_dir) >>> list((train_dpath / 'monitor').glob('*')) >>> print('trainer.logger.log_dir = {!r}'.format(train_dpath))
References
- classmethod add_argparse_args(parent_parser)[source]¶
Example
>>> from geowatch.utils.lightning_ext.callbacks.batch_plotter import * # NOQA >>> from geowatch.utils.configargparse_ext import ArgumentParser >>> cls = BatchPlotter >>> parent_parser = ArgumentParser(formatter_class='defaults') >>> cls.add_argparse_args(parent_parser) >>> parent_parser.print_help() >>> parent_parser.parse_known_args()
- class geowatch.utils.lightning_ext.callbacks.Packager(package_fpath='auto')[source]¶
Bases:
Callback
Packages the best checkpoint at the end of training and at various other key phases of the training loop.
The lightning module must have a “save_package” method that can be called with a filepath.
Todo
[ ] Package the “best” checkpoints according to the monitor
- [ ] Package arbitrary checkpoints:
[ ] Package model topology without any weights
- [ ] Copy checkpoint weights into a package to get a package with
that “weight state”.
- [ ] Initializer should be able to point at a package and use
torch-liberator partial load to transfer the weights.
[ ] Replace print statements with logging statements
[ ] Create a trainer-level logger instance (similar to netharn)
- [ ] what is the right way to handle running eval after fit?
There may be multiple candidate models that need to be tested, so we can’t just specify one package, one prediction dumping ground, and one evaluation dataset, maybe we specify the paths where the “best” ones are written?.
- Parameters:
package_fpath (PathLike) – Specifies a path where a torch packaged model will be written (or symlinked) to.
References
https://discuss.pytorch.org/t/packaging-pytorch-topology-first-and-checkpoints-later/129478/2
Example
>>> from geowatch.utils.lightning_ext.callbacks.packager import * # NOQA >>> from geowatch.utils.lightning_ext.demo import LightningToyNet2d >>> from geowatch.utils.lightning_ext.callbacks import StateLogger >>> import ubelt as ub >>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings >>> disable_lightning_hardware_warnings() >>> default_root_dir = ub.Path.appdir('lightning_ext/test/packager') >>> default_root_dir.delete().ensuredir() >>> # Test starting a model without any existing checkpoints >>> trainer = pl.Trainer(default_root_dir=default_root_dir, callbacks=[ >>> Packager(default_root_dir / 'final_package.pt'), >>> StateLogger() >>> ], max_epochs=2, accelerator='cpu', devices=1) >>> model = LightningToyNet2d() >>> trainer.fit(model)
- classmethod add_argparse_args(parent_parser)[source]¶
Example
>>> from geowatch.utils.lightning_ext.callbacks.packager import * # NOQA >>> from geowatch.utils.configargparse_ext import ArgumentParser >>> cls = Packager >>> parent_parser = ArgumentParser(formatter_class='defaults') >>> cls.add_argparse_args(parent_parser) >>> parent_parser.print_help() >>> assert parent_parser.parse_known_args(None)[0].package_fpath == 'auto'
- setup(trainer, pl_module, stage=None)[source]¶
Finalize initialization step. Resolve the paths where files will be written.
- Parameters:
trainer (pl.Trainer)
pl_module (pl.LightningModule)
stage (str | None)
- Returns:
None
- on_fit_start(trainer: Trainer, pl_module: LightningModule) None [source]¶
Todo
[ ] Write out the uninitialized topology
- on_fit_end(trainer: Trainer, pl_module: LightningModule) None [source]¶
Create the final package (or a list of candidate packages) for evaluation and deployment.
Todo
[ ] how do we properly package all of the candidate checkpoints?
[ ] Symlink to “BEST” package at the end.
[ ] write some script such that any checkpoint can be packaged.
- class geowatch.utils.lightning_ext.callbacks.StateLogger[source]¶
Bases:
Callback
Prints out what callbacks are being called
DEPRECATE: Use text_logger
- class geowatch.utils.lightning_ext.callbacks.TensorboardPlotter[source]¶
Bases:
Callback
Asynchronously dumps PNGs to disk visualize tensorboard scalars. exit
CommandLine
xdoctest -m geowatch.utils.lightning_ext.callbacks.tensorboard_plotter TensorboardPlotter
Example
>>> # xdoctest: +REQUIRES(module:tensorboard) >>> from geowatch.utils.lightning_ext import demo >>> from geowatch.monkey import monkey_lightning >>> monkey_lightning.disable_lightning_hardware_warnings() >>> self = demo.LightningToyNet2d(num_train=55) >>> default_root_dir = ub.Path.appdir('lightning_ext/tests/TensorboardPlotter').ensuredir() >>> # >>> trainer = pl.Trainer(callbacks=[TensorboardPlotter()], >>> default_root_dir=default_root_dir, >>> max_epochs=3, accelerator='cpu', devices=1) >>> trainer.fit(self) >>> train_dpath = trainer.logger.log_dir >>> print('trainer.logger.log_dir = {!r}'.format(train_dpath)) >>> data = read_tensorboard_scalars(train_dpath) >>> for key in data.keys(): >>> d = data[key] >>> df = pd.DataFrame({key: d['ydata'], 'step': d['xdata'], 'wall': d['wall']}) >>> print(df)
- class geowatch.utils.lightning_ext.callbacks.TextLogger(args=None)[source]¶
Bases:
Callback
Writes logging information to text files.
Example
>>> # >>> from geowatch.utils.lightning_ext.callbacks.text_logger import * # NOQA >>> from geowatch.utils.lightning_ext import demo >>> from geowatch.monkey import monkey_lightning >>> monkey_lightning.disable_lightning_hardware_warnings() >>> self = demo.LightningToyNet2d(num_train=55) >>> default_root_dir = ub.Path.appdir('lightning_ext/tests/TextLogger').ensuredir() >>> # >>> trainer = pl.Trainer(callbacks=[TextLogger()], >>> default_root_dir=default_root_dir, >>> max_epochs=3, accelerator='cpu', devices=1) >>> trainer.fit(self) >>> text_logs = ub.Path(trainer.text_logger.log_fpath).read_text() >>> print(text_logs)
- geowatch.utils.lightning_ext.callbacks.default_save_package(model, package_path, verbose=1)[source]¶
- class geowatch.utils.lightning_ext.callbacks.LightningTelemetry[source]¶
Bases:
Callback
The idea is that we wrap a fit job with ProcessContext
Example
>>> from geowatch.utils.lightning_ext.callbacks.telemetry import * # NOQA >>> from geowatch.utils.lightning_ext.demo import LightningToyNet2d >>> from geowatch.utils.lightning_ext.callbacks import StateLogger >>> import ubelt as ub >>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings >>> disable_lightning_hardware_warnings() >>> default_root_dir = ub.Path.appdir('lightning_ext/test/telemetry') >>> default_root_dir.delete().ensuredir() >>> self = LightningTelemetry() >>> # Test starting a model without any existing checkpoints >>> trainer = pl.Trainer(default_root_dir=default_root_dir, callbacks=[ >>> self, >>> StateLogger() >>> ], max_epochs=2, accelerator='cpu', devices=1) >>> model = LightningToyNet2d() >>> trainer.fit(model)
- classmethod add_argparse_args(parent_parser)[source]¶
Example
>>> from geowatch.utils.lightning_ext.callbacks.telemetry import * # NOQA >>> from geowatch.utils.configargparse_ext import ArgumentParser >>> cls = LightningTelemetry >>> parent_parser = ArgumentParser(formatter_class='defaults') >>> cls.add_argparse_args(parent_parser) >>> parent_parser.print_help()
- setup(trainer, pl_module, stage=None)[source]¶
Finalize initialization step. Resolve the paths where files will be written.
- Parameters:
trainer (pl.Trainer)
pl_module (pl.LightningModule)
stage (str | None)
- Returns:
None