import pytorch_lightning as pl
import torch
from geowatch.utils.util_netharn import ToyNet2d, ToyData2d
[docs]
class LightningToyNet2d(pl.LightningModule):
"""
Toydata lightning module
"""
def __init__(self, num_train=100, num_val=10, batch_size=4):
super().__init__()
self.save_hyperparameters()
self.num_train = num_train
self.num_val = num_val
self.batch_size = batch_size
self.model = ToyNet2d()
[docs]
def forward(self, x):
return self.model(x)
[docs]
def get_cfgstr(self):
return 'This is for BatchPlotter'
[docs]
def forward_step(self, batch, batch_idx):
if self.trainer is None:
stage = 'disconnected'
else:
stage = self.trainer.state.stage.lower()
inputs, targets = batch
logits = self.forward(inputs)
loss = torch.nn.functional.nll_loss(logits.log_softmax(dim=1), targets)
# https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html
self.log(f'{stage}_loss', loss)
return {'loss': loss}
[docs]
def validation_step(self, batch, batch_idx):
return self.forward_step(batch, batch_idx)
[docs]
def training_step(self, batch, batch_idx):
return self.forward_step(batch, batch_idx)
[docs]
def train_dataloader(self):
dataset = ToyData2d(n=self.num_train)
loader = dataset.make_loader(batch_size=self.batch_size, num_workers=0)
return loader
[docs]
def val_dataloader(self):
dataset = ToyData2d(n=self.num_val)
loader = dataset.make_loader(batch_size=self.batch_size, num_workers=0)
return loader
[docs]
def demo_trainer():
"""
Notes wrt to the trainer:
~/.pyenv/versions/3.8.6/envs/pyenv3.8.6/lib/python3.8/site-packages/pytorch_lightning/__init__.py
~/.pyenv/versions/3.8.6/envs/pyenv3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py
~/.pyenv/versions/3.8.6/envs/pyenv3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py
~/.pyenv/versions/3.8.6/envs/pyenv3.8.6/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py
~/.pyenv/versions/3.8.6/envs/pyenv3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py
Example:
>>> # xdoctest: +SKIP
>>> from geowatch.utils.lightning_ext.demo import * # NOQA
>>> trainer = demo_trainer()
>>> print('trainer.log_dir = {!r}'.format(trainer.log_dir))
>>> trainer.fit(trainer.model)
>>> print('trainer.log_dir = {!r}'.format(trainer.log_dir))
"""
import ubelt as ub
default_root_dir = ub.Path.appdir('lightning_ext/demo_trainer').ensuredir()
model = LightningToyNet2d(num_train=55)
from geowatch.utils import lightning_ext as pl_ext
kwargs = {}
callbacks = [
pl_ext.callbacks.AutoResumer(),
pl_ext.callbacks.StateLogger(),
pl_ext.callbacks.BatchPlotter(
num_draw=kwargs.get('num_draw', 4),
draw_interval=kwargs.get('draw_interval', '10m'),
),
pl_ext.callbacks.TensorboardPlotter(), # draw tensorboard
pl.callbacks.LearningRateMonitor(logging_interval='epoch', log_momentum=True),
pl.callbacks.LearningRateMonitor(logging_interval='step', log_momentum=True),
pl.callbacks.ModelCheckpoint(monitor='train_loss', mode='min', save_top_k=2),
# pl.callbacks.GPUStatsMonitor(),
]
# # if args.vali_dataset is not None:
# if kwargs.get('vali_dataset', None) is not None:
# callbacks += [
# pl.callbacks.EarlyStopping(
# monitor='val_loss', mode='min',
# patience=kwargs.get('patience', 10),
# verbose=True),
# pl.callbacks.ModelCheckpoint(
# monitor='val_loss', mode='min', save_top_k=2),
# ]
# kwargs['callbacks'] = callbacks
trainer = pl.Trainer(default_root_dir=default_root_dir, max_epochs=100,
callbacks=callbacks, accelerator='cpu', devices=1)
trainer.model = model
return trainer