geowatch.tasks.fusion.methods package

Submodules

Module contents

mkinit -m geowatch.tasks.fusion.methods -w

class geowatch.tasks.fusion.methods.MultimodalTransformer(classes=10, dataset_stats=None, input_sensorchan=None, input_channels=None, **kwargs)[source]

Bases: LightningModule, WatchModuleMixins

Todo

  • [ ] Change name MultimodalTransformer -> FusionModel

  • [ ] Move parent module methods -> models

CommandLine

xdoctest -m geowatch.tasks.fusion.methods.channelwise_transformer MultimodalTransformer

Example

>>> from geowatch.tasks.fusion.methods.channelwise_transformer import *  # NOQA
>>> from geowatch.tasks.fusion import datamodules
>>> print('(STEP 0): SETUP THE DATA MODULE')
>>> datamodule = datamodules.KWCocoVideoDataModule(
>>>     train_dataset='special:vidshapes-geowatch', num_workers=4, channels='auto')
>>> datamodule.setup('fit')
>>> dataset = datamodule.torch_datasets['train']
>>> print('(STEP 1): ESTIMATE DATASET STATS')
>>> dataset_stats = dataset.cached_dataset_stats(num=3)
>>> print('dataset_stats = {}'.format(ub.urepr(dataset_stats, nl=3)))
>>> loader = datamodule.train_dataloader()
>>> print('(STEP 2): SAMPLE BATCH')
>>> batch = next(iter(loader))
>>> for item_idx, item in enumerate(batch):
>>>     print(f'item_idx={item_idx}')
>>>     item_summary = dataset.summarize_item(item)
>>>     print('item_summary = {}'.format(ub.urepr(item_summary, nl=2)))
>>> print('(STEP 3): THE REST OF THE TEST')
>>> #self = MultimodalTransformer(arch_name='smt_it_joint_p8')
>>> self = MultimodalTransformer(arch_name='smt_it_joint_p2',
>>>                              dataset_stats=dataset_stats,
>>>                              classes=datamodule.predictable_classes,
>>>                              decoder='segmenter',
>>>                              change_loss='dicefocal',
>>>                              #attention_impl='performer'
>>>                              attention_impl='exact'
>>>                              )
>>> device = torch.device('cpu')
>>> self = self.to(device)
>>> # Run forward pass
>>> from geowatch.utils import util_netharn
>>> num_params = util_netharn.number_of_parameters(self)
>>> print('num_params = {!r}'.format(num_params))
>>> output = self.forward_step(batch, with_loss=True)
>>> import torch.profiler
>>> from torch.profiler import profile, ProfilerActivity
>>> with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
>>>     with torch.profiler.record_function("model_inference"):
>>>         output = self.forward_step(batch, with_loss=True)
>>> print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

Example

>>> # Note: it is important that the non-kwargs are saved as hyperparams
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import MultimodalTransformer
>>> self = model = MultimodalTransformer(arch_name="smt_it_joint_p2", input_sensorchan='r|g|b')
>>> assert "classes" in model.hparams
>>> assert "dataset_stats" in model.hparams
>>> assert "input_sensorchan" in model.hparams
>>> assert "tokenizer" in model.hparams
classmethod add_argparse_args(parent_parser)[source]

Only required for backwards compatibility until lightning CLI is the primary entry point.

Example

>>> from geowatch.tasks.fusion.methods.channelwise_transformer import *  # NOQA
>>> from geowatch.utils.configargparse_ext import ArgumentParser
>>> cls = MultimodalTransformer
>>> parent_parser = ArgumentParser(formatter_class='defaults')
>>> cls.add_argparse_args(parent_parser)
>>> parent_parser.print_help()
>>> parent_parser.parse_known_args()

print(scfg.Config.port_argparse(parent_parser, style=’dataconf’))

classmethod compatible(cfgdict)[source]

Given keyword arguments, find the subset that is compatible with this constructor. This is somewhat hacked because of usage of scriptconfig, but could be made nicer by future updates. # init_kwargs = ub.compatible(config, cls.__init__)

configure_optimizers()[source]

Todo

  • [ ] Enable use of other optimization algorithms on the CLI

  • [ ] Enable use of other scheduler algorithms on the CLI

Note

Is this even called when using LightningCLI? Nope, the LightningCLI overwrites it.

References

https://pytorch-optimizer.readthedocs.io/en/latest/index.html https://pytorch-lightning.readthedocs.io/en/stable/common/optimization.html

Example

>>> from geowatch.tasks.fusion.methods.channelwise_transformer import *  # noqa
>>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings
>>> disable_lightning_hardware_warnings()
>>> self = MultimodalTransformer(arch_name="smt_it_joint_p2", input_sensorchan='r|g|b')
>>> max_epochs = 80
>>> self.trainer = pl.Trainer(max_epochs=max_epochs)
>>> [opt], [sched] = self.configure_optimizers()
>>> rows = []
>>> # Insepct what the LR curve will look like
>>> for _ in range(max_epochs):
...     sched.last_epoch += 1
...     lr = sched.get_last_lr()[0]
...     rows.append({'lr': lr, 'last_epoch': sched.last_epoch})
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> import pandas as pd
>>> data = pd.DataFrame(rows)
>>> sns = kwplot.autosns()
>>> sns.lineplot(data=data, y='lr', x='last_epoch')

Example

>>> # Verify lr and decay is set correctly
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import *  # NOQA
>>> my_lr = 2.3e-5
>>> my_decay = 2.3e-5
>>> kw = dict(arch_name="smt_it_joint_p2", input_sensorchan='r|g|b', learning_rate=my_lr, weight_decay=my_decay)
>>> self = MultimodalTransformer(**kw)
>>> [opt], [sched] = self.configure_optimizers()
>>> assert opt.param_groups[0]['lr'] == my_lr
>>> assert opt.param_groups[0]['weight_decay'] == my_decay
>>> #
>>> self = MultimodalTransformer(**kw, optimizer='sgd')
>>> [opt], [sched] = self.configure_optimizers()
>>> assert opt.param_groups[0]['lr'] == my_lr
>>> assert opt.param_groups[0]['weight_decay'] == my_decay
>>> #
>>> self = MultimodalTransformer(**kw, optimizer='AdamW')
>>> [opt], [sched] = self.configure_optimizers()
>>> assert opt.param_groups[0]['lr'] == my_lr
>>> assert opt.param_groups[0]['weight_decay'] == my_decay
>>> #
>>> # self = MultimodalTransformer(**kw, optimizer='MADGRAD')
>>> # [opt], [sched] = self.configure_optimizers()
>>> # assert opt.param_groups[0]['lr'] == my_lr
>>> # assert opt.param_groups[0]['weight_decay'] == my_decay
forward(batch)[source]

Example

>>> import pytest
>>> pytest.skip('not currently used')
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import *  # NOQA
>>> from geowatch.tasks.fusion import datamodules
>>> channels = 'B1,B8|B8a,B10|B11'
>>> channels = 'B1|B8|B10|B8a|B11'
>>> datamodule = datamodules.KWCocoVideoDataModule(
>>>     train_dataset='special:vidshapes8-multispectral', num_workers=0, channels=channels)
>>> datamodule.setup('fit')
>>> train_dataset = datamodule.torch_datasets['train']
>>> dataset_stats = train_dataset.cached_dataset_stats()
>>> loader = datamodule.train_dataloader()
>>> tokenizer = 'convexpt-v1'
>>> tokenizer = 'dwcnn'
>>> batch = next(iter(loader))
>>> #self = MultimodalTransformer(arch_name='smt_it_joint_p8')
>>> self = MultimodalTransformer(
>>>     arch_name='smt_it_joint_p8',
>>>     dataset_stats=dataset_stats,
>>>     change_loss='dicefocal',
>>>     decoder='dicefocal',
>>>     attention_impl='performer',
>>>     tokenizer=tokenizer,
>>> )
>>> #images = torch.stack([ub.peek(f['modes'].values()) for f in batch[0]['frames']])[None, :]
>>> #images.shape
>>> #self.forward(images)
forward_foot(sensor, chan_code, mode_val: Tensor, frame_enc)[source]
forward_item(item, with_loss=False)[source]

CommandLine

xdoctest -m geowatch.tasks.fusion.methods.channelwise_transformer MultimodalTransformer.forward_item:1

Example

>>> from geowatch.tasks.fusion.methods.channelwise_transformer import *  # NOQA
>>> channels, classes, dataset_stats = MultimodalTransformer.demo_dataset_stats()
>>> self = MultimodalTransformer(
>>>     arch_name='smt_it_stm_p1', tokenizer='linconv',
>>>     decoder='segmenter', classes=classes, global_saliency_weight=1,
>>>     dataset_stats=dataset_stats, input_sensorchan=channels)
>>> item = self.demo_batch(width=64, height=65)[0]
>>> outputs = self.forward_item(item, with_loss=True)
>>> print('item')
>>> from geowatch.utils.util_netharn import _debug_inbatch_shapes
>>> print(_debug_inbatch_shapes(item))
>>> print('outputs')
>>> print(_debug_inbatch_shapes(outputs))

Example

>>> # Box head
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import *  # NOQA
>>> channels, classes, dataset_stats = MultimodalTransformer.demo_dataset_stats()
>>> self = MultimodalTransformer(
>>>     arch_name='smt_it_stm_p1', tokenizer='linconv',
>>>     decoder='mlp', classes=classes, global_saliency_weight=1,
>>>     dataset_stats=dataset_stats, input_sensorchan=channels,
>>>     decouple_resolution=False, global_box_weight=1)
>>> batch = self.demo_batch(width=64, height=64, num_timesteps=3)
>>> item = batch[0]
>>> from geowatch.utils.util_netharn import _debug_inbatch_shapes
>>> print(_debug_inbatch_shapes(batch))
>>> result1 = self.forward_step(batch, with_loss=True)
>>> assert len(result1['box'])
>>> assert 'box_ltrb' in result1['box'][0]
>>> assert len(result1['box'][0]['box_ltrb'].shape) == 3
>>> print(_debug_inbatch_shapes(result1))
>>> # Check we can go backward
>>> result1['loss'].backward()

Example

>>> # Decoupled resolutions
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import *  # NOQA
>>> channels, classes, dataset_stats = MultimodalTransformer.demo_dataset_stats()
>>> self = MultimodalTransformer(
>>>     arch_name='smt_it_stm_p1', tokenizer='linconv',
>>>     decoder='mlp', classes=classes, global_saliency_weight=1,
>>>     dataset_stats=dataset_stats, input_sensorchan=channels,
>>>     decouple_resolution=True, global_box_weight=0)
>>> batch = self.demo_batch(width=(11, 21), height=(16, 64), num_timesteps=3)
>>> item = batch[0]
>>> from geowatch.utils.util_netharn import _debug_inbatch_shapes
>>> print(_debug_inbatch_shapes(batch))
>>> result1 = self.forward_step(batch, with_loss=True)
>>> print(_debug_inbatch_shapes(result1))
>>> # Check we can go backward
>>> result1['loss'].backward()
forward_step(batch, with_loss=False, stage='unspecified')[source]

Generic forward step used for test / train / validation

Returns:

with keys for various predictions / losses

Return type:

Dict

CommandLine

xdoctest -m geowatch.tasks.fusion.methods.channelwise_transformer MultimodalTransformer.forward_step

Example

>>> from geowatch.tasks.fusion.methods.channelwise_transformer import *  # NOQA
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> import geowatch
>>> datamodule = datamodules.KWCocoVideoDataModule(
>>>     train_dataset='special:vidshapes-geowatch',
>>>     num_workers=0, chip_size=96, time_steps=4,
>>>     normalize_inputs=8, neg_to_pos_ratio=0, batch_size=5,
>>>     channels='auto',
>>> )
>>> datamodule.setup('fit')
>>> train_dset = datamodule.torch_datasets['train']
>>> loader = datamodule.train_dataloader()
>>> batch = next(iter(loader))
>>> # Test with "failed samples"
>>> batch[0] = None
>>> batch[2] = None
>>> batch[3] = None
>>> batch[4] = None
>>> if 1:
>>>     from geowatch.utils.util_netharn import _debug_inbatch_shapes
>>>     print(_debug_inbatch_shapes(batch))
>>> # Choose subclass to test this with (does not cover all cases)
>>> self = model = methods.MultimodalTransformer(
>>>     arch_name='smt_it_joint_p8', tokenizer='rearrange',
>>>     decoder='segmenter',
>>>     dataset_stats=datamodule.dataset_stats, global_saliency_weight=1.0, global_change_weight=1.0, global_class_weight=1.0,
>>>     classes=datamodule.predictable_classes, input_sensorchan=datamodule.input_sensorchan)
>>> with_loss = True
>>> outputs = self.forward_step(batch, with_loss=with_loss)
>>> canvas = datamodule.draw_batch(batch, outputs=outputs, max_items=3, overlay_on_image=False)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(canvas)
>>> kwplot.show_if_requested()

Example

>>> from geowatch.tasks.fusion.methods.channelwise_transformer import *  # NOQA
>>> channels, classes, dataset_stats = MultimodalTransformer.demo_dataset_stats()
>>> self = MultimodalTransformer(
>>>     arch_name='smt_it_stm_p1', tokenizer='linconv',
>>>     decoder='segmenter', classes=classes, global_saliency_weight=1,
>>>     dataset_stats=dataset_stats, input_sensorchan=channels)
>>> batch = self.demo_batch()
>>> outputs = self.forward_step(batch, with_loss=True)
>>> from geowatch.utils.util_netharn import _debug_inbatch_shapes
>>> print(_debug_inbatch_shapes(batch))
>>> print(_debug_inbatch_shapes(outputs))

Example

>>> # Test learned_linear multimodal reduce
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import *  # NOQA
>>> channels, classes, dataset_stats = MultimodalTransformer.demo_dataset_stats()
>>> self = MultimodalTransformer(
>>>     arch_name='smt_it_stm_p1', tokenizer='linconv',
>>>     decoder='mlp', classes=classes, global_saliency_weight=1,
>>>     dataset_stats=dataset_stats, input_sensorchan=channels, multimodal_reduce='learned_linear')
>>> batch = self.demo_batch()
>>> outputs = self.forward_step(batch, with_loss=True)
>>> from geowatch.utils.util_netharn import _debug_inbatch_shapes
>>> print(_debug_inbatch_shapes(batch))
>>> print(_debug_inbatch_shapes(outputs))
>>> # outputs['loss'].backward()
get_cfgstr()[source]
optimizer_step(*args, **kwargs)[source]
overfit(batch)[source]

Overfit script and demo

CommandLine

python -m xdoctest -m geowatch.tasks.fusion.methods.channelwise_transformer MultimodalTransformer.overfit --overfit-demo

Example

>>> # xdoctest: +REQUIRES(--overfit-demo)
>>> # ============
>>> # DEMO OVERFIT:
>>> # ============
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import *  # NOQA
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> from geowatch.utils.util_data import find_dvc_dpath
>>> import geowatch
>>> import kwcoco
>>> from os.path import join
>>> import os
>>> if 1:
>>>     print('''
...     # Generate toy datasets
...     DATA_DPATH=$HOME/data/work/toy_change
...     TRAIN_FPATH=$DATA_DPATH/vidshapes_msi_train/data.kwcoco.json
...     mkdir -p "$DATA_DPATH"
...     kwcoco toydata --key=vidshapes-videos8-frames5-randgsize-speed0.2-msi-multisensor --bundle_dpath "$DATA_DPATH/vidshapes_msi_train" --verbose=5
...     ''')
>>>     coco_fpath = ub.expandpath('$HOME/data/work/toy_change/vidshapes_msi_train/data.kwcoco.json')
>>>     coco_fpath = 'vidshapes-videos8-frames5-randgsize-speed0.2-msi-multisensor'
>>>     coco_dset = kwcoco.CocoDataset.coerce(coco_fpath)
>>>     channels="B11,r|g|b,B1|B8|B11"
>>> if 0:
>>>     dvc_dpath = geowatch.find_dvc_dpath(tags='phase2_data', hardware='auto')
>>>     coco_dset = (dvc_dpath / 'Drop4-BAS') / 'data_vali.kwcoco.json'
>>>     channels='swir16|swir22|blue|green|red|nir'
>>>     coco_dset = (dvc_dpath / 'Drop4-BAS') / 'combo_vali_I2.kwcoco.json'
>>>     channels='blue|green|red|nir,invariants.0:17'
>>> if 0:
>>>     coco_dset = geowatch.demo.demo_kwcoco_multisensor(max_speed=0.5)
>>>     # coco_dset = 'special:vidshapes8-frames9-speed0.5-multispectral'
>>>     #channels='B1|B11|B8|r|g|b|gauss'
>>>     channels='X.2|Y:2:6,B1|B8|B8a|B10|B11,r|g|b,disparity|gauss,flowx|flowy|distri'
>>> coco_dset = kwcoco.CocoDataset.coerce(coco_dset)
>>> datamodule = datamodules.KWCocoVideoDataModule(
>>>     train_dataset=coco_dset,
>>>     chip_size=128, batch_size=1, time_steps=5,
>>>     channels=channels,
>>>     normalize_peritem='blue|green|red|nir',
>>>     normalize_inputs=32, neg_to_pos_ratio=0,
>>>     num_workers='avail/2',
>>>     mask_low_quality=True,
>>>     observable_threshold=0.6,
>>>     use_grid_positives=False, use_centered_positives=True,
>>> )
>>> datamodule.setup('fit')
>>> dataset = torch_dset = datamodule.torch_datasets['train']
>>> torch_dset.disable_augmenter = True
>>> dataset_stats = datamodule.dataset_stats
>>> input_sensorchan = datamodule.input_sensorchan
>>> classes = datamodule.classes
>>> print('dataset_stats = {}'.format(ub.urepr(dataset_stats, nl=3)))
>>> print('input_sensorchan = {}'.format(input_sensorchan))
>>> print('classes = {}'.format(classes))
>>> # Choose subclass to test this with (does not cover all cases)
>>> self = methods.MultimodalTransformer(
>>>     # ===========
>>>     # Backbone
>>>     #arch_name='smt_it_joint_p2',
>>>     arch_name='smt_it_stm_p8',
>>>     stream_channels = 16,
>>>     #arch_name='deit',
>>>     optimizer='AdamW',
>>>     learning_rate=1e-5,
>>>     weight_decay=1e-3,
>>>     #attention_impl='performer',
>>>     attention_impl='exact',
>>>     #decoder='segmenter',
>>>     #saliency_head_hidden=4,
>>>     decoder='mlp',
>>>     change_loss='dicefocal',
>>>     #class_loss='cce',
>>>     class_loss='dicefocal',
>>>     #saliency_loss='dicefocal',
>>>     saliency_loss='focal',
>>>     # ===========
>>>     # Change Loss
>>>     global_change_weight=1e-5,
>>>     positive_change_weight=1.0,
>>>     negative_change_weight=0.5,
>>>     # ===========
>>>     # Class Loss
>>>     global_class_weight=1e-5,
>>>     class_weights='auto',
>>>     # ===========
>>>     # Saliency Loss
>>>     global_saliency_weight=1.00,
>>>     # ===========
>>>     # Domain Metadata (Look Ma, not hard coded!)
>>>     dataset_stats=dataset_stats,
>>>     classes=classes,
>>>     input_sensorchan=input_sensorchan,
>>>     #tokenizer='dwcnn',
>>>     tokenizer='linconv',
>>>     multimodal_reduce='learned_linear',
>>>     #tokenizer='rearrange',
>>>     # normalize_perframe=True,
>>>     window_size=8,
>>>     )
>>> self.datamodule = datamodule
>>> datamodule._notify_about_tasks(model=self)
>>> # Run one visualization
>>> loader = datamodule.train_dataloader()
>>> # Load one batch and show it before we do anything
>>> batch = next(iter(loader))
>>> print(ub.urepr(dataset.summarize_item(batch[0]), nl=3))
>>> import kwplot
>>> plt = kwplot.autoplt(force='Qt5Agg')
>>> plt.ion()
>>> canvas = datamodule.draw_batch(batch, max_channels=5, overlay_on_image=0)
>>> kwplot.imshow(canvas, fnum=1)
>>> # Run overfit
>>> device = 0
>>> self.overfit(batch)
parameter_hacking(optimizer)[source]
prepare_item(item)[source]
save_package(package_path, context=None, verbose=1)[source]

Save model architecture and checkpoint as a torch package.

Parameters:
  • package_path (str | PathLike) – where to save the package

  • context (Any) – custom json-serializable data to save in the header

  • verbose (int) – verbosity level

CommandLine

xdoctest -m geowatch.tasks.fusion.methods.channelwise_transformer MultimodalTransformer.save_package

Example

>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import *  # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> model = self = methods.MultimodalTransformer(
>>>     arch_name="smt_it_joint_p2", input_sensorchan=5,
>>>     change_head_hidden=0, saliency_head_hidden=0,
>>>     class_head_hidden=0)
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> recon = methods.MultimodalTransformer.load_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = model.state_dict()
>>> assert recon is not model
>>> assert set(recon_state) == set(recon_state)
>>> for key in recon_state.keys():
>>>     assert (model_state[key] == recon_state[key]).all()
>>>     assert model_state[key] is not recon_state[key]
test_step(batch, batch_idx=None)[source]
training_step(batch, batch_idx=None)[source]
validation_step(batch, batch_idx=None)[source]
class geowatch.tasks.fusion.methods.HeterogeneousModel(classes=10, dataset_stats=None, input_sensorchan=None, name: str = 'unnamed_model', position_encoder: str | ScaleAwarePositionalEncoder = 'auto', backbone: str | BackboneEncoderDecoder = 'auto', token_width: int = 10, token_dim: int = 16, spatial_scale_base: float = 1.0, temporal_scale_base: float = 1.0, class_weights: str = 'auto', saliency_weights: str = 'auto', positive_change_weight: float = 1.0, negative_change_weight: float = 1.0, global_class_weight: float = 1.0, global_change_weight: float = 1.0, global_saliency_weight: float = 1.0, change_loss: str = 'cce', class_loss: str = 'focal', saliency_loss: str = 'focal', tokenizer: str = 'simple_conv', decoder: str = 'upsample', ohem_ratio: float | None = None, focal_gamma: float | None = 2.0)[source]

Bases: LightningModule, WatchModuleMixins

Parameters:
  • name – Specify a name for the experiment. (Unsure if the Model is the place for this)

  • token_width – Width of each square token.

  • token_dim – Dimensionality of each computed token.

  • spatial_scale_base – The scale assigned to each token equals scale_base / token_density, where the token density is the number of tokens along a given axis.

  • temporal_scale_base – The scale assigned to each token equals scale_base / token_density, where the token density is the number of tokens along a given axis.

  • class_weights – Class weighting strategy.

  • saliency_weights – Class weighting strategy.

Example

>>> # Note: it is important that the non-kwargs are saved as hyperparams
>>> from geowatch.tasks.fusion.methods.heterogeneous import HeterogeneousModel, ScaleAgnostictPositionalEncoder
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = ScaleAgnostictPositionalEncoder(3, 8)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> model = HeterogeneousModel(
>>>   input_sensorchan='r|g|b',
>>>   position_encoder=position_encoder,
>>>   backbone=backbone,
>>> )
configure_optimizers()

Note: this is only a fallback for testing purposes. This should be overwrriten in your module or done via lightning CLI.

forward(batch)[source]

Example

>>> from geowatch.tasks import fusion
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     backbone=backbone,
>>>     position_encoder=position_encoder,
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> batch += model.demo_batch(width=55, height=75)
>>> outputs = model.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>>     if "probs" in task_key: continue
>>>     if task_key == "class": task_key = "class_idxs"
>>>     for task_pred, example in zip(task_outputs, batch):
>>>         for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>>             if (frame_idx == 0) and task_key.startswith("change"): continue
>>>             assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"

Example

>>> from geowatch.tasks import fusion
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     backbone=backbone,
>>>     decoder="simple_conv",
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> batch += model.demo_batch(width=55, height=75)
>>> outputs = model.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>>     if "probs" in task_key: continue
>>>     if task_key == "class": task_key = "class_idxs"
>>>     for task_pred, example in zip(task_outputs, batch):
>>>         for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>>             if (frame_idx == 0) and task_key.startswith("change"): continue
>>>             assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"

Example

>>> from geowatch.tasks import fusion
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     backbone=backbone,
>>>     decoder="trans_conv",
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> batch += model.demo_batch(width=55, height=75)
>>> outputs = model.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>>     if "probs" in task_key: continue
>>>     if task_key == "class": task_key = "class_idxs"
>>>     for task_pred, example in zip(task_outputs, batch):
>>>         for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>>             if (frame_idx == 0) and task_key.startswith("change"): continue
>>>             assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"

Example

>>> from geowatch.tasks import fusion
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=0,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=0,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     backbone=backbone,
>>>     decoder="trans_conv",
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> batch += model.demo_batch(width=55, height=75)
>>> outputs = model.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>>     if "probs" in task_key: continue
>>>     if task_key == "class": task_key = "class_idxs"
>>>     for task_pred, example in zip(task_outputs, batch):
>>>         for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>>             if (frame_idx == 0) and task_key.startswith("change"): continue
>>>             assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"

Example

>>> # xdoctest: +REQUIRES(module:mmseg)
>>> from geowatch.tasks import fusion
>>> from geowatch.tasks.fusion.architectures.transformer import MM_VITEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = MM_VITEncoderDecoder(
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     backbone=backbone,
>>>     position_encoder=position_encoder,
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> batch += model.demo_batch(width=55, height=75)
>>> outputs = model.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>>     if "probs" in task_key: continue
>>>     if task_key == "class": task_key = "class_idxs"
>>>     for task_pred, example in zip(task_outputs, batch):
>>>         for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>>             if (frame_idx == 0) and task_key.startswith("change"): continue
>>>             assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"

Example

>>> # xdoctest: +REQUIRES(module:mmseg)
>>> from geowatch.tasks import fusion
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> self = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     #token_dim=708,
>>>     token_dim=768 - 60,
>>>     backbone='vit_B_16_imagenet1k',
>>>     position_encoder=position_encoder,
>>> )
>>> batch = self.demo_batch(width=64, height=65)
>>> batch += self.demo_batch(width=55, height=75)
>>> outputs = self.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>>     if "probs" in task_key: continue
>>>     if task_key == "class": task_key = "class_idxs"
>>>     for task_pred, example in zip(task_outputs, batch):
>>>         for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>>             if (frame_idx == 0) and task_key.startswith("change"): continue
>>>             assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
forward_step(batch, batch_idx=None, stage='train', with_loss=True)

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     decoder="trans_conv",
>>>     backbone=backbone,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     decoder="trans_conv",
>>>     backbone=backbone,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> batch += [None]
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     decoder="trans_conv",
>>>     backbone=backbone,
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> for cutoff in [-1, -2]:
>>>     degraded_example = model.demo_batch(width=55, height=75, num_timesteps=3)[0]
>>>     degraded_example["frames"] = degraded_example["frames"][:cutoff]
>>>     batch += [degraded_example]
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     decoder="trans_conv",
>>>     backbone=backbone,
>>> )
>>> batch = model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=0.1)
>>> batch += model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=0.5)
>>> batch += model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=1.0)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()
get_cfgstr()[source]
log_grad_norm(grad_norm_dict) None[source]

Override this method to change the default behaviour of log_grad_norm.

Overloads log_grad_norm so we can supress the batch_size warning

predict_step(batch, batch_idx=None)[source]
process_input_tokens(example)[source]

Example

>>> from geowatch.tasks import fusion
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     backbone=backbone,
>>> )
>>> example = model.demo_batch(width=64, height=65)[0]
>>> input_tokens = model.process_input_tokens(example)
>>> assert len(input_tokens) == len(example["frames"])
>>> assert len(input_tokens[0]) == len(example["frames"][0]["modes"])
process_query_tokens(example)[source]

Example

>>> from geowatch.tasks import fusion
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     backbone=backbone,
>>> )
>>> example = model.demo_batch(width=64, height=65)[0]
>>> query_tokens = model.process_query_tokens(example)
>>> assert len(query_tokens) == len(example["frames"])
save_package(package_path, context=None, verbose=1)[source]

CommandLine

xdoctest -m geowatch.tasks.fusion.methods.heterogeneous HeterogeneousModel.save_package

Example

>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> #from geowatch.tasks.fusion.methods.heterogeneous import *  # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> model = self = methods.HeterogeneousModel(
>>>     position_encoder=position_encoder,
>>>     input_sensorchan=5,
>>>     decoder="upsample",
>>>     backbone=backbone,
>>> )
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.HeterogeneousModel.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = model.state_dict()
>>> assert recon is not model
>>> assert set(recon_state) == set(recon_state)
>>> for key in recon_state.keys():
>>>     assert (model_state[key] == recon_state[key]).all()
>>>     assert model_state[key] is not recon_state[key]

Example

>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> #from geowatch.tasks.fusion.methods.heterogeneous import *  # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> model = self = methods.HeterogeneousModel(
>>>     position_encoder=position_encoder,
>>>     input_sensorchan=5,
>>>     decoder="simple_conv",
>>>     backbone=backbone,
>>> )
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.HeterogeneousModel.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = model.state_dict()
>>> assert recon is not model
>>> assert set(recon_state) == set(recon_state)
>>> for key in recon_state.keys():
>>>     assert (model_state[key] == recon_state[key]).all()
>>>     assert model_state[key] is not recon_state[key]

Example

>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> #from geowatch.tasks.fusion.methods.heterogeneous import *  # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> model = self = methods.HeterogeneousModel(
>>>     position_encoder=position_encoder,
>>>     input_sensorchan=5,
>>>     decoder="trans_conv",
>>>     backbone=backbone,
>>> )
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.HeterogeneousModel.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = model.state_dict()
>>> assert recon is not model
>>> assert set(recon_state) == set(recon_state)
>>> for key in recon_state.keys():
>>>     assert (model_state[key] == recon_state[key]).all()
>>>     assert model_state[key] is not recon_state[key]
shared_step(batch, batch_idx=None, stage='train', with_loss=True)[source]

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     decoder="trans_conv",
>>>     backbone=backbone,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     decoder="trans_conv",
>>>     backbone=backbone,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> batch += [None]
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     decoder="trans_conv",
>>>     backbone=backbone,
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> for cutoff in [-1, -2]:
>>>     degraded_example = model.demo_batch(width=55, height=75, num_timesteps=3)[0]
>>>     degraded_example["frames"] = degraded_example["frames"][:cutoff]
>>>     batch += [degraded_example]
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     decoder="trans_conv",
>>>     backbone=backbone,
>>> )
>>> batch = model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=0.1)
>>> batch += model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=0.5)
>>> batch += model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=1.0)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()
test_step(batch, batch_idx=None)[source]
training_step(batch, batch_idx=None)[source]
validation_step(batch, batch_idx=None)[source]
class geowatch.tasks.fusion.methods.UNetBaseline(classes=10, dataset_stats=None, input_sensorchan=None, token_dim: int = 32, name: str = 'unnamed_model', class_weights: str = 'auto', saliency_weights: str = 'auto', positive_change_weight: float = 1.0, negative_change_weight: float = 1.0, global_class_weight: float = 1.0, global_change_weight: float = 1.0, global_saliency_weight: float = 1.0, change_loss: str = 'cce', class_loss: str = 'focal', saliency_loss: str = 'focal', ohem_ratio: float | None = None, focal_gamma: float | None = 2.0)[source]

Bases: LightningModule, WatchModuleMixins

Parameters:
  • name – Specify a name for the experiment. (Unsure if the Model is the place for this)

  • token_width – Width of each square token.

  • token_dim – Dimensionality of each computed token.

  • spatial_scale_base – The scale assigned to each token equals scale_base / token_density, where the token density is the number of tokens along a given axis.

  • temporal_scale_base – The scale assigned to each token equals scale_base / token_density, where the token density is the number of tokens along a given axis.

  • class_weights – Class weighting strategy.

  • saliency_weights – Class weighting strategy.

Example

>>> # Note: it is important that the non-kwargs are saved as hyperparams
>>> from geowatch.tasks.fusion.methods.unet_baseline import UNetBaseline
>>> model = UNetBaseline(
>>>   input_sensorchan='r|g|b',
>>> )
encode_batch(processed_batch)[source]
encode_example(processed_example)[source]
encode_frame(processed_frame)[source]
forward(batch)[source]

Example

>>> from geowatch.tasks import fusion
>>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats()
>>> model = fusion.methods.UNetBaseline(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>> )
>>> batch = model.demo_batch(width=64, height=64)
>>> outputs = model.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>>     if "probs" in task_key: continue
>>>     if task_key == "class": task_key = "class_idxs"
>>>     for task_pred, example in zip(task_outputs, batch):
>>>         for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>>             if (frame_idx == 0) and task_key.startswith("change"): continue
>>>             assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
forward_step(batch, batch_idx=None, stage='train', with_loss=True)

Example

>>> # xdoctest: +REQUIRES(env:SLOW_TESTS)
>>> from geowatch.tasks import fusion
>>> import torch
>>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats()
>>> model = fusion.methods.UNetBaseline(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> # xdoctest: +REQUIRES(env:SLOW_TESTS)
>>> from geowatch.tasks import fusion
>>> import torch
>>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats()
>>> model = fusion.methods.UNetBaseline(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> batch += [None]
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats()
>>> model = fusion.methods.UNetBaseline(
>>>     classes=classes, token_dim=2,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>> )
>>> batch = model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=0.1)
>>> batch += model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=0.5)
>>> batch += model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=1.0)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()
get_cfgstr()[source]
predict_step(batch, batch_idx=None)[source]
process_batch(batch)[source]
process_example(example)[source]
process_frame(frame) Dict[str, Dict[str, Any]][source]
save_package(package_path, context=None, verbose=1)[source]

CommandLine

xdoctest -m geowatch.tasks.fusion.methods.unet_baseline UNetBaseline.save_package

Example

>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> #from geowatch.tasks.fusion.methods.unet_baseline import *  # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> model = self = methods.UNetBaseline(
>>>     input_sensorchan=5,
>>> )
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.UNetBaseline.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = model.state_dict()
>>> assert recon is not self
>>> assert set(recon_state) == set(recon_state)
>>> from geowatch.utils.util_kwarray import torch_array_equal
>>> for key in recon_state.keys():
>>>     v1 = model_state[key]
>>>     v2 = recon_state[key]
>>>     if not torch.allclose(v1, v2, equal_nan=True):
>>>         print('v1 = {}'.format(ub.urepr(v1, nl=1)))
>>>         print('v2 = {}'.format(ub.urepr(v2, nl=1)))
>>>         raise AssertionError(f'Difference in key={key}')
>>>     assert v1 is not v2, 'should be distinct copies'

Example

>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> #from geowatch.tasks.fusion.methods.unet_baseline import *  # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> model = self = methods.UNetBaseline(
>>>     input_sensorchan=5,
>>> )
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.UNetBaseline.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = self.state_dict()
>>> assert recon is not self
>>> assert set(recon_state) == set(recon_state)
>>> from geowatch.utils.util_kwarray import torch_array_equal
>>> for key in recon_state.keys():
>>>     v1 = model_state[key]
>>>     v2 = recon_state[key]
>>>     if not torch.allclose(v1, v2, equal_nan=True):
>>>         print('v1 = {}'.format(ub.urepr(v1, nl=1)))
>>>         print('v2 = {}'.format(ub.urepr(v2, nl=1)))
>>>         raise AssertionError(f'Difference in key={key}')
>>>     assert v1 is not v2, 'should be distinct copies'

Example

>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> #from geowatch.tasks.fusion.methods.unet_baseline import *  # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> model = self = methods.UNetBaseline(
>>>     input_sensorchan=5,
>>> )
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.UNetBaseline.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = self.state_dict()
>>> assert recon is not self
>>> assert set(recon_state) == set(recon_state)
>>> from geowatch.utils.util_kwarray import torch_array_equal
>>> for key in recon_state.keys():
>>>     v1 = model_state[key]
>>>     v2 = recon_state[key]
>>>     if not torch.allclose(v1, v2, equal_nan=True):
>>>         print('v1 = {}'.format(ub.urepr(v1, nl=1)))
>>>         print('v2 = {}'.format(ub.urepr(v2, nl=1)))
>>>         raise AssertionError(f'Difference in key={key}')
>>>     assert v1 is not v2, 'should be distinct copies'
shared_step(batch, batch_idx=None, stage='train', with_loss=True)[source]

Example

>>> # xdoctest: +REQUIRES(env:SLOW_TESTS)
>>> from geowatch.tasks import fusion
>>> import torch
>>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats()
>>> model = fusion.methods.UNetBaseline(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> # xdoctest: +REQUIRES(env:SLOW_TESTS)
>>> from geowatch.tasks import fusion
>>> import torch
>>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats()
>>> model = fusion.methods.UNetBaseline(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> batch += [None]
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats()
>>> model = fusion.methods.UNetBaseline(
>>>     classes=classes, token_dim=2,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>> )
>>> batch = model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=0.1)
>>> batch += model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=0.5)
>>> batch += model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=1.0)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()
test_step(batch, batch_idx=None)[source]
training_step(batch, batch_idx=None)[source]
validation_step(batch, batch_idx=None)[source]
class geowatch.tasks.fusion.methods.NoopModel(classes=10, dataset_stats=None, input_sensorchan=None, name: str = 'unnamed_model')[source]

Bases: LightningModule, WatchModuleMixins

No-op example model. Contains a dummy parameter to satisfy the optimizer and trainer.

Todo

  • [ ] Minimize even further.

  • [ ] Identify mandatory steps in __init__ and move to a parent class.

Parameters:

name – Specify a name for the experiment. (Unsure if the Model is the place for this)

configure_optimizers()[source]
forward(x)[source]
forward_step(batch, batch_idx=None, with_loss=True)
get_cfgstr()[source]
save_package(package_path, context=None, verbose=1)[source]

CommandLine

xdoctest -m geowatch.tasks.fusion.methods.noop_model NoopModel.save_package

Example

>>> # Test without datamodule
>>> import ubelt as ub
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = dpath / 'my_package.pt'
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> model = self = methods.NoopModel(
>>>     input_sensorchan=5,)
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.NoopModel.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = model.state_dict()
>>> assert recon is not model
>>> assert set(recon_state) == set(recon_state)
>>> for key in recon_state.keys():
>>>     assert (model_state[key] == recon_state[key]).all()
>>>     assert model_state[key] is not recon_state[key]
>>> # Check what's inside of the package
>>> import zipfile
>>> import json
>>> zfile = zipfile.ZipFile(package_path)
>>> header_file = zfile.open('my_package/package_header/package_header.json')
>>> package_header = json.loads(header_file.read())
>>> print('package_header = {}'.format(ub.urepr(package_header, nl=1)))
>>> assert 'version' in package_header
>>> assert 'arch_name' in package_header
>>> assert 'module_name' in package_header
>>> assert 'packaging_time' in package_header
>>> assert 'git_hash' in package_header
>>> assert 'module_path' in package_header

Example

>>> # Test with datamodule
>>> import ubelt as ub
>>> from geowatch.tasks.fusion import datamodules
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion.methods.noop_model import *  # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = dpath / 'my_package.pt'
>>> datamodule = datamodules.kwcoco_video_data.KWCocoVideoDataModule(
>>>     train_dataset='special:vidshapes8-multispectral-multisensor', chip_size=32,
>>>     batch_size=1, time_steps=2, num_workers=2, normalize_inputs=10)
>>> datamodule.setup('fit')
>>> dataset_stats = datamodule.torch_datasets['train'].cached_dataset_stats(num=3)
>>> classes = datamodule.torch_datasets['train'].classes
>>> # Use one of our fusion.architectures in a test
>>> self = methods.NoopModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats, input_sensorchan=datamodule.input_sensorchan)
>>> # We have to run an input through the module because it is lazy
>>> batch = ub.peek(iter(datamodule.train_dataloader()))
>>> outputs = self.training_step(batch)
>>> trainer = pl.Trainer(max_steps=0)
>>> trainer.fit(model=self, datamodule=datamodule)
>>> # Save the self
>>> self.save_package(package_path)
>>> # Test that the package can be reloaded
>>> recon = methods.NoopModel.load_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = self.state_dict()
>>> assert recon is not self
>>> assert set(recon_state) == set(recon_state)
>>> from geowatch.utils.util_kwarray import torch_array_equal
>>> for key in recon_state.keys():
>>>     v1 = model_state[key]
>>>     v2 = recon_state[key]
>>>     if not torch.allclose(v1, v2, equal_nan=True):
>>>         print('v1 = {}'.format(ub.urepr(v1, nl=1)))
>>>         print('v2 = {}'.format(ub.urepr(v2, nl=1)))
>>>         raise AssertionError(f'Difference in key={key}')
>>>     assert v1 is not v2, 'should be distinct copies'
shared_step(batch, batch_idx=None, with_loss=True)[source]
training_step(batch, batch_idx=None, with_loss=True)