geowatch.tasks.fusion.fit_lightning module

Main entrypoint for a fusion fit job.

For unit tests see:

../../../tests/test_lightning_cli_fit.py

For tutorials see:

../../../docs/source/manual/tutorial/tutorial1_rgb_network.sh

geowatch.tasks.fusion.fit_lightning.custom_yaml_load(stream)[source]
geowatch.tasks.fusion.fit_lightning.custom_yaml_dump(data)[source]
class geowatch.tasks.fusion.fit_lightning.SmartTrainer(*args, add_to_registery=True, **kwargs)[source]

Bases: Trainer

Simple trainer subclass so we can ensure a print happens directly before the training loop. (so annoying that we can’t reorder callbacks)

property log_dpath

Get path to the the log directory if it exists.

class geowatch.tasks.fusion.fit_lightning.TorchGlobals(float32_matmul_precision='default')[source]

Bases: Callback

Callback to setup torch globals

Parameters:

float32_matmul_precision (str) – can be ‘medium’, ‘high’, ‘default’, or ‘auto’. The ‘default’ value does not change any setting. The ‘auto’ value defaults to ‘medium’ if the training devices have

ampere cores.

setup(trainer, pl_module, stage)[source]
class geowatch.tasks.fusion.fit_lightning.WeightInitializer(init='noop', association='embedding', remember_initial_state=True, verbose=1)[source]

Bases: Callback

Network weight initializer with support for partial weight loading.

Variables:
  • init (str | PathLike) – either “noop” to use default weight initialization, or a path to a pretrained model that has at least some similarity to the model to be trained.

  • association (str) – Either “embedding” or “isomorphism”. The “embedding” case is more flexible allowing similar subcomponents of the network to be disconnected. In the “isomorphism” case, the transfered part must be a proper subgraph in both models. See torch-libertor’s partial weight initializtion for more details.

  • remember_initial_state (bool) – If True saves the initial state in a “analysis_checkpoints” folder. Defaults to True.

  • verbose (int) – if 1 prints some info. If 3 prints the explicit association found. if 0, prints nothing.

setup(trainer, pl_module, stage)[source]
class geowatch.tasks.fusion.fit_lightning.SmartLightningCLI(model_class: type[~pytorch_lightning.core.module.LightningModule] | ~typing.Callable[[...], ~pytorch_lightning.core.module.LightningModule] | None = None, datamodule_class: type[~pytorch_lightning.core.datamodule.LightningDataModule] | ~typing.Callable[[...], ~pytorch_lightning.core.datamodule.LightningDataModule] | None = None, save_config_callback: type[~pytorch_lightning.cli.SaveConfigCallback] | None = <class 'pytorch_lightning.cli.SaveConfigCallback'>, save_config_kwargs: dict[str, ~typing.Any] | None = None, trainer_class: type[~pytorch_lightning.trainer.trainer.Trainer] | ~typing.Callable[[...], ~pytorch_lightning.trainer.trainer.Trainer] = <class 'pytorch_lightning.trainer.trainer.Trainer'>, trainer_defaults: dict[str, ~typing.Any] | None = None, seed_everything_default: bool | int = True, parser_kwargs: dict[str, ~typing.Any] | dict[str, dict[str, ~typing.Any]] | None = None, subclass_mode_model: bool = False, subclass_mode_data: bool = False, args: list[str] | dict[str, ~typing.Any] | ~jsonargparse_fork._namespace.Namespace | None = None, run: bool = True, auto_configure_optimizers: bool = True)[source]

Bases: LightningCLI_Extension

Our extension of LightningCLI class that adds custom arguments and functionality to the CLI. See add_arguments_to_parser() for more details.

Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are called / instantiated using a parsed configuration file and / or command line args.

Parsing of configuration from environment variables can be enabled by setting parser_kwargs={"default_env": True}. A full configuration yaml would be parsed from PL_CONFIG if set. Individual settings are so parsed from variables named for example PL_TRAINER__MAX_EPOCHS.

For more info, read the CLI docs.

Parameters:
  • model_class – An optional LightningModule class to train on or a callable which returns a LightningModule instance when called. If None, you can pass a registered model with --model=MyModel.

  • datamodule_class – An optional LightningDataModule class or a callable which returns a LightningDataModule instance when called. If None, you can pass a registered datamodule with --data=MyDataModule.

  • save_config_callback – A callback class to save the config.

  • save_config_kwargs – Parameters that will be used to instantiate the save_config_callback.

  • trainer_class – An optional subclass of the Trainer class or a callable which returns a Trainer instance when called.

  • trainer_defaults – Set to override Trainer defaults or add persistent callbacks. The callbacks added through this argument will not be configurable from a configuration file and will always be present for this particular CLI. Alternatively, configurable callbacks can be added as explained in the CLI docs.

  • seed_everything_default – Number for the seed_everything() seed value. Set to True to automatically choose a seed value. Setting it to False will avoid calling seed_everything.

  • parser_kwargs – Additional arguments to instantiate each LightningArgumentParser.

  • subclass_mode_model – Whether model can be any subclass of the given class.

  • subclass_mode_data – Whether datamodule can be any subclass of the given class.

  • args – Arguments to parse. If None the arguments are taken from sys.argv. Command line style arguments can be given in a list. Alternatively, structured config options can be given in a dict or jsonargparse.Namespace.

  • run – Whether subcommands should be added to run a Trainer method. If set to False, the trainer and model classes will be instantiated only.

static configure_optimizers(lightning_module: LightningModule, optimizer: Optimizer, lr_scheduler=None) Any[source]

Override to customize the configure_optimizers() method.

TODO: why is this overloaded?

Parameters:
  • lightning_module – A reference to the model.

  • optimizer – The optimizer.

  • lr_scheduler – The learning rate scheduler (if used).

add_arguments_to_parser(parser: LightningArgumentParser)[source]
  1. Adds custom extensions like “initializer” and “torch_globals”

  2. Adds the packager callback (not sure why this is having a problem when used as a real callback)

  3. Helps the dataset / model notify each other about relevant settings.

geowatch.tasks.fusion.fit_lightning.instantiate_datamodule(cls, *args, **kwargs)[source]

Custom instantiator for the datamodule that simply calls setup after creating the instance.

geowatch.tasks.fusion.fit_lightning.make_cli(config=None)[source]

Main entrypoint that creates the CLI and works around issues when config is passed as a parameter rather than via sys.argv itself.

Parameters:

config (None | Dict) – if specified disables sys.argv usage and executes a training run with the specified config.

Returns:

SmartLightningCLI

Note

Currently, creating the CLI will invoke it. We could modify this function to have the option to not invoke by specifying run=False to LightningCLI, but for some reason that changes the expected form of the config (you must specify subcommand if run=True but must not if run=False). We need to understand exactly what’s going on there before we expose a way to set run=False.

geowatch.tasks.fusion.fit_lightning.main(config=None)[source]

Thin wrapper around make_cli().

Parameters:

config (None | Dict) – if specified disables sys.argv usage and executes a training run with the specified config.

CommandLine

xdoctest -m geowatch.tasks.fusion.fit_lightning main:0

Example

>>> import os
>>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings
>>> from geowatch.tasks.fusion.fit_lightning import *  # NOQA
>>> disable_lightning_hardware_warnings()
>>> dpath = ub.Path.appdir('geowatch/tests/test_fusion_fit/demo_main_noop').delete().ensuredir()
>>> config = {
>>>     'subcommand': 'fit',
>>>     'fit.model': 'geowatch.tasks.fusion.methods.noop_model.NoopModel',
>>>     'fit.trainer.default_root_dir': os.fspath(dpath),
>>>     'fit.data.train_dataset': 'special:vidshapes2-frames9-gsize32',
>>>     'fit.data.vali_dataset': 'special:vidshapes1-frames9-gsize32',
>>>     'fit.data.chip_dims': 32,
>>>     'fit.trainer.accelerator': 'cpu',
>>>     'fit.trainer.devices': 1,
>>>     'fit.trainer.max_steps': 2,
>>>     'fit.trainer.num_sanity_val_steps': 0,
>>>     'fit.trainer.add_to_registery': 0,
>>> }
>>> cli = main(config=config)

Example

>>> import os
>>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings
>>> from geowatch.tasks.fusion.fit_lightning import *  # NOQA
>>> disable_lightning_hardware_warnings()
>>> dpath = ub.Path.appdir('geowatch/tests/test_fusion_fit/demo_main_heterogeneous').delete().ensuredir()
>>> config = {
>>>     # 'model': 'geowatch.tasks.fusion.methods.MultimodalTransformer',
>>>     #'model': 'geowatch.tasks.fusion.methods.UNetBaseline',
>>>     'subcommand': 'fit',
>>>     'fit.model.class_path': 'geowatch.tasks.fusion.methods.heterogeneous.HeterogeneousModel',
>>>     'fit.optimizer.class_path': 'torch.optim.SGD',
>>>     'fit.optimizer.init_args.lr': 1e-3,
>>>     'fit.trainer.default_root_dir': os.fspath(dpath),
>>>     'fit.data.train_dataset': 'special:vidshapes2-gsize64-frames9-speed0.5-multispectral',
>>>     'fit.data.vali_dataset': 'special:vidshapes1-gsize64-frames9-speed0.5-multispectral',
>>>     'fit.data.chip_dims': 64,
>>>     'fit.trainer.accelerator': 'cpu',
>>>     'fit.trainer.devices': 1,
>>>     'fit.trainer.max_steps': 2,
>>>     'fit.trainer.num_sanity_val_steps': 0,
>>>     'fit.trainer.add_to_registery': 0,
>>> }
>>> main(config=config)