geowatch.tasks.fusion.methods.watch_module_mixins module¶
- class geowatch.tasks.fusion.methods.watch_module_mixins.ExtendTorchMixin[source]¶
Bases:
object
- devices()[source]¶
Returns all devices this module state is mounted on
- Returns:
set of devices used by this model
- Return type:
Set[torch.device]
- property main_device¶
The main/src torch device used by this model
- class geowatch.tasks.fusion.methods.watch_module_mixins.MSIDemoDataMixin[source]¶
Bases:
object
- classmethod demo_dataset_stats()[source]¶
Mock data that mimiks a dataset summary a kwcoco dataloader could provide.
- demo_batch(batch_size=1, num_timesteps=3, width=8, height=8, nans=0, rng=None, new_mode_sample=0)[source]¶
Example
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> from geowatch.utils.util_netharn import _debug_inbatch_shapes >>> channels, clases, dataset_stats = MultimodalTransformer.demo_dataset_stats() >>> self = MultimodalTransformer( >>> arch_name='smt_it_stm_p1', tokenizer='linconv', >>> decoder='mlp', classes=clases, global_saliency_weight=1, >>> dataset_stats=dataset_stats, input_sensorchan=channels) >>> batch = self.demo_batch() >>> if 1: >>> print(_debug_inbatch_shapes(batch)) >>> result = self.forward_step(batch) >>> if 1: >>> print(_debug_inbatch_shapes(batch))
Example
>>> # With nans >>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> from geowatch.utils.util_netharn import _debug_inbatch_shapes >>> channels, clases, dataset_stats = MultimodalTransformer.demo_dataset_stats() >>> self = MultimodalTransformer( >>> arch_name='smt_it_stm_p1', tokenizer='linconv', >>> decoder='mlp', classes=clases, global_saliency_weight=1, >>> dataset_stats=dataset_stats, input_sensorchan=channels) >>> batch = self.demo_batch(nans=0.5, num_timesteps=2) >>> item = batch[0] >>> if 1: >>> print(_debug_inbatch_shapes(batch)) >>> result1 = self.forward_step(batch) >>> result2 = self.forward_step(batch, with_loss=0) >>> if 1: >>> print(_debug_inbatch_shapes(result1)) >>> print(_debug_inbatch_shapes(result2))
Example
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> from geowatch.utils.util_netharn import _debug_inbatch_shapes >>> channels, clases, dataset_stats = MultimodalTransformer.demo_dataset_stats() >>> self = MultimodalTransformer( >>> arch_name='smt_it_stm_p1', tokenizer='linconv', >>> decoder='mlp', classes=clases, global_saliency_weight=1, >>> dataset_stats=dataset_stats, input_sensorchan=channels) >>> batch = self.demo_batch(new_mode_sample=1) >>> print(_debug_inbatch_shapes(batch))
- class geowatch.tasks.fusion.methods.watch_module_mixins.LightningModelMixin[source]¶
Bases:
object
- property has_trainer¶
- class geowatch.tasks.fusion.methods.watch_module_mixins.OverfitMixin[source]¶
Bases:
object
- overfit(batch)[source]¶
Overfit script and demo
CommandLine
python -m xdoctest -m geowatch.tasks.fusion.methods.channelwise_transformer MultimodalTransformer.overfit --overfit-demo
Example
>>> # xdoctest: +REQUIRES(--overfit-demo) >>> # ============ >>> # DEMO OVERFIT: >>> # ============ >>> from geowatch.tasks.fusion.methods.heterogeneous import * # NOQA >>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion import datamodules >>> from geowatch.utils.util_data import find_dvc_dpath >>> import geowatch >>> import kwcoco >>> from os.path import join >>> import os >>> if 0: >>> ''' >>> # Generate toy datasets >>> DATA_DPATH=$HOME/data/work/toy_change >>> TRAIN_FPATH=$DATA_DPATH/vidshapes_msi_train/data.kwcoco.json >>> mkdir -p "$DATA_DPATH" >>> kwcoco toydata --key=vidshapes-videos8-frames5-randgsize-speed0.2-msi-multisensor --bundle_dpath "$DATA_DPATH/vidshapes_msi_train" --verbose=5 >>> ''' >>> coco_fpath = ub.expandpath('$HOME/data/work/toy_change/vidshapes_msi_train/data.kwcoco.json') >>> coco_dset = kwcoco.CocoDataset.coerce(coco_fpath) >>> channels="B11,r|g|b,B1|B8|B11" >>> if 1: >>> dvc_dpath = geowatch.find_dvc_dpath(tags='phase3_data', hardware='auto') >>> coco_dset = (dvc_dpath / 'Drop8-ARA-Median10GSD-V1') / 'KR_R001/imganns-KR_R001-rawbands.kwcoco.zip' >>> channels='blue|green|red|nir' >>> if 0: >>> coco_dset = geowatch.demo.demo_kwcoco_multisensor(max_speed=0.5) >>> # coco_dset = 'special:vidshapes8-frames9-speed0.5-multispectral' >>> #channels='B1|B11|B8|r|g|b|gauss' >>> channels='X.2|Y:2:6,B1|B8|B8a|B10|B11,r|g|b,disparity|gauss,flowx|flowy|distri' >>> coco_dset = kwcoco.CocoDataset.coerce(coco_dset) >>> datamodule = datamodules.KWCocoVideoDataModule( >>> train_dataset=coco_dset, >>> chip_size=128, batch_size=1, time_steps=5, >>> channels=channels, >>> normalize_peritem='blue|green|red|nir', >>> normalize_inputs=32, neg_to_pos_ratio=0, >>> num_workers='avail/2', >>> mask_low_quality=True, >>> observable_threshold=0.6, >>> use_grid_positives=False, use_centered_positives=True, >>> ) >>> datamodule.setup('fit') >>> dataset = torch_dset = datamodule.torch_datasets['train'] >>> torch_dset.disable_augmenter = True >>> dataset_stats = datamodule.dataset_stats >>> input_sensorchan = datamodule.input_sensorchan >>> classes = datamodule.classes >>> print('dataset_stats = {}'.format(ub.urepr(dataset_stats, nl=3))) >>> print('input_sensorchan = {}'.format(input_sensorchan)) >>> print('classes = {}'.format(classes)) >>> # Choose subclass to test this with (does not cover all cases) >>> self = methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> #token_dim=708, >>> #token_dim=768 - 60, >>> #backbone='vit_B_16_imagenet1k', >>> #token_dim=208, >>> #backbone='sits-former', >>> ) >>> # Choose subclass to test this with (does not cover all cases) >>> self = methods.MultimodalTransformer( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> #token_dim=708, >>> #token_dim=768 - 60, >>> #backbone='vit_B_16_imagenet1k', >>> #token_dim=208, >>> #backbone='sits-former', >>> ) >>> self.datamodule = datamodule >>> datamodule._notify_about_tasks(model=self) >>> # Run one visualization >>> loader = datamodule.train_dataloader() >>> # Load one batch and show it before we do anything >>> batch = next(iter(loader)) >>> print(ub.urepr(dataset.summarize_item(batch[0]), nl=3)) >>> import kwplot >>> plt = kwplot.autoplt(force='Qt5Agg') >>> plt.ion() >>> canvas = datamodule.draw_batch(batch, max_channels=5, overlay_on_image=0) >>> kwplot.imshow(canvas, fnum=1) >>> # Run overfit >>> device = 0 >>> self.overfit(batch)
- class geowatch.tasks.fusion.methods.watch_module_mixins.DatasetStatsMixin[source]¶
Bases:
object
- set_dataset_specific_attributes(input_sensorchan, dataset_stats)[source]¶
Set module attributes based on dataset stats it will be trained on.
- Parameters:
input_sensorchan (str | kwcoco.SensorchanSpec | None) – The input sensor channels the model should expect
dataset_stats (Dict | None) – See
demo_dataset_stats()
for an example of this structure
- Returns:
input_stats
- Return type:
None | Dict
The following attributes will be set after calling this method.
self.class_freq
self.dataset_stats
self.input_sensorchan
self.unique_sensor_modes
We also return an
input_stats
variable which should be used for setting model-dependent handling of input normalization.The handling of dataset_stats and input_sensorchan are weirdly coupled for legacy reasons and duplicated across several modules. This is a common location for that code to allow it to be more easily refactored and simplified at a later date.
- class geowatch.tasks.fusion.methods.watch_module_mixins.WatchModuleMixins[source]¶
Bases:
MSIDemoDataMixin
,ExtendTorchMixin
,LightningModelMixin
,CoerceMixins
,PackageMixin
,OverfitMixin
,DatasetStatsMixin
,DeprecatedMixin
Mixin methods for geowatch lightning modules