import pytorch_lightning as pl
import torch
from torch import nn
import torchmetrics
import einops
from einops.layers.torch import Rearrange
from torchvision import models as tv_models
from torchvision.models import feature_extraction
from typing import Union, Optional
import numpy as np
import kwcoco
import kwarray
import ubelt as ub
from geowatch import heuristics
from geowatch.tasks.fusion.methods.network_modules import coerce_criterion
from geowatch.tasks.fusion.methods.network_modules import RobustModuleDict
from geowatch.tasks.fusion.methods.watch_module_mixins import WatchModuleMixins
from geowatch.tasks.fusion.architectures.transformer import BackboneEncoderDecoder, TransformerEncoderDecoder
from geowatch.tasks.fusion.architectures import transformer
from geowatch.utils.util_netharn import InputNorm
from abc import ABCMeta, abstractmethod
try:
import xdev
profile = xdev.profile
except Exception:
profile = ub.identity
SPLIT_ATTENTION_ENCODERS = list(transformer.encoder_configs.keys())
[docs]
def to_next_multiple(n, mult):
"""
Example:
>>> from geowatch.tasks.fusion.methods.heterogeneous import to_next_multiple
>>> x = to_next_multiple(11, 4)
>>> assert x == 1, f"x = {x}, should be 1"
"""
diff = mult - n % mult
if diff == mult:
return 0
return diff
[docs]
def positions_from_shape(shape, dtype="float32", device="cpu"):
positions = torch.stack(torch.meshgrid(*[
torch.linspace(-1, 1, size + 1, dtype=dtype, device=device)[:-1]
for size in shape
]), dim=0)
mean_dims = list(range(1, len(shape) + 1))
positions -= positions.mean(dim=mean_dims, keepdims=True)
return positions
[docs]
class PadToMultiple(nn.Module):
def __init__(self, multiple: int, mode: str = 'constant', value=0.):
"""
Pads input image-shaped tensors following strategy defined by mode/value. All padding appended to bottom and right of input.
Args:
multiple: (int)
mode:
(str, default: 'constant') Padding strategy. One of ('constant', 'reflect', 'replicate', 'circular').
See: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html#torch.nn.functional.pad
value:
(Any, default: None) Fill value for 'constant', set to 0 automatically when value=None.
See: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html#torch.nn.functional.pad
Example:
>>> from geowatch.tasks.fusion.methods.heterogeneous import PadToMultiple
>>> import torch
>>> pad_module = PadToMultiple(4)
>>> inputs = torch.randn(1, 3, 10, 11)
>>> outputs = pad_module(inputs)
>>> assert outputs.shape == (1, 3, 12, 12), f"outputs.shape actually {outputs.shape}"
Example:
>>> from geowatch.tasks.fusion.methods.heterogeneous import PadToMultiple
>>> import torch
>>> pad_module = PadToMultiple(4)
>>> inputs = torch.randn(3, 10, 11)
>>> outputs = pad_module(inputs)
>>> assert outputs.shape == (3, 12, 12), f"outputs.shape actually {outputs.shape}"
Example:
>>> from geowatch.tasks.fusion.methods.heterogeneous import PadToMultiple
>>> from torch import nn
>>> import torch
>>> token_width = 10
>>> pad_module = nn.Sequential(
>>> PadToMultiple(token_width, value=0.0),
>>> nn.Conv2d(
>>> 3,
>>> 16,
>>> kernel_size=token_width,
>>> stride=token_width,
>>> )
>>> )
>>> inputs = torch.randn(3, 64, 65)
>>> outputs = pad_module(inputs)
>>> assert outputs.shape == (16, 7, 7), f"outputs.shape actually {outputs.shape}"
"""
super().__init__()
self.multiple = multiple
self.mode = mode
self.value = value
[docs]
def forward(self, x):
height, width = x.shape[-2:]
pad = (
0, to_next_multiple(width, self.multiple),
0, to_next_multiple(height, self.multiple),
)
return nn.functional.pad(x, pad, mode=self.mode, value=self.value)
[docs]
class NanToNum(nn.Module):
"""
Module which converts NaN values in input tensors to numbers.
"""
def __init__(self, num=0.0):
super().__init__()
self.num = num
[docs]
def forward(self, x):
return torch.nan_to_num(x, self.num)
[docs]
class ScaleAwarePositionalEncoder(metaclass=ABCMeta):
[docs]
@abstractmethod
def forward(self, mean, scale):
pass
[docs]
class MipNerfPositionalEncoder(nn.Module, ScaleAwarePositionalEncoder):
"""
Module which computes MipNeRf-based positional encoding vectors from tensors of mean and scale values
"""
def __init__(self, in_dims: int, num_freqs: int = 10, max_freq: float = 4.):
"""
out_dims = 2 * in_dims * num_freqs
Args:
in_dims: (int) number of input dimensions to expect for future calls to .forward(). Currently only needed for computing .output_dim
num_freqs: (int) number of frequencies to project dimensions onto.
Example:
>>> from geowatch.tasks.fusion.methods.heterogeneous import MipNerfPositionalEncoder
>>> import torch
>>> pos_enc = MipNerfPositionalEncoder(3, 4)
>>> input_means = torch.randn(1, 3, 10, 10)
>>> input_scales = torch.randn(1, 3, 10, 10)
>>> outputs = pos_enc(input_means, input_scales)
>>> assert outputs.shape == (1, pos_enc.output_dim, 10, 10)
"""
super().__init__()
frequencies = torch.linspace(0, max_freq, num_freqs)
self.mean_weights = nn.Parameter(
2. ** frequencies,
requires_grad=False)
self.scale_weights = nn.Parameter(
-2. ** (2. * frequencies - 1),
requires_grad=False)
self.weight = self.mean_weights
self.output_dim = 2 * in_dims * num_freqs
[docs]
def forward(self, mean, scale):
weighted_means = torch.einsum("y,bx...->bxy...", self.mean_weights, mean)
weighted_means = einops.rearrange(weighted_means, "batch x y ... -> batch (x y) ...")
weighted_scales = torch.einsum("y,bx...->bxy...", self.scale_weights, scale)
weighted_scales = einops.rearrange(weighted_scales, "batch x y ... -> batch (x y) ...")
return torch.concat([
weighted_means.sin() * weighted_scales.exp(),
weighted_means.cos() * weighted_scales.exp(),
], dim=1)
[docs]
class ScaleAgnostictPositionalEncoder(nn.Module, ScaleAwarePositionalEncoder):
"""
Module which computes MipNeRf-based positional encoding vectors from tensors of mean and scale values
"""
def __init__(self, in_dims: int, num_freqs: int = 10, max_freq: float = 4.):
"""
out_dims = 2 * in_dims * num_freqs
Args:
in_dims: (int) number of input dimensions to expect for future calls to .forward(). Currently only needed for computing .output_dim
num_freqs: (int) number of frequencies to project dimensions onto.
Example:
>>> from geowatch.tasks.fusion.methods.heterogeneous import ScaleAgnostictPositionalEncoder
>>> import torch
>>> pos_enc = ScaleAgnostictPositionalEncoder(3, 4)
>>> input_means = torch.randn(1, 3, 10, 10)
>>> input_scales = torch.randn(1, 3, 10, 10)
>>> outputs = pos_enc(input_means, input_scales)
>>> assert outputs.shape == (1, pos_enc.output_dim, 10, 10)
"""
super().__init__()
frequencies = torch.linspace(0, max_freq, num_freqs)
self.mean_weights = nn.Parameter(
2. ** frequencies,
requires_grad=False)
self.weight = self.mean_weights
self.output_dim = 2 * in_dims * num_freqs
[docs]
def forward(self, mean, scale):
weighted_means = torch.einsum("y,bx...->bxy...", self.mean_weights, mean)
weighted_means = einops.rearrange(weighted_means, "batch x y ... -> batch (x y) ...")
return torch.concat([
weighted_means.sin(),
weighted_means.cos(),
], dim=1)
[docs]
class ResNetShim(nn.Module):
def __init__(self, submodule):
super().__init__()
self.submodule = submodule
[docs]
def forward(self, x):
return self.submodule(x[None])["layer4"][0]
[docs]
class HeterogeneousModel(pl.LightningModule, WatchModuleMixins):
_HANDLES_NANS = True
[docs]
def get_cfgstr(self):
cfgstr = f'{self.hparams.name}_heterogeneous'
return cfgstr
def __init__(
self,
classes=10,
dataset_stats=None,
input_sensorchan=None,
name: str = "unnamed_model",
position_encoder: Union[str, ScaleAwarePositionalEncoder] = 'auto',
backbone: Union[str, BackboneEncoderDecoder] = 'auto',
token_width: int = 10,
token_dim: int = 16,
spatial_scale_base: float = 1.,
temporal_scale_base: float = 1.,
class_weights: str = "auto",
saliency_weights: str = "auto",
positive_change_weight: float = 1.0,
negative_change_weight: float = 1.0,
global_class_weight: float = 1.0,
global_change_weight: float = 1.0,
global_saliency_weight: float = 1.0,
change_loss: str = "cce", # TODO: replace control string with a module, possibly a subclass
class_loss: str = "focal", # TODO: replace control string with a module, possibly a subclass
saliency_loss: str = "focal", # TODO: replace control string with a module, possibly a subclass
tokenizer: str = "simple_conv", # TODO: replace control string with a module, possibly a subclass
decoder: str = "upsample", # TODO: replace control string with a module, possibly a subclass
ohem_ratio: Optional[float] = None,
focal_gamma: Optional[float] = 2.0,
):
"""
Args:
name: Specify a name for the experiment. (Unsure if the Model is the place for this)
token_width: Width of each square token.
token_dim: Dimensionality of each computed token.
spatial_scale_base: The scale assigned to each token equals `scale_base / token_density`, where the token density is the number of tokens along a given axis.
temporal_scale_base: The scale assigned to each token equals `scale_base / token_density`, where the token density is the number of tokens along a given axis.
class_weights: Class weighting strategy.
saliency_weights: Class weighting strategy.
Example:
>>> # Note: it is important that the non-kwargs are saved as hyperparams
>>> from geowatch.tasks.fusion.methods.heterogeneous import HeterogeneousModel, ScaleAgnostictPositionalEncoder
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = ScaleAgnostictPositionalEncoder(3, 8)
>>> backbone = TransformerEncoderDecoder(
>>> encoder_depth=1,
>>> decoder_depth=1,
>>> dim=position_encoder.output_dim + 16,
>>> queries_dim=position_encoder.output_dim,
>>> logits_dim=16,
>>> cross_heads=1,
>>> latent_heads=1,
>>> cross_dim_head=1,
>>> latent_dim_head=1,
>>> )
>>> model = HeterogeneousModel(
>>> input_sensorchan='r|g|b',
>>> position_encoder=position_encoder,
>>> backbone=backbone,
>>> )
"""
# assert position_encoder is not None
assert tokenizer in {"simple_conv", "resnet18"}, "Tokenizer not implemented yet."
assert decoder in {"upsample", "simple_conv", "trans_conv"}, "Decoder not implemented yet."
if isinstance(position_encoder, str):
if position_encoder == 'auto':
position_encoder = ScaleAgnostictPositionalEncoder(3, 8)
else:
raise KeyError(position_encoder)
pre_backbone = None
post_backbone = None
if isinstance(backbone, str):
if backbone == 'auto':
# TODO: set this to a "reasonable" default.
backbone = TransformerEncoderDecoder(
encoder_depth=3,
decoder_depth=3,
dim=position_encoder.output_dim + token_dim,
queries_dim=position_encoder.output_dim,
logits_dim=token_dim,
cross_heads=1,
latent_heads=1,
cross_dim_head=1,
latent_dim_head=1,
)
elif backbone == 'small':
# This should be a reasonable small network for testing
backbone = TransformerEncoderDecoder(
encoder_depth=1,
decoder_depth=1,
dim=position_encoder.output_dim + token_dim,
queries_dim=position_encoder.output_dim,
logits_dim=token_dim,
cross_heads=1,
latent_heads=1,
cross_dim_head=1,
latent_dim_head=1,
)
elif backbone == 'wu-vit':
"""
import geowatch
from geowatch.utils.simple_dvc import SimpleDVC
expt_dvc_dpath = geowatch.find_dvc_dpath(tags='phase2_expt')
expt_dvc = SimpleDVC(expt_dvc_dpath)
ckpt_fpath = expt_dvc_dpath / 'models/wu/MAE-2023-02-09/goldenMae-epoch=07-val_loss=0.23.ckpt'
from geowatch.tasks.fusion.methods.heterogeneous import HeterogeneousModel, ScaleAgnostictPositionalEncoder
from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
position_encoder = ScaleAgnostictPositionalEncoder(3, 8)
channels, classes, dataset_stats = HeterogeneousModel.demo_dataset_stats()
model = HeterogeneousModel(
# token_dim=768,
# token_dim=768,
input_sensorchan=channels,
classes=classes,
dataset_stats=dataset_stats,
position_encoder=position_encoder,
backbone='wu-vit',
)
from geowatch.tasks.fusion.fit import coerce_initializer
from kwutil import util_pattern
initializer = coerce_initializer(str(ckpt_fpath))
initializer.forward(model)
batch = model.demo_batch(width=64, height=65)
batch += model.demo_batch(width=55, height=75)
outputs = model.forward(batch)
"""
from geowatch.tasks.fusion.architectures import wu_mae
pre_backbone = nn.Linear(token_dim + position_encoder.output_dim, 16)
# post_backbone = nn.Linear(16, token_dim + position_encoder.output_dim)
post_backbone = nn.Linear(16, token_dim)
backbone = wu_mae.wu_backbone().transformer
elif backbone == 'sits-former':
"""
Ignore:
import geowatch
from geowatch.utils.simple_dvc import SimpleDVC
expt_dvc_dpath = geowatch.find_dvc_dpath(tags='phase2_expt')
expt_dvc = SimpleDVC(expt_dvc_dpath)
pretrained_fpath = expt_dvc_dpath / 'models/pretrained/sits-former/checkpoint.bert.tar'
import torch
model_state = torch.load(pretrained_fpath)
from geowatch.tasks.fusion.methods.heterogeneous import HeterogeneousModel, ScaleAgnostictPositionalEncoder
from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
position_encoder = ScaleAgnostictPositionalEncoder(3, 8)
channels, classes, dataset_stats = HeterogeneousModel.demo_dataset_stats()
model = HeterogeneousModel(
token_dim=208,
input_sensorchan=channels,
classes=classes,
dataset_stats=dataset_stats,
position_encoder=position_encoder,
backbone='sits-former',
)
from geowatch.tasks.fusion.fit import coerce_initializer
from kwutil import util_pattern
initializer = coerce_initializer(str(pretrained_fpath))
initializer.forward(model)
batch = model.demo_batch(width=64, height=65)
batch += model.demo_batch(width=55, height=75)
outputs = model.forward(batch)
"""
from geowatch.tasks.fusion.architectures import sits
bert_config = {
'num_features': 10,
'hidden': 256,
'n_layers': 3,
'attn_heads': 8,
'dropout': 0.1,
}
bert = sits.BERT(**bert_config)
backbone = bert.transformer_encoder
# Hack to denote that we need to not use batch first for this
# model.
backbone.is_sits_bert = True
backbone.batch_first = backbone.layers[0].self_attn.batch_first
# sits_config = {
# 'patch_size': 5,
# 'num_classes': 15,
# }
# sits.BERTClassification()
elif backbone == 'vit_B_16_imagenet1k':
"""
pip install pytorch_pretrained_vit
"""
from pytorch_pretrained_vit import ViT
vit_model = ViT('B_16_imagenet1k', pretrained=True)
backbone = vit_model.transformer
# assert token_dim == 708
elif backbone == 'vit_B_16':
from pytorch_pretrained_vit import ViT
vit_model = ViT('B_16', pretrained=True)
backbone = vit_model.transformer
# assert token_dim == 708
elif backbone in SPLIT_ATTENTION_ENCODERS:
encoder_config = transformer.encoder_configs[backbone]
backbone = transformer.FusionEncoder(
**encoder_config,
in_features=position_encoder.output_dim + token_dim,
# attention_impl=self.hparams.attention_impl,
dropout=0.1,
)
else:
raise KeyError(backbone)
super().__init__()
self.save_hyperparameters(ignore=["position_encoder"])
input_stats = self.set_dataset_specific_attributes(input_sensorchan, dataset_stats)
self.classes = kwcoco.CategoryTree.coerce(classes)
self.num_classes = len(self.classes)
# TODO: this data should be introspectable via the kwcoco file
hueristic_background_keys = heuristics.BACKGROUND_CLASSES
# FIXME: case sensitivity
hueristic_ignore_keys = heuristics.IGNORE_CLASSNAMES
if self.class_freq is not None:
all_keys = set(self.class_freq.keys())
else:
all_keys = set(self.classes)
self.background_classes = all_keys & hueristic_background_keys
self.ignore_classes = all_keys & hueristic_ignore_keys
self.foreground_classes = (all_keys - self.background_classes) - self.ignore_classes
# hueristic_ignore_keys.update(hueristic_occluded_keys)
self.saliency_num_classes = 2
# criterion and metrics
# TODO: parametarize loss criterions
# For loss function experiments, see and work in
# ~/code/watch/geowatch/tasks/fusion/methods/sequence_aware.py
# self.change_criterion = monai.losses.FocalLoss(reduction='none', to_onehot_y=False)
self.saliency_weights = self._coerce_saliency_weights('auto')
self.class_weights = self._coerce_class_weights(class_weights)
self.change_weights = torch.FloatTensor([
self.hparams.negative_change_weight,
self.hparams.positive_change_weight
])
self.sensor_channel_tokenizers = RobustModuleDict()
# Unique sensor modes obviously isn't very correct here.
# We should fix that, but let's hack it so it at least
# includes all sensor modes we probably will need.
if input_stats is not None:
sensor_modes = set(self.unique_sensor_modes) | set(input_stats.keys())
else:
sensor_modes = set(self.unique_sensor_modes)
for s, c in sorted(sensor_modes):
mode_code = kwcoco.FusedChannelSpec.coerce(c)
# For each mode make a network that should learn to tokenize
in_chan = mode_code.numel()
if input_stats is None:
input_norm = InputNorm()
else:
stats = input_stats.get((s, c), None)
if stats is None:
input_norm = InputNorm()
else:
input_norm = InputNorm(
**(ub.udict(stats) & {'mean', 'std'}))
if tokenizer == "simple_conv":
tokenizer_layer = nn.Sequential(
PadToMultiple(token_width, value=0.0),
nn.Conv2d(
in_chan,
token_dim,
kernel_size=token_width,
stride=token_width,
),
)
elif tokenizer == "resnet18":
resnet = tv_models.resnet18(tv_models.ResNet18_Weights.IMAGENET1K_V1)
resnet.conv1 = nn.Conv2d(in_chan, resnet.conv1.out_channels, kernel_size=7, stride=2, padding=3, bias=False)
resnet = feature_extraction.create_feature_extractor(resnet, return_nodes=["layer4"])
tokenizer_layer = nn.Sequential(
ResNetShim(resnet),
nn.Conv2d(512, token_dim, 1),
)
else:
raise NotImplementedError(tokenizer)
# key = sanitize_key(str((s, c)))
key = f'{s}:{c}'
self.sensor_channel_tokenizers[key] = nn.Sequential(
input_norm,
NanToNum(0.0),
tokenizer_layer,
)
self.position_encoder = position_encoder
# self.position_encoder = RandomFourierPositionalEncoder(3, 16)
# position_dim = self.position_encoder.output_dim
self.pre_backbone = pre_backbone
self.backbone = backbone
self.post_backbone = post_backbone
# self.backbone = TransformerEncoderDecoder(
# encoder_depth=backbone_encoder_depth,
# decoder_depth=backbone_decoder_depth,
# dim=token_dim + position_dim,
# queries_dim=position_dim,
# logits_dim=token_dim,
# cross_heads=backbone_cross_heads,
# latent_heads=backbone_latent_heads,
# cross_dim_head=backbone_cross_dim_head,
# latent_dim_head=backbone_latent_dim_head,
# weight_tie_layers=backbone_weight_tie_layers,
# )
self.criterions = torch.nn.ModuleDict()
self.heads = torch.nn.ModuleDict()
self.task_to_keynames = {
'change': {
'labels': 'change',
'weights': 'change_weights',
'output_dims': 'change_output_dims'
},
'saliency': {
'labels': 'saliency',
'weights': 'saliency_weights',
'output_dims': 'saliency_output_dims'
},
'class': {
'labels': 'class_idxs',
'weights': 'class_weights',
'output_dims': 'class_output_dims'
},
}
head_properties = [
{
'name': 'change',
'channels': 2,
'loss': self.hparams.change_loss,
'weights': self.change_weights,
},
{
'name': 'saliency',
'channels': self.saliency_num_classes,
'loss': self.hparams.saliency_loss,
'weights': self.saliency_weights,
},
{
'name': 'class',
'channels': self.num_classes,
'loss': self.hparams.class_loss,
'weights': self.class_weights,
},
]
self.global_head_weights = {
'class': global_class_weight,
'change': global_change_weight,
'saliency': global_saliency_weight,
}
self.magic_padding_value = -99999999.0 # Magic placeholder value
for prop in head_properties:
head_name = prop['name']
global_weight = self.global_head_weights[head_name]
if global_weight > 0:
self.criterions[head_name] = coerce_criterion(prop['loss'],
prop['weights'],
ohem_ratio=ohem_ratio,
focal_gamma=focal_gamma)
if self.hparams.decoder == "upsample":
self.heads[head_name] = nn.Sequential(
nn.Upsample(scale_factor=(token_width, token_width), mode="bilinear"),
nn.Conv2d(
token_dim,
prop['channels'],
kernel_size=5,
padding="same",
bias=False),
)
elif self.hparams.decoder == "trans_conv":
self.heads[head_name] = nn.Sequential(
# ShapePreservingTransformerEncoder(
# nn.TransformerEncoderLayer(token_dim, 8, dim_feedforward=512, dropout=0.1, activation="gelu", batch_first=True, norm_first=True),
# num_layers=2,
# ),
ShapePreservingTransformerEncoder(
token_dim,
num_layers=2,
batch_dim=0,
chan_dim=1,
),
nn.Conv2d(token_dim, token_width * token_width * prop['channels'], 1, bias=False),
Rearrange(
"batch (chan dh dw) height width -> batch chan (height dh) (width dw)",
dh=token_width, dw=token_width),
)
elif self.hparams.decoder == "simple_conv":
self.heads[head_name] = nn.Sequential(
nn.Conv2d(token_dim, token_width * token_width * prop['channels'], 1, bias=False),
Rearrange(
"batch (chan dh dw) height width -> batch chan (height dh) (width dw)",
dh=token_width, dw=token_width),
)
else:
raise NotImplementedError(decoder)
FBetaScore = torchmetrics.FBetaScore
class_metrics = torchmetrics.MetricCollection({
"class_acc": torchmetrics.Accuracy(num_classes=self.num_classes, task='multiclass'),
# "class_iou": torchmetrics.IoU(2),
'class_f1_micro': FBetaScore(beta=1.0, threshold=0.5, average='micro', num_classes=self.num_classes, task='multiclass'),
'class_f1_macro': FBetaScore(beta=1.0, threshold=0.5, average='macro', num_classes=self.num_classes, task='multiclass'),
})
change_metrics = torchmetrics.MetricCollection({
"change_acc": torchmetrics.Accuracy(task="binary"),
# "iou": torchmetrics.IoU(2),
'change_f1': FBetaScore(beta=1.0, task="binary"),
})
saliency_metrics = torchmetrics.MetricCollection({
'saliency_f1': FBetaScore(beta=1.0, task="binary"),
})
self.head_metrics = nn.ModuleDict({
f"{stage}_stage": nn.ModuleDict({
"class": class_metrics.clone(prefix=f"{stage}_"),
"change": change_metrics.clone(prefix=f"{stage}_"),
"saliency": saliency_metrics.clone(prefix=f"{stage}_"),
})
for stage in ["train", "val", "test"]
})
self._prev_batch_size = None
[docs]
def process_query_tokens(self, example):
"""
Example:
>>> from geowatch.tasks import fusion
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>> encoder_depth=1,
>>> decoder_depth=1,
>>> dim=position_encoder.output_dim + 16,
>>> queries_dim=position_encoder.output_dim,
>>> logits_dim=16,
>>> cross_heads=1,
>>> latent_heads=1,
>>> cross_dim_head=1,
>>> latent_dim_head=1,
>>> )
>>> model = fusion.methods.HeterogeneousModel(
>>> classes=classes,
>>> dataset_stats=dataset_stats,
>>> input_sensorchan=channels,
>>> position_encoder=position_encoder,
>>> backbone=backbone,
>>> )
>>> example = model.demo_batch(width=64, height=65)[0]
>>> query_tokens = model.process_query_tokens(example)
>>> assert len(query_tokens) == len(example["frames"])
"""
example_tokens = []
for frame in example["frames"]:
# space
height, width = frame["output_dims"]
height = height + to_next_multiple(height, self.hparams.token_width)
width = width + to_next_multiple(width, self.hparams.token_width)
height, width = tokens_shape = (height // self.hparams.token_width, width // self.hparams.token_width)
token_positions = positions_from_shape(
tokens_shape,
dtype=self.position_encoder.weight.dtype,
device=self.position_encoder.weight.device,
)
token_positions_scales = self.hparams.spatial_scale_base / torch.tensor(
token_positions.shape[1:],
dtype=token_positions.dtype,
device=token_positions.device,
)
token_positions_scales = einops.repeat(
token_positions_scales,
"chan -> chan height width",
height=height, width=width,
)
# time
token_times = frame["time_index"] * torch.ones_like(
token_positions[0],
dtype=token_positions.dtype,
device=token_positions.device,
)[None]
token_times_scales = self.hparams.temporal_scale_base * torch.ones(
1, height, width,
dtype=token_positions.dtype,
device=token_positions.device,
)
# combine positional encodings
token_encodings = torch.concat([
token_positions,
token_times,
])
token_scales = torch.concat([
token_positions_scales,
token_times_scales,
])
token_encodings = self.position_encoder(token_encodings[None], token_scales[None])[0]
example_tokens.append(token_encodings)
return example_tokens
[docs]
def forward(self, batch):
"""
Example:
>>> from geowatch.tasks import fusion
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>> encoder_depth=1,
>>> decoder_depth=1,
>>> dim=position_encoder.output_dim + 16,
>>> queries_dim=position_encoder.output_dim,
>>> logits_dim=16,
>>> cross_heads=1,
>>> latent_heads=1,
>>> cross_dim_head=1,
>>> latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>> classes=classes,
>>> dataset_stats=dataset_stats,
>>> input_sensorchan=channels,
>>> backbone=backbone,
>>> position_encoder=position_encoder,
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> batch += model.demo_batch(width=55, height=75)
>>> outputs = model.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>> if "probs" in task_key: continue
>>> if task_key == "class": task_key = "class_idxs"
>>> for task_pred, example in zip(task_outputs, batch):
>>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>> if (frame_idx == 0) and task_key.startswith("change"): continue
>>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
Example:
>>> from geowatch.tasks import fusion
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>> encoder_depth=1,
>>> decoder_depth=1,
>>> dim=position_encoder.output_dim + 16,
>>> queries_dim=position_encoder.output_dim,
>>> logits_dim=16,
>>> cross_heads=1,
>>> latent_heads=1,
>>> cross_dim_head=1,
>>> latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>> classes=classes,
>>> dataset_stats=dataset_stats,
>>> input_sensorchan=channels,
>>> position_encoder=position_encoder,
>>> backbone=backbone,
>>> decoder="simple_conv",
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> batch += model.demo_batch(width=55, height=75)
>>> outputs = model.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>> if "probs" in task_key: continue
>>> if task_key == "class": task_key = "class_idxs"
>>> for task_pred, example in zip(task_outputs, batch):
>>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>> if (frame_idx == 0) and task_key.startswith("change"): continue
>>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
Example:
>>> from geowatch.tasks import fusion
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>> encoder_depth=1,
>>> decoder_depth=1,
>>> dim=position_encoder.output_dim + 16,
>>> queries_dim=position_encoder.output_dim,
>>> logits_dim=16,
>>> cross_heads=1,
>>> latent_heads=1,
>>> cross_dim_head=1,
>>> latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>> classes=classes,
>>> dataset_stats=dataset_stats,
>>> input_sensorchan=channels,
>>> position_encoder=position_encoder,
>>> backbone=backbone,
>>> decoder="trans_conv",
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> batch += model.demo_batch(width=55, height=75)
>>> outputs = model.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>> if "probs" in task_key: continue
>>> if task_key == "class": task_key = "class_idxs"
>>> for task_pred, example in zip(task_outputs, batch):
>>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>> if (frame_idx == 0) and task_key.startswith("change"): continue
>>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
Example:
>>> from geowatch.tasks import fusion
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>> encoder_depth=1,
>>> decoder_depth=0,
>>> dim=position_encoder.output_dim + 16,
>>> queries_dim=0,
>>> logits_dim=16,
>>> cross_heads=1,
>>> latent_heads=1,
>>> cross_dim_head=1,
>>> latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>> classes=classes,
>>> dataset_stats=dataset_stats,
>>> input_sensorchan=channels,
>>> position_encoder=position_encoder,
>>> backbone=backbone,
>>> decoder="trans_conv",
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> batch += model.demo_batch(width=55, height=75)
>>> outputs = model.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>> if "probs" in task_key: continue
>>> if task_key == "class": task_key = "class_idxs"
>>> for task_pred, example in zip(task_outputs, batch):
>>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>> if (frame_idx == 0) and task_key.startswith("change"): continue
>>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
Example:
>>> # xdoctest: +REQUIRES(module:mmseg)
>>> from geowatch.tasks import fusion
>>> from geowatch.tasks.fusion.architectures.transformer import MM_VITEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = MM_VITEncoderDecoder(
>>> dim=position_encoder.output_dim + 16,
>>> queries_dim=position_encoder.output_dim,
>>> logits_dim=16,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>> classes=classes,
>>> dataset_stats=dataset_stats,
>>> input_sensorchan=channels,
>>> backbone=backbone,
>>> position_encoder=position_encoder,
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> batch += model.demo_batch(width=55, height=75)
>>> outputs = model.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>> if "probs" in task_key: continue
>>> if task_key == "class": task_key = "class_idxs"
>>> for task_pred, example in zip(task_outputs, batch):
>>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>> if (frame_idx == 0) and task_key.startswith("change"): continue
>>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
Example:
>>> # xdoctest: +REQUIRES(module:mmseg)
>>> from geowatch.tasks import fusion
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> self = fusion.methods.HeterogeneousModel(
>>> classes=classes,
>>> dataset_stats=dataset_stats,
>>> input_sensorchan=channels,
>>> #token_dim=708,
>>> token_dim=768 - 60,
>>> backbone='vit_B_16_imagenet1k',
>>> position_encoder=position_encoder,
>>> )
>>> batch = self.demo_batch(width=64, height=65)
>>> batch += self.demo_batch(width=55, height=75)
>>> outputs = self.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>> if "probs" in task_key: continue
>>> if task_key == "class": task_key = "class_idxs"
>>> for task_pred, example in zip(task_outputs, batch):
>>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>> if (frame_idx == 0) and task_key.startswith("change"): continue
>>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
Ignore:
from geowatch.tasks import fusion
from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
position_encoder = geowatch.tasks.fusion.methods.heterogeneous.MipNerfPositionalEncoder(in_dims=3, max_freq=3, num_freqs=16)
token_dim = 256
backbone = TransformerEncoderDecoder(
encoder_depth=6,
decoder_depth=0,
dim=position_encoder.output_dim + token_dim,
queries_dim=position_encoder.output_dim,
logits_dim=token_dim,
latent_dim_head=1024,
)
channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
model = fusion.methods.HeterogeneousModel(
token_dim=token_dim,
token_width=8,
classes=classes,
dataset_stats=dataset_stats,
input_sensorchan=channels,
position_encoder=position_encoder,
backbone=backbone,
spatial_scale_base=1,
global_change_weight=0,
global_class_weight=0,
global_saliency_weight=1,
decoder="simple_conv",
)
batch = model.demo_batch(width=64, height=65)
batch += model.demo_batch(width=55, height=75)
outputs = model.forward(batch)
"""
# ==================
# Compute input sequences and shapes
# ==================
# Lists to stash sequences and shapes
orig_input_shapes = []
orig_input_seqs = []
for example in batch:
# Each example, containing potentially more than one mode,
# is stemmed and then we save its original shape
input_tokens = self.process_input_tokens(example)
input_shapes = [
[
mode_tokens.shape[1:]
for mode_tokens in frame_tokens
]
for frame_tokens in input_tokens
]
orig_input_shapes.append(input_shapes)
# For the downstream transformer, we flatten and concatenate the
# stemmed tokens
input_token_seq = torch.concat([
torch.concat([
einops.rearrange(mode_tokens, "chan ... -> (...) chan")
for mode_tokens in frame_tokens
])
for frame_tokens in input_tokens
])
orig_input_seqs.append(input_token_seq)
if len(orig_input_seqs) == 0:
print(f'batch={batch}')
print('Skipping batch')
return None
self._prev_batch_size = len(orig_input_seqs)
# Each example may have a different number of tokens, so we perform
# some padding and compute a mask of where those padded tokens are
input_seqs = nn.utils.rnn.pad_sequence(
orig_input_seqs,
batch_first=True,
padding_value=self.magic_padding_value,
)
# Remove the placeholder
input_masks = input_seqs[..., 0] > self.magic_padding_value
input_seqs[~input_masks] = 0.
# ==================
# Compute query sequences and shapes
# (Should be similar/identical to the input proceedure)
# BUT only if we need to
# ==================
B, S, D = input_seqs.shape
if self.pre_backbone is not None:
# Fixup dims for the backbone
input_seqs = self.pre_backbone(input_seqs.view(-1, D)).view(B, S, -1)
has_decoder = getattr(self.backbone, 'has_decoder', False)
self.backbone.has_decoder = has_decoder
if has_decoder:
# Lists to stash sequences and shapes
orig_query_shapes = []
orig_query_seqs = []
for example in batch:
# Each example, containing potentially more than one task,
# a map of query position tokens are computed and then we save
# their original shape
query_tokens = self.process_query_tokens(example)
query_shapes = [
frame_tokens.shape[1:]
for frame_tokens in query_tokens
]
orig_query_shapes.append(query_shapes)
# For the downstream transformer, we flatten and concatenate
# the position embeddings
query_token_seq = torch.concat([
einops.rearrange(frame_tokens, "chan ... -> (...) chan")
for frame_tokens in query_tokens
])
orig_query_seqs.append(query_token_seq)
# Each example may have a different number of queries, so we perform
# some padding and compute a mask of where those padded tokens are
query_seqs = nn.utils.rnn.pad_sequence(
orig_query_seqs,
batch_first=True,
padding_value=self.magic_padding_value,
)
# Remove the placeholder
query_masks = query_seqs[..., 0] > self.magic_padding_value
query_seqs[~query_masks] = 0.
# ==================
# Forward pass!
# ==================
if self.backbone.has_decoder:
output_seqs = self.backbone(
input_seqs,
mask=input_masks,
queries=query_seqs,
)
output_shapes = orig_query_shapes
output_masks = query_masks
else:
# batch_first = getattr(self.backbone, 'batch_first', True)
is_sits_bert = getattr(self.backbone, 'is_sits_bert', False)
if is_sits_bert:
# Special case for pretrained BERT
# TODO: wrap the model to conform to the API here instead
# of directly hacking this function.
_input_seqs = input_seqs.transpose(0, 1)
# _input_masks = input_masks.transpose(0, 1)
_output_seqs = self.backbone(
_input_seqs,
src_key_padding_mask=~input_masks,
)
output_seqs = _output_seqs.transpose(0, 1)
output_shapes = orig_input_shapes
output_masks = input_masks
else:
# Normal case.
output_seqs = self.backbone(
input_seqs,
mask=input_masks,
)
output_shapes = orig_input_shapes
output_masks = input_masks
# Uncomment if old sits models need repackaing
HACK_TOKEN_DIMS = self.post_backbone is None
if HACK_TOKEN_DIMS:
# hack for VIT. Drops feature dims to allow for running
if output_seqs.shape[2] != self.hparams.token_dim:
output_seqs = output_seqs[:, :, 0:self.hparams.token_dim]
if self.post_backbone is not None:
# Fixup dims for the backbone
output_seqs = self.post_backbone(output_seqs.view(-1, output_seqs.shape[-1]))
output_seqs = output_seqs.view(B, S, -1)
# ==================
# Decompose outputs into the appropriate output shape
# ==================
# The container for all of our outputs
outputs = dict()
for task_name, task_head in self.heads.items():
task_outputs = []
task_probs = []
for output_seq, query_mask, frame_shapes, example in zip(output_seqs, output_masks, output_shapes, batch):
output_seq = output_seq[query_mask] # only want valid values we actually requested
seq_outputs = []
seq_probs = []
frame_sizes = [
np.reshape(pos_shape_seq, [-1, 2]).prod(axis=1).sum()
for pos_shape_seq in frame_shapes
]
output_frame_seqs = torch.split(output_seq, frame_sizes)
for output_frame_seq, frame_shape, frame in zip(output_frame_seqs, frame_shapes, example["frames"]):
if self.backbone.has_decoder:
# Rearrange token subsequence into image shaped tensor
height, width = frame_shape
output = einops.rearrange(
output_frame_seq,
"(height width) chan -> chan height width",
height=height, width=width,
)
else:
max_mode_size = (
max([h for h, _ in frame_shape]),
max([w for _, w in frame_shape]),
)
mode_sizes = [h * w for h, w in frame_shape]
output_mode_seqs = torch.split(output_frame_seq, mode_sizes)
output = torch.mean(torch.concat([
torch.nn.functional.interpolate(
einops.rearrange(
mode_seq,
"(height width) chan -> 1 chan height width",
height=mode_height, width=mode_width,
),
size=max_mode_size,
mode='bilinear',
align_corners=True,
)
for mode_seq, (mode_height, mode_width) in zip(output_mode_seqs, frame_shape)
], dim=0), dim=0)
# # If we might need to upsample our predictions
# output = nn.functional.upsample_bilinear(output[None], size=[tar_height, tar_width])[0]
# Compute task preds
output = task_head(output[None])[0]
# Clip to desired shape
tar_height, tar_width = frame["output_dims"]
output = output[:, :tar_height, :tar_width]
probs = einops.rearrange(output, "chan height width -> height width chan")
if task_name == "change":
probs = probs.sigmoid()[..., 1]
else:
probs = probs.softmax(dim=-1)
seq_outputs.append(output)
seq_probs.append(probs)
task_outputs.append(seq_outputs)
task_probs.append(seq_probs)
outputs[task_name] = task_outputs
outputs[f"{task_name}_probs"] = task_probs
return outputs
[docs]
def shared_step(self, batch, batch_idx=None, stage="train", with_loss=True):
"""
Example:
>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>> encoder_depth=1,
>>> decoder_depth=1,
>>> dim=position_encoder.output_dim + 16,
>>> queries_dim=position_encoder.output_dim,
>>> logits_dim=16,
>>> cross_heads=1,
>>> latent_heads=1,
>>> cross_dim_head=1,
>>> latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>> classes=classes,
>>> dataset_stats=dataset_stats,
>>> input_sensorchan=channels,
>>> position_encoder=position_encoder,
>>> decoder="trans_conv",
>>> backbone=backbone,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()
Example:
>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>> encoder_depth=1,
>>> decoder_depth=1,
>>> dim=position_encoder.output_dim + 16,
>>> queries_dim=position_encoder.output_dim,
>>> logits_dim=16,
>>> cross_heads=1,
>>> latent_heads=1,
>>> cross_dim_head=1,
>>> latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>> classes=classes,
>>> dataset_stats=dataset_stats,
>>> input_sensorchan=channels,
>>> position_encoder=position_encoder,
>>> decoder="trans_conv",
>>> backbone=backbone,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> batch += [None]
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()
Example:
>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>> encoder_depth=1,
>>> decoder_depth=1,
>>> dim=position_encoder.output_dim + 16,
>>> queries_dim=position_encoder.output_dim,
>>> logits_dim=16,
>>> cross_heads=1,
>>> latent_heads=1,
>>> cross_dim_head=1,
>>> latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>> classes=classes,
>>> dataset_stats=dataset_stats,
>>> input_sensorchan=channels,
>>> position_encoder=position_encoder,
>>> decoder="trans_conv",
>>> backbone=backbone,
>>> )
>>> batch = model.demo_batch(width=64, height=65)
>>> for cutoff in [-1, -2]:
>>> degraded_example = model.demo_batch(width=55, height=75, num_timesteps=3)[0]
>>> degraded_example["frames"] = degraded_example["frames"][:cutoff]
>>> batch += [degraded_example]
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()
Example:
>>> from geowatch.tasks import fusion
>>> import torch
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = fusion.methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>> encoder_depth=1,
>>> decoder_depth=1,
>>> dim=position_encoder.output_dim + 16,
>>> queries_dim=position_encoder.output_dim,
>>> logits_dim=16,
>>> cross_heads=1,
>>> latent_heads=1,
>>> cross_dim_head=1,
>>> latent_dim_head=1,
>>> )
>>> channels, classes, dataset_stats = fusion.methods.HeterogeneousModel.demo_dataset_stats()
>>> model = fusion.methods.HeterogeneousModel(
>>> classes=classes,
>>> dataset_stats=dataset_stats,
>>> input_sensorchan=channels,
>>> position_encoder=position_encoder,
>>> decoder="trans_conv",
>>> backbone=backbone,
>>> )
>>> batch = model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=0.1)
>>> batch += model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=0.5)
>>> batch += model.demo_batch(batch_size=1, width=64, height=65, num_timesteps=3, nans=1.0)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()
"""
# FIXME: why are we getting nones here?
batch = [
ex
for ex in batch
if (ex is not None)
# and (len(ex["frames"]) > 0)
]
batch_size = len(batch)
outputs = self(batch)
if outputs is None:
return None
if not with_loss:
return outputs
frame_losses = []
for task_name in self.heads:
for pred_seq, example in zip(outputs[task_name], batch):
for pred, frame in zip(pred_seq, example["frames"]):
task_labels_key = self.task_to_keynames[task_name]["labels"]
labels = frame[task_labels_key]
self.log(f"{stage}_{task_name}_logit_mean", pred.mean(),
batch_size=batch_size, rank_zero_only=True)
if labels is None:
continue
# FIXME: This is necessary because sometimes when data.input_space_scale==native, label shapes and output_dims dont match!
if pred.shape[1:] != labels.shape:
pred = nn.functional.interpolate(
pred[None],
size=labels.shape,
mode="bilinear",
)[0]
task_weights_key = self.task_to_keynames[task_name]["weights"]
task_weights = frame[task_weights_key]
valid_mask = (task_weights > 0.)
pred_ = pred[:, valid_mask]
task_weights_ = task_weights[valid_mask]
criterion = self.criterions[task_name]
if criterion.target_encoding == 'index':
loss_labels = labels.long()
loss_labels_ = loss_labels[valid_mask]
elif criterion.target_encoding == 'onehot':
# Note: 1HE is much easier to work with
labels_ = labels.long()
has_ignore = labels_.min() < 0
if has_ignore:
ignore_flags = labels_ < 0
ohe_size = criterion.in_channels + 1
labels_ = labels_.clone()
labels_[ignore_flags] = criterion.in_channels
else:
ohe_size = criterion.in_channels
loss_labels = kwarray.one_hot_embedding(
labels_,
ohe_size,
dim=0)
loss_labels_ = loss_labels[:, valid_mask]
if has_ignore:
# FIXME: inefficient to just drop these, but it
# should work.
# could improve kwarray.one_hot_embedding to
# allow the user to specify an ignore_index
loss_labels_ = loss_labels_[:-1, ...]
else:
raise KeyError(criterion.target_encoding)
loss = criterion(
pred_[None],
loss_labels_[None],
)
if loss.isnan().any():
print('!!!!!!!!!!!!!!!!!!!')
print('!!!!!!!!!!!!!!!!!!!')
print('Discovered NaN loss')
print('loss = {}'.format(ub.urepr(loss, nl=1)))
print('pred = {}'.format(ub.urepr(pred, nl=1)))
print('frame = {}'.format(ub.urepr(frame, nl=1)))
print('!!!!!!!!!!!!!!!!!!!')
print('!!!!!!!!!!!!!!!!!!!')
loss *= task_weights_
frame_losses.append(
self.global_head_weights[task_name] * loss.mean()
)
LOG_METRICS = 0 # FIXME: recent (30c8974d6d6) update broke this, why?
if LOG_METRICS:
metric_values = self.head_metrics[f"{stage}_stage"][task_name](
pred.argmax(dim=0).flatten(),
# pred[None],
labels.flatten().long(),
)
self.log_dict(
metric_values,
prog_bar=True,
batch_size=batch_size,
)
outputs["loss"] = sum(frame_losses) / len(frame_losses)
self.log(f"{stage}_loss", outputs["loss"], prog_bar=True, batch_size=batch_size)
return outputs
# def shared_step(self, batch, batch_idx=None, with_loss=True):
# outputs = {
# "change_probs": [
# [
# 0.5 * torch.ones(*frame["output_dims"])
# for frame in example["frames"]
# if frame["change"] != None
# ]
# for example in batch
# ],
# "saliency_probs": [
# [
# torch.ones(*frame["output_dims"], 2).sigmoid()
# for frame in example["frames"]
# ]
# for example in batch
# ],
# "class_probs": [
# [
# torch.ones(*frame["output_dims"], self.num_classes).softmax(dim=-1)
# for frame in example["frames"]
# ]
# for example in batch
# ],
# }
# if with_loss:
# outputs["loss"] = self.dummy_param
# return outputs
[docs]
@profile
def training_step(self, batch, batch_idx=None):
outputs = self.shared_step(batch, batch_idx=batch_idx, stage='train')
return outputs
[docs]
@profile
def validation_step(self, batch, batch_idx=None):
outputs = self.shared_step(batch, batch_idx=batch_idx, stage='val')
return outputs
[docs]
@profile
def test_step(self, batch, batch_idx=None):
outputs = self.shared_step(batch, batch_idx=batch_idx, stage='test')
return outputs
[docs]
@profile
def predict_step(self, batch, batch_idx=None):
outputs = self.shared_step(batch, batch_idx=batch_idx, stage='predict',
with_loss=False)
return outputs
# this is a special thing for the predict step
forward_step = shared_step
[docs]
def log_grad_norm(self, grad_norm_dict) -> None:
"""Override this method to change the default behaviour of ``log_grad_norm``.
Overloads log_grad_norm so we can supress the batch_size warning
"""
self.log_dict(grad_norm_dict, on_step=True, on_epoch=True,
prog_bar=False, logger=True,
batch_size=self._prev_batch_size)
[docs]
def save_package(self, package_path, context=None, verbose=1):
"""
CommandLine:
xdoctest -m geowatch.tasks.fusion.methods.heterogeneous HeterogeneousModel.save_package
Example:
>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> #from geowatch.tasks.fusion.methods.heterogeneous import * # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>> encoder_depth=1,
>>> decoder_depth=1,
>>> dim=position_encoder.output_dim + 16,
>>> queries_dim=position_encoder.output_dim,
>>> logits_dim=16,
>>> cross_heads=1,
>>> latent_heads=1,
>>> cross_dim_head=1,
>>> latent_dim_head=1,
>>> )
>>> model = self = methods.HeterogeneousModel(
>>> position_encoder=position_encoder,
>>> input_sensorchan=5,
>>> decoder="upsample",
>>> backbone=backbone,
>>> )
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.HeterogeneousModel.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = model.state_dict()
>>> assert recon is not model
>>> assert set(recon_state) == set(recon_state)
>>> for key in recon_state.keys():
>>> assert (model_state[key] == recon_state[key]).all()
>>> assert model_state[key] is not recon_state[key]
Example:
>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> #from geowatch.tasks.fusion.methods.heterogeneous import * # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>> encoder_depth=1,
>>> decoder_depth=1,
>>> dim=position_encoder.output_dim + 16,
>>> queries_dim=position_encoder.output_dim,
>>> logits_dim=16,
>>> cross_heads=1,
>>> latent_heads=1,
>>> cross_dim_head=1,
>>> latent_dim_head=1,
>>> )
>>> model = self = methods.HeterogeneousModel(
>>> position_encoder=position_encoder,
>>> input_sensorchan=5,
>>> decoder="simple_conv",
>>> backbone=backbone,
>>> )
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.HeterogeneousModel.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = model.state_dict()
>>> assert recon is not model
>>> assert set(recon_state) == set(recon_state)
>>> for key in recon_state.keys():
>>> assert (model_state[key] == recon_state[key]).all()
>>> assert model_state[key] is not recon_state[key]
Example:
>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> #from geowatch.tasks.fusion.methods.heterogeneous import * # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>> encoder_depth=1,
>>> decoder_depth=1,
>>> dim=position_encoder.output_dim + 16,
>>> queries_dim=position_encoder.output_dim,
>>> logits_dim=16,
>>> cross_heads=1,
>>> latent_heads=1,
>>> cross_dim_head=1,
>>> latent_dim_head=1,
>>> )
>>> model = self = methods.HeterogeneousModel(
>>> position_encoder=position_encoder,
>>> input_sensorchan=5,
>>> decoder="trans_conv",
>>> backbone=backbone,
>>> )
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.HeterogeneousModel.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = model.state_dict()
>>> assert recon is not model
>>> assert set(recon_state) == set(recon_state)
>>> for key in recon_state.keys():
>>> assert (model_state[key] == recon_state[key]).all()
>>> assert model_state[key] is not recon_state[key]
Ignore:
7z l $HOME/.cache/geowatch/tests/package/my_package.pt
"""
self._save_package(package_path, context=context, verbose=verbose)
# hack because of inheritence rules
configure_optimizers = WatchModuleMixins.configure_optimizers