geowatch.utils.lightning_ext.callbacks.auto_resumer module¶
https://github.com/Lightning-AI/lightning/issues/10894
- class geowatch.utils.lightning_ext.callbacks.auto_resumer.AutoResumer[source]¶
Bases:
Callback
Auto-resumes from the most recent checkpoint
THIS SEEMS TO BE BROKEN WITH NEW LIGHTNING VERSIONS
Example
>>> from geowatch.utils.lightning_ext.callbacks.auto_resumer import AutoResumer >>> from kwutil import util_path >>> from geowatch.utils.lightning_ext.demo import LightningToyNet2d >>> from geowatch.utils.lightning_ext.callbacks import StateLogger >>> import pytorch_lightning as pl >>> import ubelt as ub >>> from geowatch.monkey import monkey_lightning >>> monkey_lightning.disable_lightning_hardware_warnings() >>> default_root_dir = ub.Path.appdir('lightning_ext/test/auto_resume') >>> default_root_dir.delete() >>> # >>> # STEP 1: >>> # Test starting a model without any existing checkpoints >>> import pytest >>> try: >>> AutoResumer() >>> except NotImplementedError: >>> pytest.skip() >>> trainer_orig = pl.Trainer(default_root_dir=default_root_dir, callbacks=[AutoResumer(), StateLogger()], max_epochs=2, accelerator='cpu', devices=1) >>> model = LightningToyNet2d() >>> trainer_orig.fit(model) >>> assert len(list((ub.Path(trainer_orig.logger.log_dir) / 'checkpoints').glob('*'))) > 0 >>> # See contents written >>> print(ub.urepr(list(util_path.tree(default_root_dir)), sort=0)) >>> # >>> # CHECK 1: >>> # Make a new trainer that should auto-resume >>> self = AutoResumer() >>> trainer = trainer_resume1 = pl.Trainer(default_root_dir=default_root_dir, callbacks=[self, StateLogger()], max_epochs=2, accelerator='cpu', devices=1) >>> model = LightningToyNet2d() >>> trainer_resume1.fit(model) >>> print(ub.urepr(list(util_path.tree(default_root_dir)), sort=0)) >>> # max_epochs should prevent auto-resume from doing anything >>> assert len(list((ub.Path(trainer_resume1.logger.log_dir) / 'checkpoints').glob('*'))) == 0 >>> # >>> # CHECK 2: >>> # Increasing max epochs will let it train for longer >>> trainer_resume2 = pl.Trainer(default_root_dir=default_root_dir, callbacks=[AutoResumer(), StateLogger()], max_epochs=3, accelerator='cpu', devices=1) >>> model = LightningToyNet2d() >>> trainer_resume2.fit(model) >>> print(ub.urepr(list(util_path.tree(ub.Path(default_root_dir))), sort=0)) >>> # max_epochs should prevent auto-resume from doing anything >>> assert len(list((ub.Path(trainer_resume2.logger.log_dir) / 'checkpoints').glob('*'))) > 0
Todo
[ ] Configure how to find which checkpoint to resume from