geowatch.tasks.fusion.methods.noop_model module

class geowatch.tasks.fusion.methods.noop_model.NoopModel(classes=10, dataset_stats=None, input_sensorchan=None, name: str = 'unnamed_model')[source]

Bases: LightningModule, WatchModuleMixins

No-op example model. Contains a dummy parameter to satisfy the optimizer and trainer.

Todo

  • [ ] Minimize even further.

  • [ ] Identify mandatory steps in __init__ and move to a parent class.

Parameters:

name – Specify a name for the experiment. (Unsure if the Model is the place for this)

get_cfgstr()[source]
forward(x)[source]
shared_step(batch, batch_idx=None, with_loss=True)[source]
training_step(batch, batch_idx=None, with_loss=True)
forward_step(batch, batch_idx=None, with_loss=True)
configure_optimizers()[source]
save_package(package_path, context=None, verbose=1)[source]

CommandLine

xdoctest -m geowatch.tasks.fusion.methods.noop_model NoopModel.save_package

Example

>>> # Test without datamodule
>>> import ubelt as ub
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = dpath / 'my_package.pt'
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> model = self = methods.NoopModel(
>>>     input_sensorchan=5,)
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.NoopModel.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = model.state_dict()
>>> assert recon is not model
>>> assert set(recon_state) == set(recon_state)
>>> for key in recon_state.keys():
>>>     assert (model_state[key] == recon_state[key]).all()
>>>     assert model_state[key] is not recon_state[key]
>>> # Check what's inside of the package
>>> import zipfile
>>> import json
>>> zfile = zipfile.ZipFile(package_path)
>>> header_file = zfile.open('my_package/package_header/package_header.json')
>>> package_header = json.loads(header_file.read())
>>> print('package_header = {}'.format(ub.urepr(package_header, nl=1)))
>>> assert 'version' in package_header
>>> assert 'arch_name' in package_header
>>> assert 'module_name' in package_header
>>> assert 'packaging_time' in package_header
>>> assert 'git_hash' in package_header
>>> assert 'module_path' in package_header

Example

>>> # Test with datamodule
>>> import ubelt as ub
>>> from geowatch.tasks.fusion import datamodules
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion.methods.noop_model import *  # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = dpath / 'my_package.pt'
>>> datamodule = datamodules.kwcoco_video_data.KWCocoVideoDataModule(
>>>     train_dataset='special:vidshapes8-multispectral-multisensor', chip_size=32,
>>>     batch_size=1, time_steps=2, num_workers=2, normalize_inputs=10)
>>> datamodule.setup('fit')
>>> dataset_stats = datamodule.torch_datasets['train'].cached_dataset_stats(num=3)
>>> classes = datamodule.torch_datasets['train'].classes
>>> # Use one of our fusion.architectures in a test
>>> self = methods.NoopModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats, input_sensorchan=datamodule.input_sensorchan)
>>> # We have to run an input through the module because it is lazy
>>> batch = ub.peek(iter(datamodule.train_dataloader()))
>>> outputs = self.training_step(batch)
>>> trainer = pl.Trainer(max_steps=0)
>>> trainer.fit(model=self, datamodule=datamodule)
>>> # Save the self
>>> self.save_package(package_path)
>>> # Test that the package can be reloaded
>>> recon = methods.NoopModel.load_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = self.state_dict()
>>> assert recon is not self
>>> assert set(recon_state) == set(recon_state)
>>> from geowatch.utils.util_kwarray import torch_array_equal
>>> for key in recon_state.keys():
>>>     v1 = model_state[key]
>>>     v2 = recon_state[key]
>>>     if not torch.allclose(v1, v2, equal_nan=True):
>>>         print('v1 = {}'.format(ub.urepr(v1, nl=1)))
>>>         print('v2 = {}'.format(ub.urepr(v2, nl=1)))
>>>         raise AssertionError(f'Difference in key={key}')
>>>     assert v1 is not v2, 'should be distinct copies'