"""
https://github.com/Lightning-AI/lightning/issues/10894
"""
import pytorch_lightning as pl
import ubelt as ub
from packaging.version import parse as Version
from typing import Optional # NOQA
__all__ = ['AutoResumer']
[docs]
class AutoResumer(pl.callbacks.Callback):
"""
Auto-resumes from the most recent checkpoint
THIS SEEMS TO BE BROKEN WITH NEW LIGHTNING VERSIONS
Example:
>>> from geowatch.utils.lightning_ext.callbacks.auto_resumer import AutoResumer
>>> from kwutil import util_path
>>> from geowatch.utils.lightning_ext.demo import LightningToyNet2d
>>> from geowatch.utils.lightning_ext.callbacks import StateLogger
>>> import pytorch_lightning as pl
>>> import ubelt as ub
>>> from geowatch.monkey import monkey_lightning
>>> monkey_lightning.disable_lightning_hardware_warnings()
>>> default_root_dir = ub.Path.appdir('lightning_ext/test/auto_resume')
>>> default_root_dir.delete()
>>> #
>>> # STEP 1:
>>> # Test starting a model without any existing checkpoints
>>> import pytest
>>> try:
>>> AutoResumer()
>>> except NotImplementedError:
>>> pytest.skip()
>>> trainer_orig = pl.Trainer(default_root_dir=default_root_dir, callbacks=[AutoResumer(), StateLogger()], max_epochs=2, accelerator='cpu', devices=1)
>>> model = LightningToyNet2d()
>>> trainer_orig.fit(model)
>>> assert len(list((ub.Path(trainer_orig.logger.log_dir) / 'checkpoints').glob('*'))) > 0
>>> # See contents written
>>> print(ub.urepr(list(util_path.tree(default_root_dir)), sort=0))
>>> #
>>> # CHECK 1:
>>> # Make a new trainer that should auto-resume
>>> self = AutoResumer()
>>> trainer = trainer_resume1 = pl.Trainer(default_root_dir=default_root_dir, callbacks=[self, StateLogger()], max_epochs=2, accelerator='cpu', devices=1)
>>> model = LightningToyNet2d()
>>> trainer_resume1.fit(model)
>>> print(ub.urepr(list(util_path.tree(default_root_dir)), sort=0))
>>> # max_epochs should prevent auto-resume from doing anything
>>> assert len(list((ub.Path(trainer_resume1.logger.log_dir) / 'checkpoints').glob('*'))) == 0
>>> #
>>> # CHECK 2:
>>> # Increasing max epochs will let it train for longer
>>> trainer_resume2 = pl.Trainer(default_root_dir=default_root_dir, callbacks=[AutoResumer(), StateLogger()], max_epochs=3, accelerator='cpu', devices=1)
>>> model = LightningToyNet2d()
>>> trainer_resume2.fit(model)
>>> print(ub.urepr(list(util_path.tree(ub.Path(default_root_dir))), sort=0))
>>> # max_epochs should prevent auto-resume from doing anything
>>> assert len(list((ub.Path(trainer_resume2.logger.log_dir) / 'checkpoints').glob('*'))) > 0
"""
def __init__(self):
"""
TODO:
- [ ] Configure how to find which checkpoint to resume from
"""
if Version(pl.__version__) >= Version('1.8.0'):
raise NotImplementedError(
'Lightning 1.8.0 broke on_init_start, and we havent fixed it. '
'This component should be non-critical. Avoid using in the '
'meantime'
)
# @classmethod
# def add_argparse_args(cls, parent_parser):
# """
# Example:
# >>> from geowatch.utils.lightning_ext.callbacks.auto_resumer import * # NOQA
# >>> from geowatch.utils.configargparse_ext import ArgumentParser
# >>> cls = AutoResumer
# >>> parent_parser = ArgumentParser(formatter_class='defaults')
# >>> cls.add_argparse_args(parent_parser)
# >>> parent_parser.print_help()
# """
# from geowatch.utils.lightning_ext import argparse_ext
# arg_infos = argparse_ext.parse_docstring_args(cls)
# argparse_ext.add_arginfos_to_parser(parent_parser, arg_infos)
# return parent_parser
# FIXME: this doesn't work in new lightning versions
# def setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', stage: Optional[str] = None) -> None:
[docs]
def on_init_start(self, trainer: 'pl.Trainer') -> None:
if trainer.global_rank != 0:
return
train_dpath = trainer.default_root_dir
prev_states = self.recent_checkpoints(train_dpath)
print('There are {} existing checkpoints'.format(len(prev_states)))
if len(prev_states) > 0:
resume_from_checkpoint = prev_states[-1]
print('Will resume from {!r}'.format(resume_from_checkpoint))
# A Trainer will construct its checkpoint connector before this is
# called, but it wont use it, so it is currently safe to overwrite
# it with a new one that has the "correct" resume_from_checkpoint
# argument.
attrname = '_checkpoint_connector'
if not hasattr(trainer, attrname):
# lightning < 1.6
attrname = 'checkpoint_connector'
checkpoint_connector = getattr(trainer, attrname)
CheckpointConnector = checkpoint_connector.__class__
setattr(trainer, attrname, CheckpointConnector(
trainer, resume_from_checkpoint))
[docs]
def recent_checkpoints(self, train_dpath):
"""
Return a list of existing checkpoints in some Trainer root directory
"""
train_dpath = ub.Path(train_dpath)
candidates = sorted(train_dpath.glob('*/*/checkpoints/*.ckpt'))
return candidates