geowatch.tasks.fusion.methods package¶
Submodules¶
- geowatch.tasks.fusion.methods.channelwise_transformer module
MultimodalTransformerConfig
MultimodalTransformer
MultimodalTransformer.add_argparse_args()
MultimodalTransformer.compatible()
MultimodalTransformer.configure_optimizers()
MultimodalTransformer.overfit()
MultimodalTransformer.prepare_item()
MultimodalTransformer.forward_step()
MultimodalTransformer.forward_item()
MultimodalTransformer.forward_foot()
MultimodalTransformer.training_step()
MultimodalTransformer.optimizer_step()
MultimodalTransformer.parameter_hacking()
MultimodalTransformer.validation_step()
MultimodalTransformer.test_step()
MultimodalTransformer.save_package()
MultimodalTransformer.forward()
MultimodalTransformer.get_cfgstr()
slice_to_agree()
perterb_params()
- geowatch.tasks.fusion.methods.efficientdet module
normal_init()
bias_init_with_prob()
conv_ws_2d()
ConvWS2d
build_norm_layer()
build_conv_layer()
ConvModule
multi_apply()
RetinaHead
calc_iou()
FocalLoss
drop_connect()
round_repeats()
round_filters()
load_pretrained_weights()
Identity
Conv2dStaticSamePadding
Conv2dDynamicSamePadding
get_same_padding_conv2d()
BlockArgs
GlobalParams
BlockDecoder
efficientnet_params()
efficientnet()
get_model_params()
Swish
SwishImplementation
MemoryEfficientSwish
MBConvBlock
EfficientNet
ClipBoxes
xavier_init()
BiFPNModule
BIFPN
BBoxTransform
shift()
generate_anchors()
Anchors
EfficientDetCoder
EfficientDet
- geowatch.tasks.fusion.methods.heads module
- geowatch.tasks.fusion.methods.heterogeneous module
to_next_multiple()
positions_from_shape()
PadToMultiple
NanToNum
ShapePreservingTransformerEncoder
ScaleAwarePositionalEncoder
MipNerfPositionalEncoder
ScaleAgnostictPositionalEncoder
ResNetShim
HeterogeneousModel
HeterogeneousModel.get_cfgstr()
HeterogeneousModel.process_input_tokens()
HeterogeneousModel.process_query_tokens()
HeterogeneousModel.forward()
HeterogeneousModel.shared_step()
HeterogeneousModel.training_step()
HeterogeneousModel.validation_step()
HeterogeneousModel.test_step()
HeterogeneousModel.predict_step()
HeterogeneousModel.forward_step()
HeterogeneousModel.log_grad_norm()
HeterogeneousModel.save_package()
HeterogeneousModel.configure_optimizers()
- geowatch.tasks.fusion.methods.loss module
- geowatch.tasks.fusion.methods.network_modules module
- geowatch.tasks.fusion.methods.noop_model module
- geowatch.tasks.fusion.methods.object_head module
- geowatch.tasks.fusion.methods.torchvision_nets module
- geowatch.tasks.fusion.methods.unet_baseline module
NanToNum
UNetBaseline
UNetBaseline.get_cfgstr()
UNetBaseline.process_frame()
UNetBaseline.process_example()
UNetBaseline.process_batch()
UNetBaseline.encode_frame()
UNetBaseline.encode_example()
UNetBaseline.encode_batch()
UNetBaseline.forward()
UNetBaseline.shared_step()
UNetBaseline.training_step()
UNetBaseline.validation_step()
UNetBaseline.test_step()
UNetBaseline.predict_step()
UNetBaseline.forward_step()
UNetBaseline.save_package()
- geowatch.tasks.fusion.methods.watch_module_mixins module
Module contents¶
mkinit -m geowatch.tasks.fusion.methods -w
- class geowatch.tasks.fusion.methods.MultimodalTransformer(classes=10, dataset_stats=None, input_sensorchan=None, input_channels=None, **kwargs)[source]¶
Bases:
LightningModule
,WatchModuleMixins
Todo
[ ] Change name MultimodalTransformer -> FusionModel
[ ] Move parent module methods -> models
CommandLine
xdoctest -m geowatch.tasks.fusion.methods.channelwise_transformer MultimodalTransformer
Example
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> from geowatch.tasks.fusion import datamodules >>> print('(STEP 0): SETUP THE DATA MODULE') >>> datamodule = datamodules.KWCocoVideoDataModule( >>> train_dataset='special:vidshapes-geowatch', num_workers=4, channels='auto') >>> datamodule.setup('fit') >>> dataset = datamodule.torch_datasets['train'] >>> print('(STEP 1): ESTIMATE DATASET STATS') >>> dataset_stats = dataset.cached_dataset_stats(num=3) >>> print('dataset_stats = {}'.format(ub.urepr(dataset_stats, nl=3))) >>> loader = datamodule.train_dataloader() >>> print('(STEP 2): SAMPLE BATCH') >>> batch = next(iter(loader)) >>> for item_idx, item in enumerate(batch): >>> print(f'item_idx={item_idx}') >>> item_summary = dataset.summarize_item(item) >>> print('item_summary = {}'.format(ub.urepr(item_summary, nl=2))) >>> print('(STEP 3): THE REST OF THE TEST') >>> #self = MultimodalTransformer(arch_name='smt_it_joint_p8') >>> self = MultimodalTransformer(arch_name='smt_it_joint_p2', >>> dataset_stats=dataset_stats, >>> classes=datamodule.predictable_classes, >>> decoder='segmenter', >>> change_loss='dicefocal', >>> #attention_impl='performer' >>> attention_impl='exact' >>> ) >>> device = torch.device('cpu') >>> self = self.to(device) >>> # Run forward pass >>> from geowatch.utils import util_netharn >>> num_params = util_netharn.number_of_parameters(self) >>> print('num_params = {!r}'.format(num_params)) >>> output = self.forward_step(batch, with_loss=True) >>> import torch.profiler >>> from torch.profiler import profile, ProfilerActivity >>> with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: >>> with torch.profiler.record_function("model_inference"): >>> output = self.forward_step(batch, with_loss=True) >>> print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
Example
>>> # Note: it is important that the non-kwargs are saved as hyperparams >>> from geowatch.tasks.fusion.methods.channelwise_transformer import MultimodalTransformer >>> self = model = MultimodalTransformer(arch_name="smt_it_joint_p2", input_sensorchan='r|g|b') >>> assert "classes" in model.hparams >>> assert "dataset_stats" in model.hparams >>> assert "input_sensorchan" in model.hparams >>> assert "tokenizer" in model.hparams
- classmethod add_argparse_args(parent_parser)[source]¶
Only required for backwards compatibility until lightning CLI is the primary entry point.
Example
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> from geowatch.utils.configargparse_ext import ArgumentParser >>> cls = MultimodalTransformer >>> parent_parser = ArgumentParser(formatter_class='defaults') >>> cls.add_argparse_args(parent_parser) >>> parent_parser.print_help() >>> parent_parser.parse_known_args()
print(scfg.Config.port_argparse(parent_parser, style=’dataconf’))
- classmethod compatible(cfgdict)[source]¶
Given keyword arguments, find the subset that is compatible with this constructor. This is somewhat hacked because of usage of scriptconfig, but could be made nicer by future updates. # init_kwargs = ub.compatible(config, cls.__init__)
- configure_optimizers()[source]¶
Todo
[ ] Enable use of other optimization algorithms on the CLI
[ ] Enable use of other scheduler algorithms on the CLI
Note
Is this even called when using LightningCLI? Nope, the LightningCLI overwrites it.
References
https://pytorch-optimizer.readthedocs.io/en/latest/index.html https://pytorch-lightning.readthedocs.io/en/stable/common/optimization.html
Example
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # noqa >>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings >>> disable_lightning_hardware_warnings() >>> self = MultimodalTransformer(arch_name="smt_it_joint_p2", input_sensorchan='r|g|b') >>> max_epochs = 80 >>> self.trainer = pl.Trainer(max_epochs=max_epochs) >>> [opt], [sched] = self.configure_optimizers() >>> rows = [] >>> # Insepct what the LR curve will look like >>> for _ in range(max_epochs): ... sched.last_epoch += 1 ... lr = sched.get_last_lr()[0] ... rows.append({'lr': lr, 'last_epoch': sched.last_epoch}) >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> import pandas as pd >>> data = pd.DataFrame(rows) >>> sns = kwplot.autosns() >>> sns.lineplot(data=data, y='lr', x='last_epoch')
Example
>>> # Verify lr and decay is set correctly >>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> my_lr = 2.3e-5 >>> my_decay = 2.3e-5 >>> kw = dict(arch_name="smt_it_joint_p2", input_sensorchan='r|g|b', learning_rate=my_lr, weight_decay=my_decay) >>> self = MultimodalTransformer(**kw) >>> [opt], [sched] = self.configure_optimizers() >>> assert opt.param_groups[0]['lr'] == my_lr >>> assert opt.param_groups[0]['weight_decay'] == my_decay >>> # >>> self = MultimodalTransformer(**kw, optimizer='sgd') >>> [opt], [sched] = self.configure_optimizers() >>> assert opt.param_groups[0]['lr'] == my_lr >>> assert opt.param_groups[0]['weight_decay'] == my_decay >>> # >>> self = MultimodalTransformer(**kw, optimizer='AdamW') >>> [opt], [sched] = self.configure_optimizers() >>> assert opt.param_groups[0]['lr'] == my_lr >>> assert opt.param_groups[0]['weight_decay'] == my_decay >>> # >>> # self = MultimodalTransformer(**kw, optimizer='MADGRAD') >>> # [opt], [sched] = self.configure_optimizers() >>> # assert opt.param_groups[0]['lr'] == my_lr >>> # assert opt.param_groups[0]['weight_decay'] == my_decay
- forward(batch)[source]¶
Example
>>> import pytest >>> pytest.skip('not currently used') >>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> from geowatch.tasks.fusion import datamodules >>> channels = 'B1,B8|B8a,B10|B11' >>> channels = 'B1|B8|B10|B8a|B11' >>> datamodule = datamodules.KWCocoVideoDataModule( >>> train_dataset='special:vidshapes8-multispectral', num_workers=0, channels=channels) >>> datamodule.setup('fit') >>> train_dataset = datamodule.torch_datasets['train'] >>> dataset_stats = train_dataset.cached_dataset_stats() >>> loader = datamodule.train_dataloader() >>> tokenizer = 'convexpt-v1' >>> tokenizer = 'dwcnn' >>> batch = next(iter(loader)) >>> #self = MultimodalTransformer(arch_name='smt_it_joint_p8') >>> self = MultimodalTransformer( >>> arch_name='smt_it_joint_p8', >>> dataset_stats=dataset_stats, >>> change_loss='dicefocal', >>> decoder='dicefocal', >>> attention_impl='performer', >>> tokenizer=tokenizer, >>> ) >>> #images = torch.stack([ub.peek(f['modes'].values()) for f in batch[0]['frames']])[None, :] >>> #images.shape >>> #self.forward(images)
- forward_item(item, with_loss=False)[source]¶
CommandLine
xdoctest -m geowatch.tasks.fusion.methods.channelwise_transformer MultimodalTransformer.forward_item:1
Example
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> channels, classes, dataset_stats = MultimodalTransformer.demo_dataset_stats() >>> self = MultimodalTransformer( >>> arch_name='smt_it_stm_p1', tokenizer='linconv', >>> decoder='segmenter', classes=classes, global_saliency_weight=1, >>> dataset_stats=dataset_stats, input_sensorchan=channels) >>> item = self.demo_batch(width=64, height=65)[0] >>> outputs = self.forward_item(item, with_loss=True) >>> print('item') >>> from geowatch.utils.util_netharn import _debug_inbatch_shapes >>> print(_debug_inbatch_shapes(item)) >>> print('outputs') >>> print(_debug_inbatch_shapes(outputs))
Example
>>> # Box head >>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> channels, classes, dataset_stats = MultimodalTransformer.demo_dataset_stats() >>> self = MultimodalTransformer( >>> arch_name='smt_it_stm_p1', tokenizer='linconv', >>> decoder='mlp', classes=classes, global_saliency_weight=1, >>> dataset_stats=dataset_stats, input_sensorchan=channels, >>> decouple_resolution=False, global_box_weight=1) >>> batch = self.demo_batch(width=64, height=64, num_timesteps=3) >>> item = batch[0] >>> from geowatch.utils.util_netharn import _debug_inbatch_shapes >>> print(_debug_inbatch_shapes(batch)) >>> result1 = self.forward_step(batch, with_loss=True) >>> assert len(result1['box']) >>> assert 'box_ltrb' in result1['box'][0] >>> assert len(result1['box'][0]['box_ltrb'].shape) == 3 >>> print(_debug_inbatch_shapes(result1)) >>> # Check we can go backward >>> result1['loss'].backward()
Example
>>> # Decoupled resolutions >>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> channels, classes, dataset_stats = MultimodalTransformer.demo_dataset_stats() >>> self = MultimodalTransformer( >>> arch_name='smt_it_stm_p1', tokenizer='linconv', >>> decoder='mlp', classes=classes, global_saliency_weight=1, >>> dataset_stats=dataset_stats, input_sensorchan=channels, >>> decouple_resolution=True, global_box_weight=0) >>> batch = self.demo_batch(width=(11, 21), height=(16, 64), num_timesteps=3) >>> item = batch[0] >>> from geowatch.utils.util_netharn import _debug_inbatch_shapes >>> print(_debug_inbatch_shapes(batch)) >>> result1 = self.forward_step(batch, with_loss=True) >>> print(_debug_inbatch_shapes(result1)) >>> # Check we can go backward >>> result1['loss'].backward()
- forward_step(batch, with_loss=False, stage='unspecified')[source]¶
Generic forward step used for test / train / validation
- Returns:
with keys for various predictions / losses
- Return type:
Dict
CommandLine
xdoctest -m geowatch.tasks.fusion.methods.channelwise_transformer MultimodalTransformer.forward_step
Example
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion import datamodules >>> import geowatch >>> datamodule = datamodules.KWCocoVideoDataModule( >>> train_dataset='special:vidshapes-geowatch', >>> num_workers=0, chip_size=96, time_steps=4, >>> normalize_inputs=8, neg_to_pos_ratio=0, batch_size=5, >>> channels='auto', >>> ) >>> datamodule.setup('fit') >>> train_dset = datamodule.torch_datasets['train'] >>> loader = datamodule.train_dataloader() >>> batch = next(iter(loader)) >>> # Test with "failed samples" >>> batch[0] = None >>> batch[2] = None >>> batch[3] = None >>> batch[4] = None >>> if 1: >>> from geowatch.utils.util_netharn import _debug_inbatch_shapes >>> print(_debug_inbatch_shapes(batch)) >>> # Choose subclass to test this with (does not cover all cases) >>> self = model = methods.MultimodalTransformer( >>> arch_name='smt_it_joint_p8', tokenizer='rearrange', >>> decoder='segmenter', >>> dataset_stats=datamodule.dataset_stats, global_saliency_weight=1.0, global_change_weight=1.0, global_class_weight=1.0, >>> classes=datamodule.predictable_classes, input_sensorchan=datamodule.input_sensorchan) >>> with_loss = True >>> outputs = self.forward_step(batch, with_loss=with_loss) >>> canvas = datamodule.draw_batch(batch, outputs=outputs, max_items=3, overlay_on_image=False) >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(canvas) >>> kwplot.show_if_requested()
Example
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> channels, classes, dataset_stats = MultimodalTransformer.demo_dataset_stats() >>> self = MultimodalTransformer( >>> arch_name='smt_it_stm_p1', tokenizer='linconv', >>> decoder='segmenter', classes=classes, global_saliency_weight=1, >>> dataset_stats=dataset_stats, input_sensorchan=channels) >>> batch = self.demo_batch() >>> outputs = self.forward_step(batch, with_loss=True) >>> from geowatch.utils.util_netharn import _debug_inbatch_shapes >>> print(_debug_inbatch_shapes(batch)) >>> print(_debug_inbatch_shapes(outputs))
Example
>>> # Test learned_linear multimodal reduce >>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> channels, classes, dataset_stats = MultimodalTransformer.demo_dataset_stats() >>> self = MultimodalTransformer( >>> arch_name='smt_it_stm_p1', tokenizer='linconv', >>> decoder='mlp', classes=classes, global_saliency_weight=1, >>> dataset_stats=dataset_stats, input_sensorchan=channels, multimodal_reduce='learned_linear') >>> batch = self.demo_batch() >>> outputs = self.forward_step(batch, with_loss=True) >>> from geowatch.utils.util_netharn import _debug_inbatch_shapes >>> print(_debug_inbatch_shapes(batch)) >>> print(_debug_inbatch_shapes(outputs)) >>> # outputs['loss'].backward()
- overfit(batch)[source]¶
Overfit script and demo
CommandLine
python -m xdoctest -m geowatch.tasks.fusion.methods.channelwise_transformer MultimodalTransformer.overfit --overfit-demo
Example
>>> # xdoctest: +REQUIRES(--overfit-demo) >>> # ============ >>> # DEMO OVERFIT: >>> # ============ >>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion import datamodules >>> from geowatch.utils.util_data import find_dvc_dpath >>> import geowatch >>> import kwcoco >>> from os.path import join >>> import os >>> if 1: >>> print(''' ... # Generate toy datasets ... DATA_DPATH=$HOME/data/work/toy_change ... TRAIN_FPATH=$DATA_DPATH/vidshapes_msi_train/data.kwcoco.json ... mkdir -p "$DATA_DPATH" ... kwcoco toydata --key=vidshapes-videos8-frames5-randgsize-speed0.2-msi-multisensor --bundle_dpath "$DATA_DPATH/vidshapes_msi_train" --verbose=5 ... ''') >>> coco_fpath = ub.expandpath('$HOME/data/work/toy_change/vidshapes_msi_train/data.kwcoco.json') >>> coco_fpath = 'vidshapes-videos8-frames5-randgsize-speed0.2-msi-multisensor' >>> coco_dset = kwcoco.CocoDataset.coerce(coco_fpath) >>> channels="B11,r|g|b,B1|B8|B11" >>> if 0: >>> dvc_dpath = geowatch.find_dvc_dpath(tags='phase2_data', hardware='auto') >>> coco_dset = (dvc_dpath / 'Drop4-BAS') / 'data_vali.kwcoco.json' >>> channels='swir16|swir22|blue|green|red|nir' >>> coco_dset = (dvc_dpath / 'Drop4-BAS') / 'combo_vali_I2.kwcoco.json' >>> channels='blue|green|red|nir,invariants.0:17' >>> if 0: >>> coco_dset = geowatch.demo.demo_kwcoco_multisensor(max_speed=0.5) >>> # coco_dset = 'special:vidshapes8-frames9-speed0.5-multispectral' >>> #channels='B1|B11|B8|r|g|b|gauss' >>> channels='X.2|Y:2:6,B1|B8|B8a|B10|B11,r|g|b,disparity|gauss,flowx|flowy|distri' >>> coco_dset = kwcoco.CocoDataset.coerce(coco_dset) >>> datamodule = datamodules.KWCocoVideoDataModule( >>> train_dataset=coco_dset, >>> chip_size=128, batch_size=1, time_steps=5, >>> channels=channels, >>> normalize_peritem='blue|green|red|nir', >>> normalize_inputs=32, neg_to_pos_ratio=0, >>> num_workers='avail/2', >>> mask_low_quality=True, >>> observable_threshold=0.6, >>> use_grid_positives=False, use_centered_positives=True, >>> ) >>> datamodule.setup('fit') >>> dataset = torch_dset = datamodule.torch_datasets['train'] >>> torch_dset.disable_augmenter = True >>> dataset_stats = datamodule.dataset_stats >>> input_sensorchan = datamodule.input_sensorchan >>> classes = datamodule.classes >>> print('dataset_stats = {}'.format(ub.urepr(dataset_stats, nl=3))) >>> print('input_sensorchan = {}'.format(input_sensorchan)) >>> print('classes = {}'.format(classes)) >>> # Choose subclass to test this with (does not cover all cases) >>> self = methods.MultimodalTransformer( >>> # =========== >>> # Backbone >>> #arch_name='smt_it_joint_p2', >>> arch_name='smt_it_stm_p8', >>> stream_channels = 16, >>> #arch_name='deit', >>> optimizer='AdamW', >>> learning_rate=1e-5, >>> weight_decay=1e-3, >>> #attention_impl='performer', >>> attention_impl='exact', >>> #decoder='segmenter', >>> #saliency_head_hidden=4, >>> decoder='mlp', >>> change_loss='dicefocal', >>> #class_loss='cce', >>> class_loss='dicefocal', >>> #saliency_loss='dicefocal', >>> saliency_loss='focal', >>> # =========== >>> # Change Loss >>> global_change_weight=1e-5, >>> positive_change_weight=1.0, >>> negative_change_weight=0.5, >>> # =========== >>> # Class Loss >>> global_class_weight=1e-5, >>> class_weights='auto', >>> # =========== >>> # Saliency Loss >>> global_saliency_weight=1.00, >>> # =========== >>> # Domain Metadata (Look Ma, not hard coded!) >>> dataset_stats=dataset_stats, >>> classes=classes, >>> input_sensorchan=input_sensorchan, >>> #tokenizer='dwcnn', >>> tokenizer='linconv', >>> multimodal_reduce='learned_linear', >>> #tokenizer='rearrange', >>> # normalize_perframe=True, >>> window_size=8, >>> ) >>> self.datamodule = datamodule >>> datamodule._notify_about_tasks(model=self) >>> # Run one visualization >>> loader = datamodule.train_dataloader() >>> # Load one batch and show it before we do anything >>> batch = next(iter(loader)) >>> print(ub.urepr(dataset.summarize_item(batch[0]), nl=3)) >>> import kwplot >>> plt = kwplot.autoplt(force='Qt5Agg') >>> plt.ion() >>> canvas = datamodule.draw_batch(batch, max_channels=5, overlay_on_image=0) >>> kwplot.imshow(canvas, fnum=1) >>> # Run overfit >>> device = 0 >>> self.overfit(batch)
- save_package(package_path, context=None, verbose=1)[source]¶
Save model architecture and checkpoint as a torch package.
- Parameters:
package_path (str | PathLike) – where to save the package
context (Any) – custom json-serializable data to save in the header
verbose (int) – verbosity level
CommandLine
xdoctest -m geowatch.tasks.fusion.methods.channelwise_transformer MultimodalTransformer.save_package
Example
>>> # Test without datamodule >>> import ubelt as ub >>> from os.path import join >>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir() >>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test >>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion import datamodules >>> model = self = methods.MultimodalTransformer( >>> arch_name="smt_it_joint_p2", input_sensorchan=5, >>> change_head_hidden=0, saliency_head_hidden=0, >>> class_head_hidden=0)
>>> # Save the model (TODO: need to save datamodule as well) >>> model.save_package(package_path)
>>> # Test that the package can be reloaded >>> recon = methods.MultimodalTransformer.load_package(package_path) >>> # Check consistency and data is actually different >>> recon_state = recon.state_dict() >>> model_state = model.state_dict() >>> assert recon is not model >>> assert set(recon_state) == set(recon_state) >>> for key in recon_state.keys(): >>> assert (model_state[key] == recon_state[key]).all() >>> assert model_state[key] is not recon_state[key]
- class geowatch.tasks.fusion.methods.HeterogeneousModel(classes=10, dataset_stats=None, input_sensorchan=None, name: str = 'unnamed_model', position_encoder: str | ScaleAwarePositionalEncoder = 'auto', backbone: str | BackboneEncoderDecoder = 'auto', token_width: int = 10, token_dim: int = 16, spatial_scale_base: float = 1.0, temporal_scale_base: float = 1.0, class_weights: str = 'auto', saliency_weights: str = 'auto', positive_change_weight: float = 1.0, negative_change_weight: float = 1.0, global_class_weight: float = 1.0, global_change_weight: float = 1.0, global_saliency_weight: float = 1.0, change_loss: str = 'cce', class_loss: str = 'focal', saliency_loss: str = 'focal', tokenizer: str = 'simple_conv', decoder: str = 'upsample', ohem_ratio: float | None = None, focal_gamma: float | None = 2.0)[source]¶
Bases:
LightningModule
,WatchModuleMixins
- Parameters:
name – Specify a name for the experiment. (Unsure if the Model is the place for this)
token_width – Width of each square token.
token_dim – Dimensionality of each computed token.
spatial_scale_base – The scale assigned to each token equals scale_base / token_density, where the token density is the number of tokens along a given axis.
temporal_scale_base – The scale assigned to each token equals scale_base / token_density, where the token density is the number of tokens along a given axis.
class_weights – Class weighting strategy.
saliency_weights – Class weighting strategy.
Example
>>> # Note: it is important that the non-kwargs are saved as hyperparams >>> from geowatch.tasks.fusion.methods.heterogeneous import HeterogeneousModel, ScaleAgnostictPositionalEncoder >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = ScaleAgnostictPositionalEncoder(3, 8) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> model = HeterogeneousModel( >>> input_sensorchan='r|g|b', >>> position_encoder=position_encoder, >>> backbone=backbone, >>> )
- configure_optimizers()¶
Note: this is only a fallback for testing purposes. This should be overwrriten in your module or done via lightning CLI.
- forward(batch)[source]¶
Example
>>> from geowatch.tasks import fusion >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> backbone=backbone, >>> position_encoder=position_encoder, >>> ) >>> batch = model.demo_batch(width=64, height=65) >>> batch += model.demo_batch(width=55, height=75) >>> outputs = model.forward(batch) >>> for task_key, task_outputs in outputs.items(): >>> if "probs" in task_key: continue >>> if task_key == "class": task_key = "class_idxs" >>> for task_pred, example in zip(task_outputs, batch): >>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])): >>> if (frame_idx == 0) and task_key.startswith("change"): continue >>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
Example
>>> from geowatch.tasks import fusion >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> backbone=backbone, >>> decoder="simple_conv", >>> ) >>> batch = model.demo_batch(width=64, height=65) >>> batch += model.demo_batch(width=55, height=75) >>> outputs = model.forward(batch) >>> for task_key, task_outputs in outputs.items(): >>> if "probs" in task_key: continue >>> if task_key == "class": task_key = "class_idxs" >>> for task_pred, example in zip(task_outputs, batch): >>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])): >>> if (frame_idx == 0) and task_key.startswith("change"): continue >>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
Example
>>> from geowatch.tasks import fusion >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> backbone=backbone, >>> decoder="trans_conv", >>> ) >>> batch = model.demo_batch(width=64, height=65) >>> batch += model.demo_batch(width=55, height=75) >>> outputs = model.forward(batch) >>> for task_key, task_outputs in outputs.items(): >>> if "probs" in task_key: continue >>> if task_key == "class": task_key = "class_idxs" >>> for task_pred, example in zip(task_outputs, batch): >>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])): >>> if (frame_idx == 0) and task_key.startswith("change"): continue >>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
Example
>>> from geowatch.tasks import fusion >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=0, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=0, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> backbone=backbone, >>> decoder="trans_conv", >>> ) >>> batch = model.demo_batch(width=64, height=65) >>> batch += model.demo_batch(width=55, height=75) >>> outputs = model.forward(batch) >>> for task_key, task_outputs in outputs.items(): >>> if "probs" in task_key: continue >>> if task_key == "class": task_key = "class_idxs" >>> for task_pred, example in zip(task_outputs, batch): >>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])): >>> if (frame_idx == 0) and task_key.startswith("change"): continue >>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
Example
>>> # xdoctest: +REQUIRES(module:mmseg) >>> from geowatch.tasks import fusion >>> from geowatch.tasks.fusion.architectures.transformer import MM_VITEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = MM_VITEncoderDecoder( >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> backbone=backbone, >>> position_encoder=position_encoder, >>> ) >>> batch = model.demo_batch(width=64, height=65) >>> batch += model.demo_batch(width=55, height=75) >>> outputs = model.forward(batch) >>> for task_key, task_outputs in outputs.items(): >>> if "probs" in task_key: continue >>> if task_key == "class": task_key = "class_idxs" >>> for task_pred, example in zip(task_outputs, batch): >>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])): >>> if (frame_idx == 0) and task_key.startswith("change"): continue >>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
Example
>>> # xdoctest: +REQUIRES(module:mmseg) >>> from geowatch.tasks import fusion >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> self = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> #token_dim=708, >>> token_dim=768 - 60, >>> backbone='vit_B_16_imagenet1k', >>> position_encoder=position_encoder, >>> ) >>> batch = self.demo_batch(width=64, height=65) >>> batch += self.demo_batch(width=55, height=75) >>> outputs = self.forward(batch) >>> for task_key, task_outputs in outputs.items(): >>> if "probs" in task_key: continue >>> if task_key == "class": task_key = "class_idxs" >>> for task_pred, example in zip(task_outputs, batch): >>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])): >>> if (frame_idx == 0) and task_key.startswith("change"): continue >>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
- forward_step(batch, batch_idx=None, stage='train', with_loss=True)¶
Example
>>> from geowatch.tasks import fusion >>> import torch >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> decoder="trans_conv", >>> backbone=backbone, >>> ) >>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3) >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
Example
>>> from geowatch.tasks import fusion >>> import torch >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> decoder="trans_conv", >>> backbone=backbone, >>> ) >>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3) >>> batch += [None] >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
Example
>>> from geowatch.tasks import fusion >>> import torch >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> decoder="trans_conv", >>> backbone=backbone, >>> ) >>> batch = model.demo_batch(width=64, height=65) >>> for cutoff in [-1, -2]: >>> degraded_example = model.demo_batch(width=55, height=75, num_timesteps=3)[0] >>> degraded_example["frames"] = degraded_example["frames"][:cutoff] >>> batch += [degraded_example] >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
Example
>>> from geowatch.tasks import fusion >>> import torch >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> decoder="trans_conv", >>> backbone=backbone, >>> ) >>> batch = model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=0.1) >>> batch += model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=0.5) >>> batch += model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=1.0) >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
- log_grad_norm(grad_norm_dict) None [source]¶
Override this method to change the default behaviour of
log_grad_norm
.Overloads log_grad_norm so we can supress the batch_size warning
- process_input_tokens(example)[source]¶
Example
>>> from geowatch.tasks import fusion >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> backbone=backbone, >>> ) >>> example = model.demo_batch(width=64, height=65)[0] >>> input_tokens = model.process_input_tokens(example) >>> assert len(input_tokens) == len(example["frames"]) >>> assert len(input_tokens[0]) == len(example["frames"][0]["modes"])
- process_query_tokens(example)[source]¶
Example
>>> from geowatch.tasks import fusion >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> backbone=backbone, >>> ) >>> example = model.demo_batch(width=64, height=65)[0] >>> query_tokens = model.process_query_tokens(example) >>> assert len(query_tokens) == len(example["frames"])
- save_package(package_path, context=None, verbose=1)[source]¶
CommandLine
xdoctest -m geowatch.tasks.fusion.methods.heterogeneous HeterogeneousModel.save_package
Example
>>> # Test without datamodule >>> import ubelt as ub >>> from os.path import join >>> #from geowatch.tasks.fusion.methods.heterogeneous import * # NOQA >>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir() >>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test >>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion import datamodules >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> model = self = methods.HeterogeneousModel( >>> position_encoder=position_encoder, >>> input_sensorchan=5, >>> decoder="upsample", >>> backbone=backbone, >>> )
>>> # Save the model (TODO: need to save datamodule as well) >>> model.save_package(package_path)
>>> # Test that the package can be reloaded >>> #recon = methods.HeterogeneousModel.load_package(package_path) >>> from geowatch.tasks.fusion.utils import load_model_from_package >>> recon = load_model_from_package(package_path) >>> # Check consistency and data is actually different >>> recon_state = recon.state_dict() >>> model_state = model.state_dict() >>> assert recon is not model >>> assert set(recon_state) == set(recon_state) >>> for key in recon_state.keys(): >>> assert (model_state[key] == recon_state[key]).all() >>> assert model_state[key] is not recon_state[key]
Example
>>> # Test without datamodule >>> import ubelt as ub >>> from os.path import join >>> #from geowatch.tasks.fusion.methods.heterogeneous import * # NOQA >>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir() >>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test >>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion import datamodules >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> model = self = methods.HeterogeneousModel( >>> position_encoder=position_encoder, >>> input_sensorchan=5, >>> decoder="simple_conv", >>> backbone=backbone, >>> )
>>> # Save the model (TODO: need to save datamodule as well) >>> model.save_package(package_path)
>>> # Test that the package can be reloaded >>> #recon = methods.HeterogeneousModel.load_package(package_path) >>> from geowatch.tasks.fusion.utils import load_model_from_package >>> recon = load_model_from_package(package_path) >>> # Check consistency and data is actually different >>> recon_state = recon.state_dict() >>> model_state = model.state_dict() >>> assert recon is not model >>> assert set(recon_state) == set(recon_state) >>> for key in recon_state.keys(): >>> assert (model_state[key] == recon_state[key]).all() >>> assert model_state[key] is not recon_state[key]
Example
>>> # Test without datamodule >>> import ubelt as ub >>> from os.path import join >>> #from geowatch.tasks.fusion.methods.heterogeneous import * # NOQA >>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir() >>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test >>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion import datamodules >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> model = self = methods.HeterogeneousModel( >>> position_encoder=position_encoder, >>> input_sensorchan=5, >>> decoder="trans_conv", >>> backbone=backbone, >>> )
>>> # Save the model (TODO: need to save datamodule as well) >>> model.save_package(package_path)
>>> # Test that the package can be reloaded >>> #recon = methods.HeterogeneousModel.load_package(package_path) >>> from geowatch.tasks.fusion.utils import load_model_from_package >>> recon = load_model_from_package(package_path) >>> # Check consistency and data is actually different >>> recon_state = recon.state_dict() >>> model_state = model.state_dict() >>> assert recon is not model >>> assert set(recon_state) == set(recon_state) >>> for key in recon_state.keys(): >>> assert (model_state[key] == recon_state[key]).all() >>> assert model_state[key] is not recon_state[key]
Example
>>> from geowatch.tasks import fusion >>> import torch >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> decoder="trans_conv", >>> backbone=backbone, >>> ) >>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3) >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
Example
>>> from geowatch.tasks import fusion >>> import torch >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> decoder="trans_conv", >>> backbone=backbone, >>> ) >>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3) >>> batch += [None] >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
Example
>>> from geowatch.tasks import fusion >>> import torch >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> decoder="trans_conv", >>> backbone=backbone, >>> ) >>> batch = model.demo_batch(width=64, height=65) >>> for cutoff in [-1, -2]: >>> degraded_example = model.demo_batch(width=55, height=75, num_timesteps=3)[0] >>> degraded_example["frames"] = degraded_example["frames"][:cutoff] >>> batch += [degraded_example] >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
Example
>>> from geowatch.tasks import fusion >>> import torch >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> decoder="trans_conv", >>> backbone=backbone, >>> ) >>> batch = model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=0.1) >>> batch += model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=0.5) >>> batch += model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=1.0) >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
- class geowatch.tasks.fusion.methods.UNetBaseline(classes=10, dataset_stats=None, input_sensorchan=None, token_dim: int = 32, name: str = 'unnamed_model', class_weights: str = 'auto', saliency_weights: str = 'auto', positive_change_weight: float = 1.0, negative_change_weight: float = 1.0, global_class_weight: float = 1.0, global_change_weight: float = 1.0, global_saliency_weight: float = 1.0, change_loss: str = 'cce', class_loss: str = 'focal', saliency_loss: str = 'focal', ohem_ratio: float | None = None, focal_gamma: float | None = 2.0)[source]¶
Bases:
LightningModule
,WatchModuleMixins
- Parameters:
name – Specify a name for the experiment. (Unsure if the Model is the place for this)
token_width – Width of each square token.
token_dim – Dimensionality of each computed token.
spatial_scale_base – The scale assigned to each token equals scale_base / token_density, where the token density is the number of tokens along a given axis.
temporal_scale_base – The scale assigned to each token equals scale_base / token_density, where the token density is the number of tokens along a given axis.
class_weights – Class weighting strategy.
saliency_weights – Class weighting strategy.
Example
>>> # Note: it is important that the non-kwargs are saved as hyperparams >>> from geowatch.tasks.fusion.methods.unet_baseline import UNetBaseline >>> model = UNetBaseline( >>> input_sensorchan='r|g|b', >>> )
- forward(batch)[source]¶
Example
>>> from geowatch.tasks import fusion >>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats() >>> model = fusion.methods.UNetBaseline( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> ) >>> batch = model.demo_batch(width=64, height=64) >>> outputs = model.forward(batch) >>> for task_key, task_outputs in outputs.items(): >>> if "probs" in task_key: continue >>> if task_key == "class": task_key = "class_idxs" >>> for task_pred, example in zip(task_outputs, batch): >>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])): >>> if (frame_idx == 0) and task_key.startswith("change"): continue >>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
- forward_step(batch, batch_idx=None, stage='train', with_loss=True)¶
Example
>>> # xdoctest: +REQUIRES(env:SLOW_TESTS) >>> from geowatch.tasks import fusion >>> import torch >>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats() >>> model = fusion.methods.UNetBaseline( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> ) >>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3) >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
Example
>>> # xdoctest: +REQUIRES(env:SLOW_TESTS) >>> from geowatch.tasks import fusion >>> import torch >>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats() >>> model = fusion.methods.UNetBaseline( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> ) >>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3) >>> batch += [None] >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
Example
>>> from geowatch.tasks import fusion >>> import torch >>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats() >>> model = fusion.methods.UNetBaseline( >>> classes=classes, token_dim=2, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> ) >>> batch = model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=0.1) >>> batch += model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=0.5) >>> batch += model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=1.0) >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
- save_package(package_path, context=None, verbose=1)[source]¶
CommandLine
xdoctest -m geowatch.tasks.fusion.methods.unet_baseline UNetBaseline.save_package
Example
>>> # Test without datamodule >>> import ubelt as ub >>> from os.path import join >>> #from geowatch.tasks.fusion.methods.unet_baseline import * # NOQA >>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir() >>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test >>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion import datamodules >>> model = self = methods.UNetBaseline( >>> input_sensorchan=5, >>> )
>>> # Save the model (TODO: need to save datamodule as well) >>> model.save_package(package_path)
>>> # Test that the package can be reloaded >>> #recon = methods.UNetBaseline.load_package(package_path) >>> from geowatch.tasks.fusion.utils import load_model_from_package >>> recon = load_model_from_package(package_path) >>> # Check consistency and data is actually different >>> recon_state = recon.state_dict() >>> model_state = model.state_dict() >>> assert recon is not self >>> assert set(recon_state) == set(recon_state) >>> from geowatch.utils.util_kwarray import torch_array_equal >>> for key in recon_state.keys(): >>> v1 = model_state[key] >>> v2 = recon_state[key] >>> if not torch.allclose(v1, v2, equal_nan=True): >>> print('v1 = {}'.format(ub.urepr(v1, nl=1))) >>> print('v2 = {}'.format(ub.urepr(v2, nl=1))) >>> raise AssertionError(f'Difference in key={key}') >>> assert v1 is not v2, 'should be distinct copies'
Example
>>> # Test without datamodule >>> import ubelt as ub >>> from os.path import join >>> #from geowatch.tasks.fusion.methods.unet_baseline import * # NOQA >>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir() >>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test >>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion import datamodules >>> model = self = methods.UNetBaseline( >>> input_sensorchan=5, >>> )
>>> # Save the model (TODO: need to save datamodule as well) >>> model.save_package(package_path)
>>> # Test that the package can be reloaded >>> #recon = methods.UNetBaseline.load_package(package_path) >>> from geowatch.tasks.fusion.utils import load_model_from_package >>> recon = load_model_from_package(package_path) >>> # Check consistency and data is actually different >>> recon_state = recon.state_dict() >>> model_state = self.state_dict() >>> assert recon is not self >>> assert set(recon_state) == set(recon_state) >>> from geowatch.utils.util_kwarray import torch_array_equal >>> for key in recon_state.keys(): >>> v1 = model_state[key] >>> v2 = recon_state[key] >>> if not torch.allclose(v1, v2, equal_nan=True): >>> print('v1 = {}'.format(ub.urepr(v1, nl=1))) >>> print('v2 = {}'.format(ub.urepr(v2, nl=1))) >>> raise AssertionError(f'Difference in key={key}') >>> assert v1 is not v2, 'should be distinct copies'
Example
>>> # Test without datamodule >>> import ubelt as ub >>> from os.path import join >>> #from geowatch.tasks.fusion.methods.unet_baseline import * # NOQA >>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir() >>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test >>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion import datamodules >>> model = self = methods.UNetBaseline( >>> input_sensorchan=5, >>> )
>>> # Save the model (TODO: need to save datamodule as well) >>> model.save_package(package_path)
>>> # Test that the package can be reloaded >>> #recon = methods.UNetBaseline.load_package(package_path) >>> from geowatch.tasks.fusion.utils import load_model_from_package >>> recon = load_model_from_package(package_path) >>> # Check consistency and data is actually different >>> recon_state = recon.state_dict() >>> model_state = self.state_dict() >>> assert recon is not self >>> assert set(recon_state) == set(recon_state) >>> from geowatch.utils.util_kwarray import torch_array_equal >>> for key in recon_state.keys(): >>> v1 = model_state[key] >>> v2 = recon_state[key] >>> if not torch.allclose(v1, v2, equal_nan=True): >>> print('v1 = {}'.format(ub.urepr(v1, nl=1))) >>> print('v2 = {}'.format(ub.urepr(v2, nl=1))) >>> raise AssertionError(f'Difference in key={key}') >>> assert v1 is not v2, 'should be distinct copies'
Example
>>> # xdoctest: +REQUIRES(env:SLOW_TESTS) >>> from geowatch.tasks import fusion >>> import torch >>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats() >>> model = fusion.methods.UNetBaseline( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> ) >>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3) >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
Example
>>> # xdoctest: +REQUIRES(env:SLOW_TESTS) >>> from geowatch.tasks import fusion >>> import torch >>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats() >>> model = fusion.methods.UNetBaseline( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> ) >>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3) >>> batch += [None] >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
Example
>>> from geowatch.tasks import fusion >>> import torch >>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats() >>> model = fusion.methods.UNetBaseline( >>> classes=classes, token_dim=2, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> ) >>> batch = model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=0.1) >>> batch += model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=0.5) >>> batch += model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=1.0) >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
- class geowatch.tasks.fusion.methods.NoopModel(classes=10, dataset_stats=None, input_sensorchan=None, name: str = 'unnamed_model')[source]¶
Bases:
LightningModule
,WatchModuleMixins
No-op example model. Contains a dummy parameter to satisfy the optimizer and trainer.
Todo
[ ] Minimize even further.
[ ] Identify mandatory steps in __init__ and move to a parent class.
- Parameters:
name – Specify a name for the experiment. (Unsure if the Model is the place for this)
- forward_step(batch, batch_idx=None, with_loss=True)¶
- save_package(package_path, context=None, verbose=1)[source]¶
CommandLine
xdoctest -m geowatch.tasks.fusion.methods.noop_model NoopModel.save_package
Example
>>> # Test without datamodule >>> import ubelt as ub >>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir() >>> package_path = dpath / 'my_package.pt'
>>> # Use one of our fusion.architectures in a test >>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion import datamodules >>> model = self = methods.NoopModel( >>> input_sensorchan=5,)
>>> # Save the model (TODO: need to save datamodule as well) >>> model.save_package(package_path)
>>> # Test that the package can be reloaded >>> #recon = methods.NoopModel.load_package(package_path) >>> from geowatch.tasks.fusion.utils import load_model_from_package >>> recon = load_model_from_package(package_path) >>> # Check consistency and data is actually different >>> recon_state = recon.state_dict() >>> model_state = model.state_dict() >>> assert recon is not model >>> assert set(recon_state) == set(recon_state) >>> for key in recon_state.keys(): >>> assert (model_state[key] == recon_state[key]).all() >>> assert model_state[key] is not recon_state[key]
>>> # Check what's inside of the package >>> import zipfile >>> import json >>> zfile = zipfile.ZipFile(package_path) >>> header_file = zfile.open('my_package/package_header/package_header.json') >>> package_header = json.loads(header_file.read()) >>> print('package_header = {}'.format(ub.urepr(package_header, nl=1))) >>> assert 'version' in package_header >>> assert 'arch_name' in package_header >>> assert 'module_name' in package_header >>> assert 'packaging_time' in package_header >>> assert 'git_hash' in package_header >>> assert 'module_path' in package_header
Example
>>> # Test with datamodule >>> import ubelt as ub >>> from geowatch.tasks.fusion import datamodules >>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion.methods.noop_model import * # NOQA >>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir() >>> package_path = dpath / 'my_package.pt'
>>> datamodule = datamodules.kwcoco_video_data.KWCocoVideoDataModule( >>> train_dataset='special:vidshapes8-multispectral-multisensor', chip_size=32, >>> batch_size=1, time_steps=2, num_workers=2, normalize_inputs=10) >>> datamodule.setup('fit') >>> dataset_stats = datamodule.torch_datasets['train'].cached_dataset_stats(num=3) >>> classes = datamodule.torch_datasets['train'].classes
>>> # Use one of our fusion.architectures in a test >>> self = methods.NoopModel( >>> classes=classes, >>> dataset_stats=dataset_stats, input_sensorchan=datamodule.input_sensorchan)
>>> # We have to run an input through the module because it is lazy >>> batch = ub.peek(iter(datamodule.train_dataloader())) >>> outputs = self.training_step(batch)
>>> trainer = pl.Trainer(max_steps=0) >>> trainer.fit(model=self, datamodule=datamodule)
>>> # Save the self >>> self.save_package(package_path)
>>> # Test that the package can be reloaded >>> recon = methods.NoopModel.load_package(package_path)
>>> # Check consistency and data is actually different >>> recon_state = recon.state_dict() >>> model_state = self.state_dict() >>> assert recon is not self >>> assert set(recon_state) == set(recon_state) >>> from geowatch.utils.util_kwarray import torch_array_equal >>> for key in recon_state.keys(): >>> v1 = model_state[key] >>> v2 = recon_state[key] >>> if not torch.allclose(v1, v2, equal_nan=True): >>> print('v1 = {}'.format(ub.urepr(v1, nl=1))) >>> print('v2 = {}'.format(ub.urepr(v2, nl=1))) >>> raise AssertionError(f'Difference in key={key}') >>> assert v1 is not v2, 'should be distinct copies'
- training_step(batch, batch_idx=None, with_loss=True)¶