geowatch.tasks.fusion.fit_lightning module

Main entrypoint for a fusion fit job.

For unit tests see:

../../../tests/test_lightning_cli_fit.py

For tutorials see:

../../../docs/source/manual/tutorial/tutorial1_rgb_network.sh

geowatch.tasks.fusion.fit_lightning.custom_yaml_load(stream)[source]
geowatch.tasks.fusion.fit_lightning.custom_yaml_dump(data)[source]
geowatch.tasks.fusion.fit_lightning.make_cli(config=None)[source]

Main entrypoint that creates the CLI and works around issues when config is passed as a parameter rather than via sys.argv itself.

Parameters:

config (None | Dict) – if specified disables sys.argv usage and executes a training run with the specified config.

Returns:

SmartLightningCLI

Note

Currently, creating the CLI will invoke it. We could modify this function to have the option to not invoke by specifying run=False to LightningCLI, but for some reason that changes the expected form of the config (you must specify subcommand if run=True but must not if run=False). We need to understand exactly what’s going on there before we expose a way to set run=False.

geowatch.tasks.fusion.fit_lightning.main(config=None)[source]

Thin wrapper around make_cli().

Parameters:

config (None | Dict) – if specified disables sys.argv usage and executes a training run with the specified config.

CommandLine

xdoctest -m geowatch.tasks.fusion.fit_lightning main:0

Example

>>> import os
>>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings
>>> from geowatch.tasks.fusion.fit_lightning import *  # NOQA
>>> disable_lightning_hardware_warnings()
>>> dpath = ub.Path.appdir('geowatch/tests/test_fusion_fit/demo_main_noop').delete().ensuredir()
>>> config = {
>>>     'subcommand': 'fit',
>>>     'fit.model': 'geowatch.tasks.fusion.methods.noop_model.NoopModel',
>>>     'fit.trainer.default_root_dir': os.fspath(dpath),
>>>     'fit.data.train_dataset': 'special:vidshapes2-frames9-gsize32',
>>>     'fit.data.vali_dataset': 'special:vidshapes1-frames9-gsize32',
>>>     'fit.data.chip_dims': 32,
>>>     'fit.trainer.accelerator': 'cpu',
>>>     'fit.trainer.devices': 1,
>>>     'fit.trainer.max_steps': 2,
>>>     'fit.trainer.num_sanity_val_steps': 0,
>>>     'fit.trainer.add_to_registery': 0,
>>> }
>>> cli = main(config=config)

Example

>>> import os
>>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings
>>> from geowatch.tasks.fusion.fit_lightning import *  # NOQA
>>> disable_lightning_hardware_warnings()
>>> dpath = ub.Path.appdir('geowatch/tests/test_fusion_fit/demo_main_heterogeneous').delete().ensuredir()
>>> config = {
>>>     # 'model': 'geowatch.tasks.fusion.methods.MultimodalTransformer',
>>>     #'model': 'geowatch.tasks.fusion.methods.UNetBaseline',
>>>     'subcommand': 'fit',
>>>     'fit.model.class_path': 'geowatch.tasks.fusion.methods.heterogeneous.HeterogeneousModel',
>>>     'fit.optimizer.class_path': 'torch.optim.SGD',
>>>     'fit.optimizer.init_args.lr': 1e-3,
>>>     'fit.trainer.default_root_dir': os.fspath(dpath),
>>>     'fit.data.train_dataset': 'special:vidshapes2-gsize64-frames9-speed0.5-multispectral',
>>>     'fit.data.vali_dataset': 'special:vidshapes1-gsize64-frames9-speed0.5-multispectral',
>>>     'fit.data.chip_dims': 64,
>>>     'fit.trainer.accelerator': 'cpu',
>>>     'fit.trainer.devices': 1,
>>>     'fit.trainer.max_steps': 2,
>>>     'fit.trainer.num_sanity_val_steps': 0,
>>>     'fit.trainer.add_to_registery': 0,
>>> }
>>> main(config=config)