geowatch.tasks.fusion.methods.heterogeneous module¶
- geowatch.tasks.fusion.methods.heterogeneous.to_next_multiple(n, mult)[source]¶
Example
>>> from geowatch.tasks.fusion.methods.heterogeneous import to_next_multiple >>> x = to_next_multiple(11, 4) >>> assert x == 1, f"x = {x}, should be 1"
- geowatch.tasks.fusion.methods.heterogeneous.positions_from_shape(shape, dtype='float32', device='cpu')[source]¶
- class geowatch.tasks.fusion.methods.heterogeneous.PadToMultiple(multiple: int, mode: str = 'constant', value=0.0)[source]¶
Bases:
Module
Pads input image-shaped tensors following strategy defined by mode/value. All padding appended to bottom and right of input.
- Parameters:
multiple – (int)
mode – (str, default: ‘constant’) Padding strategy. One of (‘constant’, ‘reflect’, ‘replicate’, ‘circular’). See: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html#torch.nn.functional.pad
value – (Any, default: None) Fill value for ‘constant’, set to 0 automatically when value=None. See: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html#torch.nn.functional.pad
Example
>>> from geowatch.tasks.fusion.methods.heterogeneous import PadToMultiple >>> import torch >>> pad_module = PadToMultiple(4) >>> inputs = torch.randn(1, 3, 10, 11) >>> outputs = pad_module(inputs) >>> assert outputs.shape == (1, 3, 12, 12), f"outputs.shape actually {outputs.shape}"
Example
>>> from geowatch.tasks.fusion.methods.heterogeneous import PadToMultiple >>> import torch >>> pad_module = PadToMultiple(4) >>> inputs = torch.randn(3, 10, 11) >>> outputs = pad_module(inputs) >>> assert outputs.shape == (3, 12, 12), f"outputs.shape actually {outputs.shape}"
Example
>>> from geowatch.tasks.fusion.methods.heterogeneous import PadToMultiple >>> from torch import nn >>> import torch >>> token_width = 10 >>> pad_module = nn.Sequential( >>> PadToMultiple(token_width, value=0.0), >>> nn.Conv2d( >>> 3, >>> 16, >>> kernel_size=token_width, >>> stride=token_width, >>> ) >>> ) >>> inputs = torch.randn(3, 64, 65) >>> outputs = pad_module(inputs) >>> assert outputs.shape == (16, 7, 7), f"outputs.shape actually {outputs.shape}"
- class geowatch.tasks.fusion.methods.heterogeneous.NanToNum(num=0.0)[source]¶
Bases:
Module
Module which converts NaN values in input tensors to numbers.
- class geowatch.tasks.fusion.methods.heterogeneous.ShapePreservingTransformerEncoder(token_dim, num_layers, batch_dim=0, chan_dim=1)[source]¶
Bases:
Module
- class geowatch.tasks.fusion.methods.heterogeneous.ScaleAwarePositionalEncoder[source]¶
Bases:
object
- class geowatch.tasks.fusion.methods.heterogeneous.MipNerfPositionalEncoder(in_dims: int, num_freqs: int = 10, max_freq: float = 4.0)[source]¶
Bases:
Module
,ScaleAwarePositionalEncoder
Module which computes MipNeRf-based positional encoding vectors from tensors of mean and scale values
out_dims = 2 * in_dims * num_freqs
- Parameters:
in_dims – (int) number of input dimensions to expect for future calls to .forward(). Currently only needed for computing .output_dim
num_freqs – (int) number of frequencies to project dimensions onto.
Example
>>> from geowatch.tasks.fusion.methods.heterogeneous import MipNerfPositionalEncoder >>> import torch >>> pos_enc = MipNerfPositionalEncoder(3, 4) >>> input_means = torch.randn(1, 3, 10, 10) >>> input_scales = torch.randn(1, 3, 10, 10) >>> outputs = pos_enc(input_means, input_scales) >>> assert outputs.shape == (1, pos_enc.output_dim, 10, 10)
- class geowatch.tasks.fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(in_dims: int, num_freqs: int = 10, max_freq: float = 4.0)[source]¶
Bases:
Module
,ScaleAwarePositionalEncoder
Module which computes MipNeRf-based positional encoding vectors from tensors of mean and scale values
out_dims = 2 * in_dims * num_freqs
- Parameters:
in_dims – (int) number of input dimensions to expect for future calls to .forward(). Currently only needed for computing .output_dim
num_freqs – (int) number of frequencies to project dimensions onto.
Example
>>> from geowatch.tasks.fusion.methods.heterogeneous import ScaleAgnostictPositionalEncoder >>> import torch >>> pos_enc = ScaleAgnostictPositionalEncoder(3, 4) >>> input_means = torch.randn(1, 3, 10, 10) >>> input_scales = torch.randn(1, 3, 10, 10) >>> outputs = pos_enc(input_means, input_scales) >>> assert outputs.shape == (1, pos_enc.output_dim, 10, 10)
- class geowatch.tasks.fusion.methods.heterogeneous.HeterogeneousModel(classes=10, dataset_stats=None, input_sensorchan=None, name: str = 'unnamed_model', position_encoder: str | ScaleAwarePositionalEncoder = 'auto', backbone: str | BackboneEncoderDecoder = 'auto', token_width: int = 10, token_dim: int = 16, spatial_scale_base: float = 1.0, temporal_scale_base: float = 1.0, class_weights: str = 'auto', saliency_weights: str = 'auto', positive_change_weight: float = 1.0, negative_change_weight: float = 1.0, global_class_weight: float = 1.0, global_change_weight: float = 1.0, global_saliency_weight: float = 1.0, change_loss: str = 'cce', class_loss: str = 'focal', saliency_loss: str = 'focal', tokenizer: str = 'simple_conv', decoder: str = 'upsample', ohem_ratio: float | None = None, focal_gamma: float | None = 2.0)[source]¶
Bases:
LightningModule
,WatchModuleMixins
- Parameters:
name – Specify a name for the experiment. (Unsure if the Model is the place for this)
token_width – Width of each square token.
token_dim – Dimensionality of each computed token.
spatial_scale_base – The scale assigned to each token equals scale_base / token_density, where the token density is the number of tokens along a given axis.
temporal_scale_base – The scale assigned to each token equals scale_base / token_density, where the token density is the number of tokens along a given axis.
class_weights – Class weighting strategy.
saliency_weights – Class weighting strategy.
Example
>>> # Note: it is important that the non-kwargs are saved as hyperparams >>> from geowatch.tasks.fusion.methods.heterogeneous import HeterogeneousModel, ScaleAgnostictPositionalEncoder >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = ScaleAgnostictPositionalEncoder(3, 8) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> model = HeterogeneousModel( >>> input_sensorchan='r|g|b', >>> position_encoder=position_encoder, >>> backbone=backbone, >>> )
- process_input_tokens(example)[source]¶
Example
>>> from geowatch.tasks import fusion >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> backbone=backbone, >>> ) >>> example = model.demo_batch(width=64, height=65)[0] >>> input_tokens = model.process_input_tokens(example) >>> assert len(input_tokens) == len(example["frames"]) >>> assert len(input_tokens[0]) == len(example["frames"][0]["modes"])
- process_query_tokens(example)[source]¶
Example
>>> from geowatch.tasks import fusion >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> backbone=backbone, >>> ) >>> example = model.demo_batch(width=64, height=65)[0] >>> query_tokens = model.process_query_tokens(example) >>> assert len(query_tokens) == len(example["frames"])
- forward(batch)[source]¶
Example
>>> from geowatch.tasks import fusion >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> backbone=backbone, >>> position_encoder=position_encoder, >>> ) >>> batch = model.demo_batch(width=64, height=65) >>> batch += model.demo_batch(width=55, height=75) >>> outputs = model.forward(batch) >>> for task_key, task_outputs in outputs.items(): >>> if "probs" in task_key: continue >>> if task_key == "class": task_key = "class_idxs" >>> for task_pred, example in zip(task_outputs, batch): >>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])): >>> if (frame_idx == 0) and task_key.startswith("change"): continue >>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
Example
>>> from geowatch.tasks import fusion >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> backbone=backbone, >>> decoder="simple_conv", >>> ) >>> batch = model.demo_batch(width=64, height=65) >>> batch += model.demo_batch(width=55, height=75) >>> outputs = model.forward(batch) >>> for task_key, task_outputs in outputs.items(): >>> if "probs" in task_key: continue >>> if task_key == "class": task_key = "class_idxs" >>> for task_pred, example in zip(task_outputs, batch): >>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])): >>> if (frame_idx == 0) and task_key.startswith("change"): continue >>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
Example
>>> from geowatch.tasks import fusion >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> backbone=backbone, >>> decoder="trans_conv", >>> ) >>> batch = model.demo_batch(width=64, height=65) >>> batch += model.demo_batch(width=55, height=75) >>> outputs = model.forward(batch) >>> for task_key, task_outputs in outputs.items(): >>> if "probs" in task_key: continue >>> if task_key == "class": task_key = "class_idxs" >>> for task_pred, example in zip(task_outputs, batch): >>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])): >>> if (frame_idx == 0) and task_key.startswith("change"): continue >>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
Example
>>> from geowatch.tasks import fusion >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=0, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=0, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> backbone=backbone, >>> decoder="trans_conv", >>> ) >>> batch = model.demo_batch(width=64, height=65) >>> batch += model.demo_batch(width=55, height=75) >>> outputs = model.forward(batch) >>> for task_key, task_outputs in outputs.items(): >>> if "probs" in task_key: continue >>> if task_key == "class": task_key = "class_idxs" >>> for task_pred, example in zip(task_outputs, batch): >>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])): >>> if (frame_idx == 0) and task_key.startswith("change"): continue >>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
Example
>>> # xdoctest: +REQUIRES(module:mmseg) >>> from geowatch.tasks import fusion >>> from geowatch.tasks.fusion.architectures.transformer import MM_VITEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = MM_VITEncoderDecoder( >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> backbone=backbone, >>> position_encoder=position_encoder, >>> ) >>> batch = model.demo_batch(width=64, height=65) >>> batch += model.demo_batch(width=55, height=75) >>> outputs = model.forward(batch) >>> for task_key, task_outputs in outputs.items(): >>> if "probs" in task_key: continue >>> if task_key == "class": task_key = "class_idxs" >>> for task_pred, example in zip(task_outputs, batch): >>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])): >>> if (frame_idx == 0) and task_key.startswith("change"): continue >>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
Example
>>> # xdoctest: +REQUIRES(module:mmseg) >>> from geowatch.tasks import fusion >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> self = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> #token_dim=708, >>> token_dim=768 - 60, >>> backbone='vit_B_16_imagenet1k', >>> position_encoder=position_encoder, >>> ) >>> batch = self.demo_batch(width=64, height=65) >>> batch += self.demo_batch(width=55, height=75) >>> outputs = self.forward(batch) >>> for task_key, task_outputs in outputs.items(): >>> if "probs" in task_key: continue >>> if task_key == "class": task_key = "class_idxs" >>> for task_pred, example in zip(task_outputs, batch): >>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])): >>> if (frame_idx == 0) and task_key.startswith("change"): continue >>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
Example
>>> from geowatch.tasks import fusion >>> import torch >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> decoder="trans_conv", >>> backbone=backbone, >>> ) >>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3) >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
Example
>>> from geowatch.tasks import fusion >>> import torch >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> decoder="trans_conv", >>> backbone=backbone, >>> ) >>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3) >>> batch += [None] >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
Example
>>> from geowatch.tasks import fusion >>> import torch >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> decoder="trans_conv", >>> backbone=backbone, >>> ) >>> batch = model.demo_batch(width=64, height=65) >>> for cutoff in [-1, -2]: >>> degraded_example = model.demo_batch(width=55, height=75, num_timesteps=3)[0] >>> degraded_example["frames"] = degraded_example["frames"][:cutoff] >>> batch += [degraded_example] >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
Example
>>> from geowatch.tasks import fusion >>> import torch >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> decoder="trans_conv", >>> backbone=backbone, >>> ) >>> batch = model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=0.1) >>> batch += model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=0.5) >>> batch += model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=1.0) >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
- forward_step(batch, batch_idx=None, stage='train', with_loss=True)¶
Example
>>> from geowatch.tasks import fusion >>> import torch >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> decoder="trans_conv", >>> backbone=backbone, >>> ) >>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3) >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
Example
>>> from geowatch.tasks import fusion >>> import torch >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> decoder="trans_conv", >>> backbone=backbone, >>> ) >>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3) >>> batch += [None] >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
Example
>>> from geowatch.tasks import fusion >>> import torch >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> decoder="trans_conv", >>> backbone=backbone, >>> ) >>> batch = model.demo_batch(width=64, height=65) >>> for cutoff in [-1, -2]: >>> degraded_example = model.demo_batch(width=55, height=75, num_timesteps=3)[0] >>> degraded_example["frames"] = degraded_example["frames"][:cutoff] >>> batch += [degraded_example] >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
Example
>>> from geowatch.tasks import fusion >>> import torch >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats() >>> model = fusion.methods.HeterogeneousModel( >>> classes=classes, >>> dataset_stats=dataset_stats, >>> input_sensorchan=channels, >>> position_encoder=position_encoder, >>> decoder="trans_conv", >>> backbone=backbone, >>> ) >>> batch = model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=0.1) >>> batch += model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=0.5) >>> batch += model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=1.0) >>> outputs = model.shared_step(batch) >>> optimizer = torch.optim.Adam(model.parameters()) >>> optimizer.zero_grad() >>> loss = outputs["loss"] >>> loss.backward() >>> optimizer.step()
- log_grad_norm(grad_norm_dict) None [source]¶
Override this method to change the default behaviour of
log_grad_norm
.Overloads log_grad_norm so we can supress the batch_size warning
- save_package(package_path, context=None, verbose=1)[source]¶
CommandLine
xdoctest -m geowatch.tasks.fusion.methods.heterogeneous HeterogeneousModel.save_package
Example
>>> # Test without datamodule >>> import ubelt as ub >>> from os.path import join >>> #from geowatch.tasks.fusion.methods.heterogeneous import * # NOQA >>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir() >>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test >>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion import datamodules >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> model = self = methods.HeterogeneousModel( >>> position_encoder=position_encoder, >>> input_sensorchan=5, >>> decoder="upsample", >>> backbone=backbone, >>> )
>>> # Save the model (TODO: need to save datamodule as well) >>> model.save_package(package_path)
>>> # Test that the package can be reloaded >>> #recon = methods.HeterogeneousModel.load_package(package_path) >>> from geowatch.tasks.fusion.utils import load_model_from_package >>> recon = load_model_from_package(package_path) >>> # Check consistency and data is actually different >>> recon_state = recon.state_dict() >>> model_state = model.state_dict() >>> assert recon is not model >>> assert set(recon_state) == set(recon_state) >>> for key in recon_state.keys(): >>> assert (model_state[key] == recon_state[key]).all() >>> assert model_state[key] is not recon_state[key]
Example
>>> # Test without datamodule >>> import ubelt as ub >>> from os.path import join >>> #from geowatch.tasks.fusion.methods.heterogeneous import * # NOQA >>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir() >>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test >>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion import datamodules >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> model = self = methods.HeterogeneousModel( >>> position_encoder=position_encoder, >>> input_sensorchan=5, >>> decoder="simple_conv", >>> backbone=backbone, >>> )
>>> # Save the model (TODO: need to save datamodule as well) >>> model.save_package(package_path)
>>> # Test that the package can be reloaded >>> #recon = methods.HeterogeneousModel.load_package(package_path) >>> from geowatch.tasks.fusion.utils import load_model_from_package >>> recon = load_model_from_package(package_path) >>> # Check consistency and data is actually different >>> recon_state = recon.state_dict() >>> model_state = model.state_dict() >>> assert recon is not model >>> assert set(recon_state) == set(recon_state) >>> for key in recon_state.keys(): >>> assert (model_state[key] == recon_state[key]).all() >>> assert model_state[key] is not recon_state[key]
Example
>>> # Test without datamodule >>> import ubelt as ub >>> from os.path import join >>> #from geowatch.tasks.fusion.methods.heterogeneous import * # NOQA >>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir() >>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test >>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion import datamodules >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> model = self = methods.HeterogeneousModel( >>> position_encoder=position_encoder, >>> input_sensorchan=5, >>> decoder="trans_conv", >>> backbone=backbone, >>> )
>>> # Save the model (TODO: need to save datamodule as well) >>> model.save_package(package_path)
>>> # Test that the package can be reloaded >>> #recon = methods.HeterogeneousModel.load_package(package_path) >>> from geowatch.tasks.fusion.utils import load_model_from_package >>> recon = load_model_from_package(package_path) >>> # Check consistency and data is actually different >>> recon_state = recon.state_dict() >>> model_state = model.state_dict() >>> assert recon is not model >>> assert set(recon_state) == set(recon_state) >>> for key in recon_state.keys(): >>> assert (model_state[key] == recon_state[key]).all() >>> assert model_state[key] is not recon_state[key]
- configure_optimizers()¶
Note: this is only a fallback for testing purposes. This should be overwrriten in your module or done via lightning CLI.