#!/usr/bin/env python3
"""
TODO:
- [ ] Add relevant documentation to point at
~/code/geowatch/geowatch/mlops/repackager.py so it is easy for the user to
unpack / repack models.
"""
import scriptconfig as scfg
import ubelt as ub
[docs]
class UnpackModelCLI(scfg.DataConfig):
"""
Unpack the core components of a torch package to make them suitable for
environment-agnostic use or repackaging.
SeeAlso
-------
python -m geowatch/mlops/repackager
"""
fpath = scfg.Value(None, help='the path to a torch package', position=1)
dst_dpath = scfg.Value(None, help='Path to a destination directory to write to. If unspecfied chooses one.')
[docs]
@classmethod
def main(cls, cmdline=1, **kwargs):
"""
Example:
>>> # xdoctest: +SKIP
>>> from geowatch.cli.experimental.unpack_model import * # NOQA
>>> cmdline = 0
>>> kwargs = dict()
>>> kwargs['fpath'] = '/data/joncrall/dvc-repos/smart_phase3_expt/models/fusion/Drop8-ARA-Cropped2GSD-V1/packages/Drop8-ARA-Cropped2GSD-V1_allsensors_V001/Drop8-ARA-Cropped2GSD-V1_allsensors_V001_epoch0_step21021.pt'
>>> cls = UnpackModelCLI
>>> cls.main(cmdline=cmdline, **kwargs)
"""
import rich
from rich.markup import escape
config = cls.cli(cmdline=cmdline, data=kwargs, strict=True)
rich.print('config = ' + escape(ub.urepr(config, nl=1)))
package_fpath = config.fpath
unpack_model(package_fpath, dst_dpath=config.dst_dpath)
[docs]
def unpack_model(package_fpath, dst_dpath=None):
"""
Extracts and writes extracted files to disk.
Returns:
Dict[str, ub.Path]: mapping from keys to written paths on disk
Ignore:
package_fpath = '/data/joncrall/dvc-repos/smart_phase3_expt/models/fusion/Drop8-ARA-Cropped2GSD-V1/packages/Drop8-ARA-Cropped2GSD-V1_allsensors_V001/Drop8-ARA-Cropped2GSD-V1_allsensors_V001_epoch0_step21021.pt'
result = unpack_model(package_fpath)
checkpoint_fpath = result['ckpt_fpath']
from geowatch.mlops.repackager import repackage
new_package_fpath = repackage(checkpoint_fpath=checkpoint_fpath)[0]
round2_result = unpack_model(new_package_fpath)
import kwutil
config1 = kwutil.Yaml.load(result['config_fpath'])
config2 = kwutil.Yaml.load(round2_result['config_fpath'])
config1 == config2
from kwcoco.util.util_json import indexable_diff
info = indexable_diff(config1, config2)
assert info['similarity'] == 1.0
Example:
>>> import ubelt as ub
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_fpath = dpath / 'my_package.pt'
>>> dst_dpath = dpath / 'unpacked_packages'
>>> if not package_fpath.exists():
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> model = self = methods.MultimodalTransformer(
>>> arch_name="smt_it_joint_p2", input_sensorchan=5,
>>> change_head_hidden=0, saliency_head_hidden=0,
>>> class_head_hidden=0)
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_fpath)
>>> results = unpack_model(package_fpath)
>>> # Test repackage
>>> from geowatch.mlops.repackager import repackage_single_checkpoint
>>> repackage_single_checkpoint
"""
from kwutil.util_yaml import Yaml
import torch
package_fpath = ub.Path(package_fpath)
package_content = extract_package_contents(package_fpath)
if dst_dpath is None:
dst_dpath = package_fpath.parent / package_fpath.stem
else:
dst_dpath = ub.Path(dst_dpath)
dst_dpath.ensuredir()
config_fpath = dst_dpath / 'config.yaml'
config_fpath.write_text(Yaml.dumps(package_content['config']))
ckpt_name = package_fpath.stem + '.ckpt'
ckpt_dpath = (dst_dpath / 'checkpoints').ensuredir()
checkpoint_fpath = ckpt_dpath / ckpt_name
with open(checkpoint_fpath, 'wb') as file:
torch.save(package_content['checkpoint'], file)
result = {}
result['dpath'] = dst_dpath
result['ckpt_fpath'] = checkpoint_fpath
result['config_fpath'] = config_fpath
return result
# def unpack_model_backup(package_fpath):
# import zipfile
# zfile = zipfile.ZipFile(package_fpath, 'r')
# names = zfile.namelist()
# cand_model_pkl = [n for n in names if n.endswith('model.pkl')]
# assert len(cand_model_pkl) == 1
# model_pkl = cand_model_pkl[0]
# file = zfile.open(model_pkl)
# data = file.read()
# import pickle
# loaded = pickle.loads(data)
__cli__ = UnpackModelCLI
if __name__ == '__main__':
"""
CommandLine:
python ~/code/geowatch/geowatch/cli/experimental/unpack_model.py
python -m geowatch.cli.experimental.unpack_model
"""
__cli__.main()