geowatch.utils.lightning_ext.callbacks.packager module¶
Packager callback to interface with torch.package
- class geowatch.utils.lightning_ext.callbacks.packager.Packager(package_fpath='auto')[source]¶
Bases:
Callback
Packages the best checkpoint at the end of training and at various other key phases of the training loop.
The lightning module must have a “save_package” method that can be called with a filepath.
Todo
[ ] Package the “best” checkpoints according to the monitor
- [ ] Package arbitrary checkpoints:
[ ] Package model topology without any weights
- [ ] Copy checkpoint weights into a package to get a package with
that “weight state”.
- [ ] Initializer should be able to point at a package and use
torch-liberator partial load to transfer the weights.
[ ] Replace print statements with logging statements
[ ] Create a trainer-level logger instance (similar to netharn)
- [ ] what is the right way to handle running eval after fit?
There may be multiple candidate models that need to be tested, so we can’t just specify one package, one prediction dumping ground, and one evaluation dataset, maybe we specify the paths where the “best” ones are written?.
- Parameters:
package_fpath (PathLike) – Specifies a path where a torch packaged model will be written (or symlinked) to.
References
https://discuss.pytorch.org/t/packaging-pytorch-topology-first-and-checkpoints-later/129478/2
Example
>>> from geowatch.utils.lightning_ext.callbacks.packager import * # NOQA >>> from geowatch.utils.lightning_ext.demo import LightningToyNet2d >>> from geowatch.utils.lightning_ext.callbacks import StateLogger >>> import ubelt as ub >>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings >>> disable_lightning_hardware_warnings() >>> default_root_dir = ub.Path.appdir('lightning_ext/test/packager') >>> default_root_dir.delete().ensuredir() >>> # Test starting a model without any existing checkpoints >>> trainer = pl.Trainer(default_root_dir=default_root_dir, callbacks=[ >>> Packager(default_root_dir / 'final_package.pt'), >>> StateLogger() >>> ], max_epochs=2, accelerator='cpu', devices=1) >>> model = LightningToyNet2d() >>> trainer.fit(model)
- classmethod add_argparse_args(parent_parser)[source]¶
Example
>>> from geowatch.utils.lightning_ext.callbacks.packager import * # NOQA >>> from geowatch.utils.configargparse_ext import ArgumentParser >>> cls = Packager >>> parent_parser = ArgumentParser(formatter_class='defaults') >>> cls.add_argparse_args(parent_parser) >>> parent_parser.print_help() >>> assert parent_parser.parse_known_args(None)[0].package_fpath == 'auto'
- setup(trainer, pl_module, stage=None)[source]¶
Finalize initialization step. Resolve the paths where files will be written.
- Parameters:
trainer (pl.Trainer)
pl_module (pl.LightningModule)
stage (str | None)
- Returns:
None
- on_fit_start(trainer: Trainer, pl_module: LightningModule) None [source]¶
Todo
[ ] Write out the uninitialized topology
- on_fit_end(trainer: Trainer, pl_module: LightningModule) None [source]¶
Create the final package (or a list of candidate packages) for evaluation and deployment.
Todo
[ ] how do we properly package all of the candidate checkpoints?
[ ] Symlink to “BEST” package at the end.
[ ] write some script such that any checkpoint can be packaged.