geowatch.utils.lightning_ext.demo module

class geowatch.utils.lightning_ext.demo.LightningToyNet2d(num_train=100, num_val=10, batch_size=4)[source]

Bases: LightningModule

Toydata lightning module

forward(x)[source]
get_cfgstr()[source]
forward_step(batch, batch_idx)[source]
validation_step(batch, batch_idx)[source]
training_step(batch, batch_idx)[source]
train_dataloader()[source]
val_dataloader()[source]
configure_optimizers()[source]
geowatch.utils.lightning_ext.demo.demo_trainer()[source]

Notes wrt to the trainer:

~/.pyenv/versions/3.8.6/envs/pyenv3.8.6/lib/python3.8/site-packages/pytorch_lightning/__init__.py

~/.pyenv/versions/3.8.6/envs/pyenv3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py

~/.pyenv/versions/3.8.6/envs/pyenv3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

~/.pyenv/versions/3.8.6/envs/pyenv3.8.6/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py

~/.pyenv/versions/3.8.6/envs/pyenv3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py

Example

>>> # xdoctest: +SKIP
>>> from geowatch.utils.lightning_ext.demo import *  # NOQA
>>> trainer = demo_trainer()
>>> print('trainer.log_dir = {!r}'.format(trainer.log_dir))
>>> trainer.fit(trainer.model)
>>> print('trainer.log_dir = {!r}'.format(trainer.log_dir))