geowatch.tasks.fusion.methods.heterogeneous module

geowatch.tasks.fusion.methods.heterogeneous.to_next_multiple(n, mult)[source]

Example

>>> from geowatch.tasks.fusion.methods.heterogeneous import to_next_multiple
>>> x = to_next_multiple(11, 4)
>>> assert x == 1, f"x = {x}, should be 1"
geowatch.tasks.fusion.methods.heterogeneous.positions_from_shape(shape, dtype='float32', device='cpu')[source]
class geowatch.tasks.fusion.methods.heterogeneous.PadToMultiple(multiple: int, mode: str = 'constant', value=0.0)[source]

Bases: Module

Pads input image-shaped tensors following strategy defined by mode/value. All padding appended to bottom and right of input.

Parameters:

Example

>>> from geowatch.tasks.fusion.methods.heterogeneous import PadToMultiple
>>> import torch
>>> pad_module = PadToMultiple(4)
>>> inputs = torch.randn(1, 3, 10, 11)
>>> outputs = pad_module(inputs)
>>> assert outputs.shape == (1, 3, 12, 12), f"outputs.shape actually {outputs.shape}"

Example

>>> from geowatch.tasks.fusion.methods.heterogeneous import PadToMultiple
>>> import torch
>>> pad_module = PadToMultiple(4)
>>> inputs = torch.randn(3, 10, 11)
>>> outputs = pad_module(inputs)
>>> assert outputs.shape == (3, 12, 12), f"outputs.shape actually {outputs.shape}"

Example

>>> from geowatch.tasks.fusion.methods.heterogeneous import PadToMultiple
>>> from torch import nn
>>> import torch
>>> token_width = 10
>>> pad_module = nn.Sequential(
>>>         PadToMultiple(token_width, value=0.0),
>>>         nn.Conv2d(
>>>             3,
>>>             16,
>>>             kernel_size=token_width,
>>>             stride=token_width,
>>>         )
>>> )
>>> inputs = torch.randn(3, 64, 65)
>>> outputs = pad_module(inputs)
>>> assert outputs.shape == (16, 7, 7), f"outputs.shape actually {outputs.shape}"
forward(x)[source]
class geowatch.tasks.fusion.methods.heterogeneous.NanToNum(num=0.0)[source]

Bases: Module

Module which converts NaN values in input tensors to numbers.

forward(x)[source]
class geowatch.tasks.fusion.methods.heterogeneous.ShapePreservingTransformerEncoder(token_dim, num_layers, batch_dim=0, chan_dim=1)[source]

Bases: Module

forward(src, mask=None)[source]
class geowatch.tasks.fusion.methods.heterogeneous.ScaleAwarePositionalEncoder[source]

Bases: object

abstractmethod forward(mean, scale)[source]
class geowatch.tasks.fusion.methods.heterogeneous.MipNerfPositionalEncoder(in_dims: int, num_freqs: int = 10, max_freq: float = 4.0)[source]

Bases: Module, ScaleAwarePositionalEncoder

Module which computes MipNeRf-based positional encoding vectors from tensors of mean and scale values

out_dims = 2 * in_dims * num_freqs

Parameters:
  • in_dims – (int) number of input dimensions to expect for future calls to .forward(). Currently only needed for computing .output_dim

  • num_freqs – (int) number of frequencies to project dimensions onto.

Example

>>> from geowatch.tasks.fusion.methods.heterogeneous import MipNerfPositionalEncoder
>>> import torch
>>> pos_enc = MipNerfPositionalEncoder(3, 4)
>>> input_means = torch.randn(1, 3, 10, 10)
>>> input_scales = torch.randn(1, 3, 10, 10)
>>> outputs = pos_enc(input_means, input_scales)
>>> assert outputs.shape == (1, pos_enc.output_dim, 10, 10)
forward(mean, scale)[source]
class geowatch.tasks.fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(in_dims: int, num_freqs: int = 10, max_freq: float = 4.0)[source]

Bases: Module, ScaleAwarePositionalEncoder

Module which computes MipNeRf-based positional encoding vectors from tensors of mean and scale values

out_dims = 2 * in_dims * num_freqs

Parameters:
  • in_dims – (int) number of input dimensions to expect for future calls to .forward(). Currently only needed for computing .output_dim

  • num_freqs – (int) number of frequencies to project dimensions onto.

Example

>>> from geowatch.tasks.fusion.methods.heterogeneous import ScaleAgnostictPositionalEncoder
>>> import torch
>>> pos_enc = ScaleAgnostictPositionalEncoder(3, 4)
>>> input_means = torch.randn(1, 3, 10, 10)
>>> input_scales = torch.randn(1, 3, 10, 10)
>>> outputs = pos_enc(input_means, input_scales)
>>> assert outputs.shape == (1, pos_enc.output_dim, 10, 10)
forward(mean, scale)[source]
class geowatch.tasks.fusion.methods.heterogeneous.ResNetShim(submodule)[source]

Bases: Module

forward(x)[source]
class geowatch.tasks.fusion.methods.heterogeneous.HeterogeneousModel(classes=10, dataset_stats=None, input_sensorchan=None, name: str = 'unnamed_model', position_encoder: str | ScaleAwarePositionalEncoder = 'auto', backbone: str | BackboneEncoderDecoder = 'auto', token_width: int = 10, token_dim: int = 16, spatial_scale_base: float = 1.0, temporal_scale_base: float = 1.0, class_weights: str = 'auto', saliency_weights: str = 'auto', positive_change_weight: float = 1.0, negative_change_weight: float = 1.0, global_class_weight: float = 1.0, global_change_weight: float = 1.0, global_saliency_weight: float = 1.0, change_loss: str = 'cce', class_loss: str = 'focal', saliency_loss: str = 'focal', tokenizer: str = 'simple_conv', decoder: str = 'upsample', ohem_ratio: float | None = None, focal_gamma: float | None = 2.0)[source]

Bases: LightningModule, WatchModuleMixins

Parameters:
  • name – Specify a name for the experiment. (Unsure if the Model is the place for this)

  • token_width – Width of each square token.

  • token_dim – Dimensionality of each computed token.

  • spatial_scale_base – The scale assigned to each token equals scale_base / token_density, where the token density is the number of tokens along a given axis.

  • temporal_scale_base – The scale assigned to each token equals scale_base / token_density, where the token density is the number of tokens along a given axis.

  • class_weights – Class weighting strategy.

  • saliency_weights – Class weighting strategy.

Example

>>> # Note: it is important that the non-kwargs are saved as hyperparams
>>> from geowatch.tasks.fusion.methods.heterogeneous import HeterogeneousModel, ScaleAgnostictPositionalEncoder
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = ScaleAgnostictPositionalEncoder(3, 8)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> model = HeterogeneousModel(
>>>   input_sensorchan='r|g|b',
>>>   position_encoder=position_encoder,
>>>   backbone=backbone,
>>> )
get_cfgstr()[source]
process_input_tokens(example)[source]

Example

>>> from geowatch.tasks import fusion
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     backbone=backbone,
>>> )
>>> example = model.demo_batch(width=64, height=65)[0]
>>> input_tokens = model.process_input_tokens(example)
>>> assert len(input_tokens) == len(example["frames"])
>>> assert len(input_tokens[0]) == len(example["frames"][0]["modes"])
process_query_tokens(example)[source]

Example

>>> from geowatch.tasks import fusion
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     backbone=backbone,
>>> )
>>> example = model.demo_batch(width=64, height=65)[0]
>>> query_tokens = model.process_query_tokens(example)
>>> assert len(query_tokens) == len(example["frames"])
forward(batch)[source]

Example

>>> from geowatch.tasks import fusion
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     backbone=backbone,
>>>     position_encoder=position_encoder,
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> batch += model.demo_batch(width=55, height=75)
>>> outputs = model.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>>     if "probs" in task_key: continue
>>>     if task_key == "class": task_key = "class_idxs"
>>>     for task_pred, example in zip(task_outputs, batch):
>>>         for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>>             if (frame_idx == 0) and task_key.startswith("change"): continue
>>>             assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"

Example

>>> from geowatch.tasks import fusion
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     backbone=backbone,
>>>     decoder="simple_conv",
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> batch += model.demo_batch(width=55, height=75)
>>> outputs = model.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>>     if "probs" in task_key: continue
>>>     if task_key == "class": task_key = "class_idxs"
>>>     for task_pred, example in zip(task_outputs, batch):
>>>         for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>>             if (frame_idx == 0) and task_key.startswith("change"): continue
>>>             assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"

Example

>>> from geowatch.tasks import fusion
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     backbone=backbone,
>>>     decoder="trans_conv",
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> batch += model.demo_batch(width=55, height=75)
>>> outputs = model.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>>     if "probs" in task_key: continue
>>>     if task_key == "class": task_key = "class_idxs"
>>>     for task_pred, example in zip(task_outputs, batch):
>>>         for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>>             if (frame_idx == 0) and task_key.startswith("change"): continue
>>>             assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"

Example

>>> from geowatch.tasks import fusion
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=0,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=0,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     backbone=backbone,
>>>     decoder="trans_conv",
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> batch += model.demo_batch(width=55, height=75)
>>> outputs = model.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>>     if "probs" in task_key: continue
>>>     if task_key == "class": task_key = "class_idxs"
>>>     for task_pred, example in zip(task_outputs, batch):
>>>         for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>>             if (frame_idx == 0) and task_key.startswith("change"): continue
>>>             assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"

Example

>>> # xdoctest: +REQUIRES(module:mmseg)
>>> from geowatch.tasks import fusion
>>> from geowatch.tasks.fusion.architectures.transformer import MM_VITEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = MM_VITEncoderDecoder(
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     backbone=backbone,
>>>     position_encoder=position_encoder,
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> batch += model.demo_batch(width=55, height=75)
>>> outputs = model.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>>     if "probs" in task_key: continue
>>>     if task_key == "class": task_key = "class_idxs"
>>>     for task_pred, example in zip(task_outputs, batch):
>>>         for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>>             if (frame_idx == 0) and task_key.startswith("change"): continue
>>>             assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"

Example

>>> # xdoctest: +REQUIRES(module:mmseg)
>>> from geowatch.tasks import fusion
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> self = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     #token_dim=708,
>>>     token_dim=768 - 60,
>>>     backbone='vit_B_16_imagenet1k',
>>>     position_encoder=position_encoder,
>>> )
>>> batch = self.demo_batch(width=64, height=65)
>>> batch += self.demo_batch(width=55, height=75)
>>> outputs = self.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>>     if "probs" in task_key: continue
>>>     if task_key == "class": task_key = "class_idxs"
>>>     for task_pred, example in zip(task_outputs, batch):
>>>         for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>>             if (frame_idx == 0) and task_key.startswith("change"): continue
>>>             assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
shared_step(batch, batch_idx=None, stage='train', with_loss=True)[source]

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     decoder="trans_conv",
>>>     backbone=backbone,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     decoder="trans_conv",
>>>     backbone=backbone,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> batch += [None]
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     decoder="trans_conv",
>>>     backbone=backbone,
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> for cutoff in [-1, -2]:
>>>     degraded_example = model.demo_batch(width=55, height=75, num_timesteps=3)[0]
>>>     degraded_example["frames"] = degraded_example["frames"][:cutoff]
>>>     batch += [degraded_example]
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     decoder="trans_conv",
>>>     backbone=backbone,
>>> )
>>> batch = model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=0.1)
>>> batch += model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=0.5)
>>> batch += model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=1.0)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()
training_step(batch, batch_idx=None)[source]
validation_step(batch, batch_idx=None)[source]
test_step(batch, batch_idx=None)[source]
predict_step(batch, batch_idx=None)[source]
forward_step(batch, batch_idx=None, stage='train', with_loss=True)

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     decoder="trans_conv",
>>>     backbone=backbone,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     decoder="trans_conv",
>>>     backbone=backbone,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> batch += [None]
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     decoder="trans_conv",
>>>     backbone=backbone,
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> for cutoff in [-1, -2]:
>>>     degraded_example = model.demo_batch(width=55, height=75, num_timesteps=3)[0]
>>>     degraded_example["frames"] = degraded_example["frames"][:cutoff]
>>>     batch += [degraded_example]
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>>     position_encoder=position_encoder,
>>>     decoder="trans_conv",
>>>     backbone=backbone,
>>> )
>>> batch = model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=0.1)
>>> batch += model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=0.5)
>>> batch += model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=1.0)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()
log_grad_norm(grad_norm_dict) None[source]

Override this method to change the default behaviour of log_grad_norm.

Overloads log_grad_norm so we can supress the batch_size warning

save_package(package_path, context=None, verbose=1)[source]

CommandLine

xdoctest -m geowatch.tasks.fusion.methods.heterogeneous HeterogeneousModel.save_package

Example

>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> #from geowatch.tasks.fusion.methods.heterogeneous import *  # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> model = self = methods.HeterogeneousModel(
>>>     position_encoder=position_encoder,
>>>     input_sensorchan=5,
>>>     decoder="upsample",
>>>     backbone=backbone,
>>> )
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.HeterogeneousModel.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = model.state_dict()
>>> assert recon is not model
>>> assert set(recon_state) == set(recon_state)
>>> for key in recon_state.keys():
>>>     assert (model_state[key] == recon_state[key]).all()
>>>     assert model_state[key] is not recon_state[key]

Example

>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> #from geowatch.tasks.fusion.methods.heterogeneous import *  # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> model = self = methods.HeterogeneousModel(
>>>     position_encoder=position_encoder,
>>>     input_sensorchan=5,
>>>     decoder="simple_conv",
>>>     backbone=backbone,
>>> )
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.HeterogeneousModel.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = model.state_dict()
>>> assert recon is not model
>>> assert set(recon_state) == set(recon_state)
>>> for key in recon_state.keys():
>>>     assert (model_state[key] == recon_state[key]).all()
>>>     assert model_state[key] is not recon_state[key]

Example

>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> #from geowatch.tasks.fusion.methods.heterogeneous import *  # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> model = self = methods.HeterogeneousModel(
>>>     position_encoder=position_encoder,
>>>     input_sensorchan=5,
>>>     decoder="trans_conv",
>>>     backbone=backbone,
>>> )
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.HeterogeneousModel.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = model.state_dict()
>>> assert recon is not model
>>> assert set(recon_state) == set(recon_state)
>>> for key in recon_state.keys():
>>>     assert (model_state[key] == recon_state[key]).all()
>>>     assert model_state[key] is not recon_state[key]
configure_optimizers()

Note: this is only a fallback for testing purposes. This should be overwrriten in your module or done via lightning CLI.