"""
LightningTelemetry callback to interface with torch.package
"""
import pytorch_lightning as pl
import ubelt as ub
# from pytorch_lightning.utilities.rank_zero import rank_zero_only
[docs]
class LightningTelemetry(pl.callbacks.Callback):
"""
The idea is that we wrap a fit job with ProcessContext
Example:
>>> from geowatch.utils.lightning_ext.callbacks.telemetry import * # NOQA
>>> from geowatch.utils.lightning_ext.demo import LightningToyNet2d
>>> from geowatch.utils.lightning_ext.callbacks import StateLogger
>>> import ubelt as ub
>>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings
>>> disable_lightning_hardware_warnings()
>>> default_root_dir = ub.Path.appdir('geowatch/lightning_ext/test/telemetry')
>>> default_root_dir.delete().ensuredir()
>>> self = LightningTelemetry()
>>> # Test starting a model without any existing checkpoints
>>> trainer = pl.Trainer(default_root_dir=default_root_dir, callbacks=[
>>> self,
>>> StateLogger()
>>> ], max_epochs=2, accelerator='cpu', devices=1)
>>> model = LightningToyNet2d()
>>> trainer.fit(model)
"""
def __init__(self):
from geowatch.utils.process_context import ProcessContext
self.context = ProcessContext(
name='lightning_fit',
# TODO: how to get the config here?
# config=config,
# track_emissions='offline'
)
[docs]
@classmethod
def add_argparse_args(cls, parent_parser):
"""
Example:
>>> from geowatch.utils.lightning_ext.callbacks.telemetry import * # NOQA
>>> from geowatch.utils.configargparse_ext import ArgumentParser
>>> cls = LightningTelemetry
>>> parent_parser = ArgumentParser(formatter_class='defaults')
>>> cls.add_argparse_args(parent_parser)
>>> parent_parser.print_help()
"""
from geowatch.utils.lightning_ext import argparse_ext
arg_infos = argparse_ext.parse_docstring_args(cls)
argparse_ext.add_arginfos_to_parser(parent_parser, arg_infos)
return parent_parser
[docs]
def setup(self, trainer, pl_module, stage=None):
"""
Finalize initialization step.
Resolve the paths where files will be written.
Args:
trainer (pl.Trainer):
pl_module (pl.LightningModule):
stage (str | None):
Returns:
None
"""
if trainer.is_global_zero:
self._after_initialization(trainer)
def _after_initialization(self, trainer):
if trainer.is_global_zero:
print('initialize process context')
root_dir = ub.Path(trainer.default_root_dir)
print('root_dir = {!r}'.format(root_dir))
self.context.add_disk_info(root_dir)
# trainer.logger.log_dir
[docs]
def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""
TODO:
- [ ] Write out the uninitialized topology
"""
if not trainer.is_global_zero:
return
self.context.start()
self._dump(trainer)
[docs]
def on_fit_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
if not trainer.is_global_zero:
return
self._dump(trainer)
# Causes ddp hang
[docs]
def on_train_epoch_end(self, trainer, logs=None):
if trainer.global_rank != 0:
return
# print('Epoch finished, dumping telemetry')
self._dump(trainer)
[docs]
def on_exception(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', *args, **kw) -> None:
if trainer.global_rank != 0:
return
if hasattr(self.context, 'is_started'):
is_started = self.context.is_started
else:
# old non-public API, remove after kwutil 0.3.5 is min dep
is_started = self.context._started
if is_started:
print('Telemetry encountered exception, dumping...')
self._dump(trainer)
else:
print('Telemetry encountered exception, but not dumping because telemetry was not started')
def _dump(self, trainer):
if not trainer.is_global_zero:
return
if trainer.log_dir is None:
print('Trainer run without a log_dir, cannot dump telemetry')
return
import json
log_dpath = ub.Path(trainer.logger.log_dir)
obj = self.context.flush()
tel_fpath = log_dpath / 'telemetry.json'
tel_fpath.write_text(json.dumps(obj))
# print(f'Wrote telemetry to: {tel_fpath}')