geowatch.tasks.fusion.fit_lightning module¶
Main entrypoint for a fusion fit job.
- For unit tests see:
../../../tests/test_lightning_cli_fit.py
- For tutorials see:
../../../docs/source/manual/tutorial/tutorial1_rgb_network.sh
- class geowatch.tasks.fusion.fit_lightning.SmartTrainer(*args, add_to_registery=True, **kwargs)[source]¶
Bases:
Trainer
Simple trainer subclass so we can ensure a print happens directly before the training loop. (so annoying that we can’t reorder callbacks)
- property log_dpath¶
Get path to the the log directory if it exists.
- class geowatch.tasks.fusion.fit_lightning.TorchGlobals(float32_matmul_precision='default')[source]¶
Bases:
Callback
Callback to setup torch globals
- Parameters:
float32_matmul_precision (str) – can be ‘medium’, ‘high’, ‘default’, or ‘auto’. The ‘default’ value does not change any setting. The ‘auto’ value defaults to ‘medium’ if the training devices have
ampere cores.
- class geowatch.tasks.fusion.fit_lightning.WeightInitializer(init='noop', association='embedding', remember_initial_state=True, verbose=1)[source]¶
Bases:
Callback
Network weight initializer with support for partial weight loading.
- Variables:
init (str | PathLike) – either “noop” to use default weight initialization, or a path to a pretrained model that has at least some similarity to the model to be trained.
association (str) – Either “embedding” or “isomorphism”. The “embedding” case is more flexible allowing similar subcomponents of the network to be disconnected. In the “isomorphism” case, the transfered part must be a proper subgraph in both models. See torch-libertor’s partial weight initializtion for more details.
remember_initial_state (bool) – If True saves the initial state in a “analysis_checkpoints” folder. Defaults to True.
verbose (int) – if 1 prints some info. If 3 prints the explicit association found. if 0, prints nothing.
- class geowatch.tasks.fusion.fit_lightning.SmartLightningCLI(model_class: type[~pytorch_lightning.core.module.LightningModule] | ~typing.Callable[[...], ~pytorch_lightning.core.module.LightningModule] | None = None, datamodule_class: type[~pytorch_lightning.core.datamodule.LightningDataModule] | ~typing.Callable[[...], ~pytorch_lightning.core.datamodule.LightningDataModule] | None = None, save_config_callback: type[~pytorch_lightning.cli.SaveConfigCallback] | None = <class 'pytorch_lightning.cli.SaveConfigCallback'>, save_config_kwargs: dict[str, ~typing.Any] | None = None, trainer_class: type[~pytorch_lightning.trainer.trainer.Trainer] | ~typing.Callable[[...], ~pytorch_lightning.trainer.trainer.Trainer] = <class 'pytorch_lightning.trainer.trainer.Trainer'>, trainer_defaults: dict[str, ~typing.Any] | None = None, seed_everything_default: bool | int = True, parser_kwargs: dict[str, ~typing.Any] | dict[str, dict[str, ~typing.Any]] | None = None, subclass_mode_model: bool = False, subclass_mode_data: bool = False, args: list[str] | dict[str, ~typing.Any] | ~jsonargparse_fork._namespace.Namespace | None = None, run: bool = True, auto_configure_optimizers: bool = True)[source]¶
Bases:
LightningCLI_Extension
Our extension of LightningCLI class that adds custom arguments and functionality to the CLI. See
add_arguments_to_parser()
for more details.Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are called / instantiated using a parsed configuration file and / or command line args.
Parsing of configuration from environment variables can be enabled by setting
parser_kwargs={"default_env": True}
. A full configuration yaml would be parsed fromPL_CONFIG
if set. Individual settings are so parsed from variables named for examplePL_TRAINER__MAX_EPOCHS
.For more info, read the CLI docs.
- Parameters:
model_class – An optional
LightningModule
class to train on or a callable which returns aLightningModule
instance when called. IfNone
, you can pass a registered model with--model=MyModel
.datamodule_class – An optional
LightningDataModule
class or a callable which returns aLightningDataModule
instance when called. IfNone
, you can pass a registered datamodule with--data=MyDataModule
.save_config_callback – A callback class to save the config.
save_config_kwargs – Parameters that will be used to instantiate the save_config_callback.
trainer_class – An optional subclass of the
Trainer
class or a callable which returns aTrainer
instance when called.trainer_defaults – Set to override Trainer defaults or add persistent callbacks. The callbacks added through this argument will not be configurable from a configuration file and will always be present for this particular CLI. Alternatively, configurable callbacks can be added as explained in the CLI docs.
seed_everything_default – Number for the
seed_everything()
seed value. Set to True to automatically choose a seed value. Setting it to False will avoid callingseed_everything
.parser_kwargs – Additional arguments to instantiate each
LightningArgumentParser
.subclass_mode_model – Whether model can be any subclass of the given class.
subclass_mode_data – Whether datamodule can be any subclass of the given class.
args – Arguments to parse. If
None
the arguments are taken fromsys.argv
. Command line style arguments can be given in alist
. Alternatively, structured config options can be given in adict
orjsonargparse.Namespace
.run – Whether subcommands should be added to run a
Trainer
method. If set toFalse
, the trainer and model classes will be instantiated only.
- static configure_optimizers(lightning_module: LightningModule, optimizer: Optimizer, lr_scheduler=None) Any [source]¶
Override to customize the
configure_optimizers()
method.TODO: why is this overloaded?
- Parameters:
lightning_module – A reference to the model.
optimizer – The optimizer.
lr_scheduler – The learning rate scheduler (if used).
- geowatch.tasks.fusion.fit_lightning.instantiate_datamodule(cls, *args, **kwargs)[source]¶
Custom instantiator for the datamodule that simply calls setup after creating the instance.
- geowatch.tasks.fusion.fit_lightning.make_cli(config=None)[source]¶
Main entrypoint that creates the CLI and works around issues when config is passed as a parameter rather than via
sys.argv
itself.- Parameters:
config (None | Dict) – if specified disables sys.argv usage and executes a training run with the specified config.
- Returns:
SmartLightningCLI
Note
Currently, creating the CLI will invoke it. We could modify this function to have the option to not invoke by specifying
run=False
toLightningCLI
, but for some reason that changes the expected form of the config (you must specify subcommand if run=True but must not if run=False). We need to understand exactly what’s going on there before we expose a way to set run=False.
- geowatch.tasks.fusion.fit_lightning.main(config=None)[source]¶
Thin wrapper around
make_cli()
.- Parameters:
config (None | Dict) – if specified disables sys.argv usage and executes a training run with the specified config.
CommandLine
xdoctest -m geowatch.tasks.fusion.fit_lightning main:0
Example
>>> import os >>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings >>> from geowatch.tasks.fusion.fit_lightning import * # NOQA >>> disable_lightning_hardware_warnings() >>> dpath = ub.Path.appdir('geowatch/tests/test_fusion_fit/demo_main_noop').delete().ensuredir() >>> config = { >>> 'subcommand': 'fit', >>> 'fit.model': 'geowatch.tasks.fusion.methods.noop_model.NoopModel', >>> 'fit.trainer.default_root_dir': os.fspath(dpath), >>> 'fit.data.train_dataset': 'special:vidshapes2-frames9-gsize32', >>> 'fit.data.vali_dataset': 'special:vidshapes1-frames9-gsize32', >>> 'fit.data.chip_dims': 32, >>> 'fit.trainer.accelerator': 'cpu', >>> 'fit.trainer.devices': 1, >>> 'fit.trainer.max_steps': 2, >>> 'fit.trainer.num_sanity_val_steps': 0, >>> 'fit.trainer.add_to_registery': 0, >>> } >>> cli = main(config=config)
Example
>>> import os >>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings >>> from geowatch.tasks.fusion.fit_lightning import * # NOQA >>> disable_lightning_hardware_warnings() >>> dpath = ub.Path.appdir('geowatch/tests/test_fusion_fit/demo_main_heterogeneous').delete().ensuredir() >>> config = { >>> # 'model': 'geowatch.tasks.fusion.methods.MultimodalTransformer', >>> #'model': 'geowatch.tasks.fusion.methods.UNetBaseline', >>> 'subcommand': 'fit', >>> 'fit.model.class_path': 'geowatch.tasks.fusion.methods.heterogeneous.HeterogeneousModel', >>> 'fit.optimizer.class_path': 'torch.optim.SGD', >>> 'fit.optimizer.init_args.lr': 1e-3, >>> 'fit.trainer.default_root_dir': os.fspath(dpath), >>> 'fit.data.train_dataset': 'special:vidshapes2-gsize64-frames9-speed0.5-multispectral', >>> 'fit.data.vali_dataset': 'special:vidshapes1-gsize64-frames9-speed0.5-multispectral', >>> 'fit.data.chip_dims': 64, >>> 'fit.trainer.accelerator': 'cpu', >>> 'fit.trainer.devices': 1, >>> 'fit.trainer.max_steps': 2, >>> 'fit.trainer.num_sanity_val_steps': 0, >>> 'fit.trainer.add_to_registery': 0, >>> } >>> main(config=config)