import pytorch_lightning as pl
import ubelt as ub
from typing import Dict, Any, Optional
[docs]
class StateLogger(pl.callbacks.Callback):
"""
Prints out what callbacks are being called
DEPRECATE: Use text_logger
"""
def __init__(self):
pass
[docs]
def setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', stage: Optional[str] = None) -> None:
print('setup state logger')
print('trainer.default_root_dir = {!r}'.format(trainer.default_root_dir))
[docs]
def teardown(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', stage: Optional[str] = None) -> None:
if 0:
print('teardown state logger')
if 0:
def on_init_start(self, trainer: 'pl.Trainer') -> None:
if 0:
print('on_init_start')
def on_init_end(self, trainer: 'pl.Trainer') -> None:
if 0:
print('on_init_start')
[docs]
def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
if 0:
print('on_fit_start')
[docs]
def on_fit_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
if 0:
print('on_fit_end')
# def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
# print('on_train_start')
# def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
# print('on_train_end')
[docs]
def on_save_checkpoint(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', checkpoint: Dict[str, Any]) -> dict:
if 0:
print('on_save_checkpoint - checkpoint = {}'.format(ub.urepr(checkpoint.keys(), nl=1)))
# def load_state_dict...
# def on_load_checkpoint(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', callback_state: Dict[str, Any]) -> None:
# if 0:
# print('on_load_checkpoint - callback_state = {}'.format(ub.urepr(callback_state.keys(), nl=1)))
[docs]
def on_sanity_check_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
if 0:
print('on_sanity_check_start')
[docs]
def on_sanity_check_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
if 0:
print('on_sanity_check_end')
[docs]
def on_exception(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', *args, **kw) -> None:
print('on_exception')
print('kw = {!r}'.format(kw))
print('args = {!r}'.format(args))
print('INTERUPT')
print('trainer.default_root_dir = {!r}'.format(trainer.default_root_dir))
print('trainer.log_dir = {!r}'.format(trainer.log_dir))