geowatch.tasks.fusion.methods.watch_module_mixins module

class geowatch.tasks.fusion.methods.watch_module_mixins.ExtendTorchMixin[source]

Bases: object

reset_weights()[source]
devices()[source]

Returns all devices this module state is mounted on

Returns:

set of devices used by this model

Return type:

Set[torch.device]

property main_device

The main/src torch device used by this model

class geowatch.tasks.fusion.methods.watch_module_mixins.MSIDemoDataMixin[source]

Bases: object

classmethod demo_dataset_stats()[source]

Mock data that mimiks a dataset summary a kwcoco dataloader could provide.

demo_batch(batch_size=1, num_timesteps=3, width=8, height=8, nans=0, rng=None, new_mode_sample=0)[source]

Example

>>> from geowatch.tasks.fusion.methods.channelwise_transformer import *  # NOQA
>>> from geowatch.utils.util_netharn import _debug_inbatch_shapes
>>> channels, clases, dataset_stats = MultimodalTransformer.demo_dataset_stats()
>>> self = MultimodalTransformer(
>>>     arch_name='smt_it_stm_p1', tokenizer='linconv',
>>>     decoder='mlp', classes=clases, global_saliency_weight=1,
>>>     dataset_stats=dataset_stats, input_sensorchan=channels)
>>> batch = self.demo_batch()
>>> if 1:
>>>     print(_debug_inbatch_shapes(batch))
>>> result = self.forward_step(batch)
>>> if 1:
>>>     print(_debug_inbatch_shapes(batch))

Example

>>> # With nans
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import *  # NOQA
>>> from geowatch.utils.util_netharn import _debug_inbatch_shapes
>>> channels, clases, dataset_stats = MultimodalTransformer.demo_dataset_stats()
>>> self = MultimodalTransformer(
>>>     arch_name='smt_it_stm_p1', tokenizer='linconv',
>>>     decoder='mlp', classes=clases, global_saliency_weight=1,
>>>     dataset_stats=dataset_stats, input_sensorchan=channels)
>>> batch = self.demo_batch(nans=0.5, num_timesteps=2)
>>> item = batch[0]
>>> if 1:
>>>     print(_debug_inbatch_shapes(batch))
>>> result1 = self.forward_step(batch)
>>> result2 = self.forward_step(batch, with_loss=0)
>>> if 1:
>>>     print(_debug_inbatch_shapes(result1))
>>>     print(_debug_inbatch_shapes(result2))

Example

>>> from geowatch.tasks.fusion.methods.channelwise_transformer import *  # NOQA
>>> from geowatch.utils.util_netharn import _debug_inbatch_shapes
>>> channels, clases, dataset_stats = MultimodalTransformer.demo_dataset_stats()
>>> self = MultimodalTransformer(
>>>     arch_name='smt_it_stm_p1', tokenizer='linconv',
>>>     decoder='mlp', classes=clases, global_saliency_weight=1,
>>>     dataset_stats=dataset_stats, input_sensorchan=channels)
>>> batch = self.demo_batch(new_mode_sample=1)
>>> print(_debug_inbatch_shapes(batch))
class geowatch.tasks.fusion.methods.watch_module_mixins.LightningModelMixin[source]

Bases: object

property has_trainer
class geowatch.tasks.fusion.methods.watch_module_mixins.DeprecatedMixin[source]

Bases: object

configure_optimizers()[source]

Note: this is only a fallback for testing purposes. This should be overwrriten in your module or done via lightning CLI.

class geowatch.tasks.fusion.methods.watch_module_mixins.OverfitMixin[source]

Bases: object

overfit(batch)[source]

Overfit script and demo

CommandLine

python -m xdoctest -m geowatch.tasks.fusion.methods.channelwise_transformer MultimodalTransformer.overfit --overfit-demo

Example

>>> # xdoctest: +REQUIRES(--overfit-demo)
>>> # ============
>>> # DEMO OVERFIT:
>>> # ============
>>> from geowatch.tasks.fusion.methods.heterogeneous import *  # NOQA
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> from geowatch.utils.util_data import find_dvc_dpath
>>> import geowatch
>>> import kwcoco
>>> from os.path import join
>>> import os
>>> if 0:
>>>     '''
>>>     # Generate toy datasets
>>>     DATA_DPATH=$HOME/data/work/toy_change
>>>     TRAIN_FPATH=$DATA_DPATH/vidshapes_msi_train/data.kwcoco.json
>>>     mkdir -p "$DATA_DPATH"
>>>     kwcoco toydata --key=vidshapes-videos8-frames5-randgsize-speed0.2-msi-multisensor --bundle_dpath "$DATA_DPATH/vidshapes_msi_train" --verbose=5
>>>     '''
>>>     coco_fpath = ub.expandpath('$HOME/data/work/toy_change/vidshapes_msi_train/data.kwcoco.json')
>>>     coco_dset = kwcoco.CocoDataset.coerce(coco_fpath)
>>>     channels="B11,r|g|b,B1|B8|B11"
>>> if 1:
>>>     dvc_dpath = geowatch.find_dvc_dpath(tags='phase3_data', hardware='auto')
>>>     coco_dset = (dvc_dpath / 'Drop8-ARA-Median10GSD-V1') / 'KR_R001/imganns-KR_R001-rawbands.kwcoco.zip'
>>>     channels='blue|green|red|nir'
>>> if 0:
>>>     coco_dset = geowatch.demo.demo_kwcoco_multisensor(max_speed=0.5)
>>>     # coco_dset = 'special:vidshapes8-frames9-speed0.5-multispectral'
>>>     #channels='B1|B11|B8|r|g|b|gauss'
>>>     channels='X.2|Y:2:6,B1|B8|B8a|B10|B11,r|g|b,disparity|gauss,flowx|flowy|distri'
>>> coco_dset = kwcoco.CocoDataset.coerce(coco_dset)
>>> datamodule = datamodules.KWCocoVideoDataModule(
>>>     train_dataset=coco_dset,
>>>     chip_size=128, batch_size=1, time_steps=5,
>>>     channels=channels,
>>>     normalize_peritem='blue|green|red|nir',
>>>     normalize_inputs=32, neg_to_pos_ratio=0,
>>>     num_workers='avail/2',
>>>     mask_low_quality=True,
>>>     observable_threshold=0.6,
>>>     use_grid_positives=False, use_centered_positives=True,
>>> )
>>> datamodule.setup('fit')
>>> dataset = torch_dset = datamodule.torch_datasets['train']
>>> torch_dset.disable_augmenter = True
>>> dataset_stats = datamodule.dataset_stats
>>> input_sensorchan = datamodule.input_sensorchan
>>> classes = datamodule.classes
>>> print('dataset_stats = {}'.format(ub.urepr(dataset_stats, nl=3)))
>>> print('input_sensorchan = {}'.format(input_sensorchan))
>>> print('classes = {}'.format(classes))
>>> # Choose subclass to test this with (does not cover all cases)
>>> self = methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     #token_dim=708,
>>>     #token_dim=768 - 60,
>>>     #backbone='vit_B_16_imagenet1k',
>>>     #token_dim=208,
>>>     #backbone='sits-former',
>>>     )
>>> # Choose subclass to test this with (does not cover all cases)
>>> self = methods.MultimodalTransformer(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     #token_dim=708,
>>>     #token_dim=768 - 60,
>>>     #backbone='vit_B_16_imagenet1k',
>>>     #token_dim=208,
>>>     #backbone='sits-former',
>>>     )
>>> self.datamodule = datamodule
>>> datamodule._notify_about_tasks(model=self)
>>> # Run one visualization
>>> loader = datamodule.train_dataloader()
>>> # Load one batch and show it before we do anything
>>> batch = next(iter(loader))
>>> print(ub.urepr(dataset.summarize_item(batch[0]), nl=3))
>>> import kwplot
>>> plt = kwplot.autoplt(force='Qt5Agg')
>>> plt.ion()
>>> canvas = datamodule.draw_batch(batch, max_channels=5, overlay_on_image=0)
>>> kwplot.imshow(canvas, fnum=1)
>>> # Run overfit
>>> device = 0
>>> self.overfit(batch)
class geowatch.tasks.fusion.methods.watch_module_mixins.PackageMixin[source]

Bases: object

classmethod load_package(package_path, verbose=1)[source]

DEPRECATE IN FAVOR OF geowatch.tasks.fusion.utils.load_model_from_package

Todo

  • [ ] Make the logic that defines the save_package and load_package

    methods with appropriate package header data a lightning abstraction.

class geowatch.tasks.fusion.methods.watch_module_mixins.CoerceMixins[source]

Bases: object

class geowatch.tasks.fusion.methods.watch_module_mixins.DatasetStatsMixin[source]

Bases: object

set_dataset_specific_attributes(input_sensorchan, dataset_stats)[source]

Set module attributes based on dataset stats it will be trained on.

Parameters:
  • input_sensorchan (str | kwcoco.SensorchanSpec | None) – The input sensor channels the model should expect

  • dataset_stats (Dict | None) – See demo_dataset_stats() for an example of this structure

Returns:

input_stats

Return type:

None | Dict

The following attributes will be set after calling this method.

  • self.class_freq

  • self.dataset_stats

  • self.input_sensorchan

  • self.unique_sensor_modes

We also return an input_stats variable which should be used for setting model-dependent handling of input normalization.

The handling of dataset_stats and input_sensorchan are weirdly coupled for legacy reasons and duplicated across several modules. This is a common location for that code to allow it to be more easily refactored and simplified at a later date.

class geowatch.tasks.fusion.methods.watch_module_mixins.WatchModuleMixins[source]

Bases: MSIDemoDataMixin, ExtendTorchMixin, LightningModelMixin, CoerceMixins, PackageMixin, OverfitMixin, DatasetStatsMixin, DeprecatedMixin

Mixin methods for geowatch lightning modules