geowatch.utils.lightning_ext.callbacks.auto_resumer module

https://github.com/Lightning-AI/lightning/issues/10894

class geowatch.utils.lightning_ext.callbacks.auto_resumer.AutoResumer[source]

Bases: Callback

Auto-resumes from the most recent checkpoint

THIS SEEMS TO BE BROKEN WITH NEW LIGHTNING VERSIONS

Example

>>> from geowatch.utils.lightning_ext.callbacks.auto_resumer import AutoResumer
>>> from kwutil import util_path
>>> from geowatch.utils.lightning_ext.demo import LightningToyNet2d
>>> from geowatch.utils.lightning_ext.callbacks import StateLogger
>>> import pytorch_lightning as pl
>>> import ubelt as ub
>>> from geowatch.monkey import monkey_lightning
>>> monkey_lightning.disable_lightning_hardware_warnings()
>>> default_root_dir = ub.Path.appdir('lightning_ext/test/auto_resume')
>>> default_root_dir.delete()
>>> #
>>> # STEP 1:
>>> # Test starting a model without any existing checkpoints
>>> import pytest
>>> try:
>>>     AutoResumer()
>>> except NotImplementedError:
>>>     pytest.skip()
>>> trainer_orig = pl.Trainer(default_root_dir=default_root_dir, callbacks=[AutoResumer(), StateLogger()], max_epochs=2, accelerator='cpu', devices=1)
>>> model = LightningToyNet2d()
>>> trainer_orig.fit(model)
>>> assert len(list((ub.Path(trainer_orig.logger.log_dir) / 'checkpoints').glob('*'))) > 0
>>> # See contents written
>>> print(ub.urepr(list(util_path.tree(default_root_dir)), sort=0))
>>> #
>>> # CHECK 1:
>>> # Make a new trainer that should auto-resume
>>> self = AutoResumer()
>>> trainer = trainer_resume1 = pl.Trainer(default_root_dir=default_root_dir, callbacks=[self, StateLogger()], max_epochs=2, accelerator='cpu', devices=1)
>>> model = LightningToyNet2d()
>>> trainer_resume1.fit(model)
>>> print(ub.urepr(list(util_path.tree(default_root_dir)), sort=0))
>>> # max_epochs should prevent auto-resume from doing anything
>>> assert len(list((ub.Path(trainer_resume1.logger.log_dir) / 'checkpoints').glob('*'))) == 0
>>> #
>>> # CHECK 2:
>>> # Increasing max epochs will let it train for longer
>>> trainer_resume2 = pl.Trainer(default_root_dir=default_root_dir, callbacks=[AutoResumer(), StateLogger()], max_epochs=3, accelerator='cpu', devices=1)
>>> model = LightningToyNet2d()
>>> trainer_resume2.fit(model)
>>> print(ub.urepr(list(util_path.tree(ub.Path(default_root_dir))), sort=0))
>>> # max_epochs should prevent auto-resume from doing anything
>>> assert len(list((ub.Path(trainer_resume2.logger.log_dir) / 'checkpoints').glob('*'))) > 0

Todo

  • [ ] Configure how to find which checkpoint to resume from

on_init_start(trainer: Trainer) None[source]
recent_checkpoints(train_dpath)[source]

Return a list of existing checkpoints in some Trainer root directory