geowatch.tasks.fusion.fit_lightning module¶
Main entrypoint for a fusion fit job.
- For unit tests see:
../../../tests/test_lightning_cli_fit.py
- For tutorials see:
../../../docs/source/manual/tutorial/tutorial1_rgb_network.sh
- geowatch.tasks.fusion.fit_lightning.make_cli(config=None)[source]¶
Main entrypoint that creates the CLI and works around issues when config is passed as a parameter rather than via
sys.argvitself.- Parameters:
config (None | Dict) – if specified disables sys.argv usage and executes a training run with the specified config.
- Returns:
SmartLightningCLI
Note
Currently, creating the CLI will invoke it. We could modify this function to have the option to not invoke by specifying
run=FalsetoLightningCLI, but for some reason that changes the expected form of the config (you must specify subcommand if run=True but must not if run=False). We need to understand exactly what’s going on there before we expose a way to set run=False.
- geowatch.tasks.fusion.fit_lightning.main(config=None)[source]¶
Thin wrapper around
make_cli().- Parameters:
config (None | Dict) – if specified disables sys.argv usage and executes a training run with the specified config.
CommandLine
xdoctest -m geowatch.tasks.fusion.fit_lightning main:0
Example
>>> import os >>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings >>> from geowatch.tasks.fusion.fit_lightning import * # NOQA >>> disable_lightning_hardware_warnings() >>> dpath = ub.Path.appdir('geowatch/tests/test_fusion_fit/demo_main_noop').delete().ensuredir() >>> config = { >>> 'subcommand': 'fit', >>> 'fit.model': 'geowatch.tasks.fusion.methods.noop_model.NoopModel', >>> 'fit.trainer.default_root_dir': os.fspath(dpath), >>> 'fit.data.train_dataset': 'special:vidshapes2-frames9-gsize32', >>> 'fit.data.vali_dataset': 'special:vidshapes1-frames9-gsize32', >>> 'fit.data.chip_dims': 32, >>> 'fit.trainer.accelerator': 'cpu', >>> 'fit.trainer.devices': 1, >>> 'fit.trainer.max_steps': 2, >>> 'fit.trainer.num_sanity_val_steps': 0, >>> 'fit.trainer.add_to_registery': 0, >>> } >>> cli = main(config=config)
Example
>>> import os >>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings >>> from geowatch.tasks.fusion.fit_lightning import * # NOQA >>> disable_lightning_hardware_warnings() >>> dpath = ub.Path.appdir('geowatch/tests/test_fusion_fit/demo_main_heterogeneous').delete().ensuredir() >>> config = { >>> # 'model': 'geowatch.tasks.fusion.methods.MultimodalTransformer', >>> #'model': 'geowatch.tasks.fusion.methods.UNetBaseline', >>> 'subcommand': 'fit', >>> 'fit.model.class_path': 'geowatch.tasks.fusion.methods.heterogeneous.HeterogeneousModel', >>> 'fit.optimizer.class_path': 'torch.optim.SGD', >>> 'fit.optimizer.init_args.lr': 1e-3, >>> 'fit.trainer.default_root_dir': os.fspath(dpath), >>> 'fit.data.train_dataset': 'special:vidshapes2-gsize64-frames9-speed0.5-multispectral', >>> 'fit.data.vali_dataset': 'special:vidshapes1-gsize64-frames9-speed0.5-multispectral', >>> 'fit.data.chip_dims': 64, >>> 'fit.trainer.accelerator': 'cpu', >>> 'fit.trainer.devices': 1, >>> 'fit.trainer.max_steps': 2, >>> 'fit.trainer.num_sanity_val_steps': 0, >>> 'fit.trainer.add_to_registery': 0, >>> } >>> main(config=config)