geowatch.mlops.repackager module

Script for converting a checkpoint (that lives in a training directory) into a pytorch package with appropriate metadata.

class geowatch.mlops.repackager.RepackageConfig(*args, **kwargs)[source]

Bases: DataConfig

Convert a raw torch checkpoint into a torch package.

Attempts to combine checkpoint weights with its associated model code in a standalone torch package.

To do this we must be able to infer how to construct an instance of the model to load the weights into. Currently we implement hard coded heuristics that only work for specific fusion models.

Note

The output filenames are chosen automatically. In the future we may give the user more control here. We may also look for ways to provide more hints for determening how to construct model instances either from context or via these configuration arguments.

Valid options: []

Parameters:
  • *args – positional arguments for this data config

  • **kwargs – keyword arguments for this data config

default = {'checkpoint_fpath': <Value(None)>, 'force': <Value(False)>, 'strict': <Value(False)>}
main(**kwargs)
geowatch.mlops.repackager.main(cmdline=True, **kwargs)[source]
geowatch.mlops.repackager.repackage(checkpoint_fpath, force=False, strict=False, dry=False)[source]

Logic for handling multiple checkpoint repackages at a time. Automatically chooses the new package name.

Todo

generalize this beyond the fusion model, also refactor.

geowatch.mlops.repackager.looks_like_training_directory(candidate_dpath)[source]
geowatch.mlops.repackager.inspect_checkpoint_context(checkpoint_fpath)[source]

Use heuristics to attempt to find the context in which this checkpoint was trained.

geowatch.mlops.repackager.suggest_package_name_for_checkpoint(context)[source]

Suggest a more distinguishable name for the checkpoint based on context

geowatch.mlops.repackager.parse_and_init_config(config)[source]
geowatch.mlops.repackager.torch_load_cpu(checkpoint_fpath)[source]
geowatch.mlops.repackager.repackage_single_checkpoint(checkpoint_fpath, package_fpath, train_dpath_hint=None, model_config_fpath=None)[source]

Primary logic for repackaging a checkpoint into a torch package.

To do this we need to have some information about how to construct the specific module to associate with the weights. We have some heuristics built in to take care of this for specific known models, but new models will need new logic to handle them. It would be nice to find a way to generalize this.

CommandLine

xdoctest -m geowatch.mlops.repackager repackage_single_checkpoint

Example

>>> import ubelt as ub
>>> import torch
>>> dpath = ub.Path.appdir('geowatch/tests/repackage').delete().ensuredir()
>>> package_fpath = dpath / 'my_package.pt'
>>> checkpoint_fpath = dpath / 'my_checkpoint.ckpt'
>>> assert not package_fpath.exists()
>>> # Create an instance of a model, and save a checkpoint to disk
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> model = self = methods.MultimodalTransformer(
>>>     arch_name="smt_it_joint_p2", input_sensorchan=5,
>>>     change_head_hidden=0, saliency_head_hidden=0,
>>>     class_head_hidden=0)
>>> # Save a checkpoint to disk.
>>> model_state = model.state_dict()
>>> # (fixme: how to get a lightning style checkpoint structure?)
>>> checkpoint = {
>>>     'state_dict': model_state,
>>>     'hyper_parameters': model._hparams,
>>> }
>>> with open(checkpoint_fpath, 'wb') as file:
...     torch.save(checkpoint, file)
>>> from geowatch.mlops.repackager import *  # NOQA
>>> repackage_single_checkpoint(checkpoint_fpath, package_fpath)
>>> assert package_fpath.exists()
>>> # Test we can reload the package
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> model2 = load_model_from_package(package_fpath)
>>> # TODO: get allclose working on the nested dict
>>> params1 = dict(model.named_parameters())
>>> params2 = dict(model2.named_parameters())
>>> k = 'encoder.layers.0.mlp.3.weight'
>>> assert torch.allclose(params1[k], params2[k])
>>> assert params1[k] is not params2[k]
>>> params1 = ub.IndexableWalker(dict(model.named_parameters()))
>>> params2 = ub.IndexableWalker(dict(model2.named_parameters()))
>>> for k, v in params1:
...     assert torch.allclose(params1[k], params2[k])
...     assert params1[k] is not params2[k]
>>> # Test that we can get model stats
>>> from geowatch.cli import torch_model_stats
>>> row = torch_model_stats.torch_model_stats(package_fpath)
>>> print(f'row = {ub.urepr(row, nl=2)}')
geowatch.mlops.repackager.load_meta(fpath)[source]