import pytorch_lightning as pl
import torch
from torch import nn
import torch.nn.functional as F
import torchmetrics
import einops
import kwcoco
import kwarray
import ubelt as ub
from geowatch import heuristics
from geowatch.tasks.fusion.methods.network_modules import coerce_criterion
from geowatch.tasks.fusion.methods.network_modules import RobustModuleDict
from geowatch.tasks.fusion.methods.watch_module_mixins import WatchModuleMixins
from geowatch.tasks.fusion.architectures import unet_blur
from geowatch.utils.util_netharn import InputNorm
from typing import Dict, Any, Optional
try:
import xdev
profile = xdev.profile
except Exception:
profile = ub.identity
[docs]
class NanToNum(nn.Module):
"""
Module which converts NaN values in input tensors to numbers.
"""
def __init__(self, num=0.0):
super().__init__()
self.num = num
[docs]
def forward(self, x):
return torch.nan_to_num(x, self.num)
[docs]
class UNetBaseline(pl.LightningModule, WatchModuleMixins):
_HANDLES_NANS = True
[docs]
def get_cfgstr(self):
cfgstr = f'{self.hparams.name}_unet'
return cfgstr
def __init__(
self,
classes=10,
dataset_stats=None,
input_sensorchan=None,
token_dim: int = 32,
name: str = "unnamed_model",
class_weights: str = "auto",
saliency_weights: str = "auto",
positive_change_weight: float = 1.0,
negative_change_weight: float = 1.0,
global_class_weight: float = 1.0,
global_change_weight: float = 1.0,
global_saliency_weight: float = 1.0,
change_loss: str = "cce", # TODO: replace control string with a module, possibly a subclass
class_loss: str = "focal", # TODO: replace control string with a module, possibly a subclass
saliency_loss: str = "focal", # TODO: replace control string with a module, possibly a subclass
ohem_ratio: Optional[float] = None,
focal_gamma: Optional[float] = 2.0,
):
"""
Args:
name: Specify a name for the experiment. (Unsure if the Model is the place for this)
token_width: Width of each square token.
token_dim: Dimensionality of each computed token.
spatial_scale_base: The scale assigned to each token equals `scale_base / token_density`, where the token density is the number of tokens along a given axis.
temporal_scale_base: The scale assigned to each token equals `scale_base / token_density`, where the token density is the number of tokens along a given axis.
class_weights: Class weighting strategy.
saliency_weights: Class weighting strategy.
Example:
>>> # Note: it is important that the non-kwargs are saved as hyperparams
>>> from geowatch.tasks.fusion.methods.unet_baseline import UNetBaseline
>>> model = UNetBaseline(
>>> input_sensorchan='r|g|b',
>>> )
"""
super().__init__()
self.save_hyperparameters()
input_stats = self.set_dataset_specific_attributes(input_sensorchan, dataset_stats)
self.classes = kwcoco.CategoryTree.coerce(classes)
self.num_classes = len(self.classes)
# TODO: this data should be introspectable via the kwcoco file
hueristic_background_keys = heuristics.BACKGROUND_CLASSES
# FIXME: case sensitivity
hueristic_ignore_keys = heuristics.IGNORE_CLASSNAMES
if self.class_freq is not None:
all_keys = set(self.class_freq.keys())
else:
all_keys = set(self.classes)
self.background_classes = all_keys & hueristic_background_keys
self.ignore_classes = all_keys & hueristic_ignore_keys
self.foreground_classes = (all_keys - self.background_classes) - self.ignore_classes
# hueristic_ignore_keys.update(hueristic_occluded_keys)
self.saliency_num_classes = 2
# criterion and metrics
# TODO: parametarize loss criterions
# For loss function experiments, see and work in
# ~/code/watch/watch/tks/fusion/methods/sequence_aware.py
# self.change_criterion = monai.losses.FocalLoss(reduction='none', to_onehot_y=False)
self.saliency_weights = self._coerce_saliency_weights(saliency_weights)
self.class_weights = self._coerce_class_weights(class_weights)
self.change_weights = torch.FloatTensor([
self.hparams.negative_change_weight,
self.hparams.positive_change_weight
])
self.sensor_channel_tokenizers = RobustModuleDict()
# Unique sensor modes obviously isn't very correct here.
# We should fix that, but let's hack it so it at least
# includes all sensor modes we probably will need.
if input_stats is not None:
sensor_modes = set(self.unique_sensor_modes) | set(input_stats.keys())
else:
sensor_modes = set(self.unique_sensor_modes)
for s, c in sensor_modes:
mode_code = kwcoco.FusedChannelSpec.coerce(c)
# For each mode make a network that should learn to tokenize
in_chan = mode_code.numel()
if input_stats is None:
input_norm = InputNorm()
else:
stats = input_stats.get((s, c), None)
if stats is None:
input_norm = InputNorm()
else:
input_norm = InputNorm(
**(ub.udict(stats) & {'mean', 'std'}))
# key = sanitize_key(str((s, c)))
key = f'{s}:{c}'
self.sensor_channel_tokenizers[key] = nn.Sequential(
input_norm,
NanToNum(0.0),
unet_blur.UNet(in_chan, token_dim),
)
self.backbone = nn.Sequential(
nn.ReLU(),
nn.Conv3d(token_dim, token_dim, (2, 5, 5), padding="same"),
nn.BatchNorm3d(token_dim),
nn.ReLU(),
nn.Conv3d(token_dim, token_dim, (2, 5, 5), padding="same"),
nn.BatchNorm3d(token_dim),
nn.ReLU(),
nn.Conv3d(token_dim, token_dim, (2, 5, 5), padding="same"),
nn.BatchNorm3d(token_dim),
nn.ReLU(),
)
self.criterions = torch.nn.ModuleDict()
self.heads = torch.nn.ModuleDict()
self.task_to_keynames = {
'change': {
'labels': 'change',
'weights': 'change_weights',
'output_dims': 'change_output_dims'
},
'saliency': {
'labels': 'saliency',
'weights': 'saliency_weights',
'output_dims': 'saliency_output_dims'
},
'class': {
'labels': 'class_idxs',
'weights': 'class_weights',
'output_dims': 'class_output_dims'
},
}
head_properties = [
{
'name': 'change',
'channels': 2,
'loss': self.hparams.change_loss,
'weights': self.change_weights,
},
{
'name': 'saliency',
'channels': self.saliency_num_classes,
'loss': self.hparams.saliency_loss,
'weights': self.saliency_weights,
},
{
'name': 'class',
'channels': self.num_classes,
'loss': self.hparams.class_loss,
'weights': self.class_weights,
},
]
self.global_head_weights = {
'class': global_class_weight,
'change': global_change_weight,
'saliency': global_saliency_weight,
}
for prop in head_properties:
head_name = prop['name']
global_weight = self.global_head_weights[head_name]
if global_weight > 0:
self.criterions[head_name] = coerce_criterion(prop['loss'],
prop['weights'],
ohem_ratio=ohem_ratio,
focal_gamma=focal_gamma)
self.heads[head_name] = unet_blur.UNet(token_dim, prop["channels"])
FBetaScore = torchmetrics.FBetaScore
class_metrics = torchmetrics.MetricCollection({
"class_acc": torchmetrics.Accuracy(num_classes=self.num_classes, task='multiclass'),
# "class_iou": torchmetrics.IoU(2),
'class_f1_micro': FBetaScore(beta=1.0, threshold=0.5, average='micro', num_classes=self.num_classes, task='multiclass'),
'class_f1_macro': FBetaScore(beta=1.0, threshold=0.5, average='macro', num_classes=self.num_classes, task='multiclass'),
})
change_metrics = torchmetrics.MetricCollection({
"change_acc": torchmetrics.Accuracy(task="binary"),
# "iou": torchmetrics.IoU(2),
'change_f1': FBetaScore(beta=1.0, task="binary"),
})
saliency_metrics = torchmetrics.MetricCollection({
'saliency_f1': FBetaScore(beta=1.0, task="binary"),
})
self.head_metrics = nn.ModuleDict({
f"{stage}_stage": nn.ModuleDict({
"class": class_metrics.clone(prefix=f"{stage}_"),
"change": change_metrics.clone(prefix=f"{stage}_"),
"saliency": saliency_metrics.clone(prefix=f"{stage}_"),
})
for stage in ["train", "val", "test"]
})
[docs]
def process_frame(self, frame) -> Dict[str, Dict[str, Any]]:
configs = {
"change": {
"data": "change",
"weights": "change_weights",
"output_dims": "change_output_dims",
"time_index": "time_index",
},
"saliency": {
"data": "saliency",
"weights": "saliency_weights",
"output_dims": "saliency_output_dims",
"time_index": "time_index",
},
"class": {
"data": "class_idxs",
"weights": "class_weights",
"output_dims": "class_output_dims",
"time_index": "time_index",
},
}
outputs = dict()
for name, config in configs.items():
# if frame[config["data"]] is not None:
output = {
key: frame[value]
for key, value in config.items()
}
if output["output_dims"] is None:
if output["data"] is not None:
output["output_dims"] = list(output["data"].shape)
else:
output["output_dims"] = frame["output_dims"]
outputs[name] = output
for mode_name, mode_val in frame["modes"].items():
outputs[f"{frame['sensor']}:{mode_name}"] = {
"data": mode_val,
"weights": None,
"output_dims": list(mode_val.shape[1:]),
"time_index": frame["time_index"],
}
return outputs
[docs]
def process_example(self, example):
return [
self.process_frame(frame)
for frame in example["frames"]
]
[docs]
def process_batch(self, batch):
return [
self.process_example(example)
for example in batch
if example is not None
]
[docs]
def encode_frame(self, processed_frame):
return {
key: self.sensor_channel_tokenizers[key](data["data"][None])[0] # shape=[C, H, W]
for key, data in processed_frame.items()
if key in self.sensor_channel_tokenizers.keys()
} # length = num_modes
[docs]
def encode_example(self, processed_example):
return torch.stack([
torch.stack(
list(frame.values()), # shape=[num_modes, C, H, W]
dim=0,
).mean(dim=0) # shape=[C, H, W]
for frame in map(self.encode_frame, processed_example)
], dim=1) # shape=[C, num_frames, H, W]
[docs]
def encode_batch(self, processed_batch):
encoded_examples = list(map(self.encode_example, processed_batch))
C, T, H, W = torch.max(torch.stack([
torch.tensor(ex.shape)
for ex in encoded_examples
]), dim=0).values
encoded_examples = [
F.pad(
ex,
# F.pad pairs padding values IN REVERSE ORDER, below is correct
(
0, 0, # W-ex.shape[3],
0, 0, # H-ex.shape[2],
0, T - ex.shape[1],
0, 0, # C-ex.shape[0],
),
mode="constant", value=0,
)
for ex in encoded_examples
]
return torch.stack(encoded_examples, dim=0)
[docs]
def forward(self, batch):
"""
Example:
>>> from geowatch.tasks import fusion
>>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats()
>>> model = fusion.methods.UNetBaseline(
>>> classes=classes,
>>> dataset_stats=dataset_stats,
>>> input_sensorchan=channels,
>>> )
>>> batch = model.demo_batch(width=64, height=64)
>>> outputs = model.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>> if "probs" in task_key: continue
>>> if task_key == "class": task_key = "class_idxs"
>>> for task_pred, example in zip(task_outputs, batch):
>>> for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>> if (frame_idx == 0) and task_key.startswith("change"): continue
>>> assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
"""
processed_batch = self.process_batch(batch)
encoded_batch = self.encode_batch(processed_batch)
output_seqs = self.backbone(encoded_batch)
# decompose outputs
outputs = dict()
for task_name, task_head in self.heads.items():
task_outputs = []
task_probs = []
for output_seq, example in zip(output_seqs, batch):
output_seq = einops.rearrange(output_seq, "chan time height width -> time chan height width")
seq_outputs = []
seq_probs = []
for output, frame in zip(output_seq, example["frames"]):
output = task_head(output[None])[0]
probs = einops.rearrange(output, "chan height width -> height width chan")
if task_name == "change":
probs = probs.sigmoid()[..., 1]
else:
probs = probs.softmax(dim=-1)
seq_outputs.append(output)
seq_probs.append(probs)
task_outputs.append(seq_outputs)
task_probs.append(seq_probs)
outputs[task_name] = task_outputs
outputs[f"{task_name}_probs"] = task_probs
return outputs
[docs]
def shared_step(self, batch, batch_idx=None, stage="train", with_loss=True):
"""
Example:
>>> # xdoctest: +REQUIRES(env:SLOW_TESTS)
>>> from geowatch.tasks import fusion
>>> import torch
>>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats()
>>> model = fusion.methods.UNetBaseline(
>>> classes=classes,
>>> dataset_stats=dataset_stats,
>>> input_sensorchan=channels,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()
Example:
>>> # xdoctest: +REQUIRES(env:SLOW_TESTS)
>>> from geowatch.tasks import fusion
>>> import torch
>>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats()
>>> model = fusion.methods.UNetBaseline(
>>> classes=classes,
>>> dataset_stats=dataset_stats,
>>> input_sensorchan=channels,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> batch += [None]
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()
Example:
>>> from geowatch.tasks import fusion
>>> import torch
>>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats()
>>> model = fusion.methods.UNetBaseline(
>>> classes=classes, token_dim=2,
>>> dataset_stats=dataset_stats,
>>> input_sensorchan=channels,
>>> )
>>> batch = model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=0.1)
>>> batch += model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=0.5)
>>> batch += model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=1.0)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()
"""
# FIXME: why are we getting nones here?
batch = [
ex
for ex in batch
if (ex is not None)
# and (len(ex["frames"]) > 0)
]
outputs = self(batch)
if not with_loss:
return outputs
frame_losses = []
for task_name in self.heads:
for pred_seq, example in zip(outputs[task_name], batch):
for pred, frame in zip(pred_seq, example["frames"]):
task_labels_key = self.task_to_keynames[task_name]["labels"]
labels = frame[task_labels_key]
self.log(f"{stage}_{task_name}_logit_mean", pred.mean())
if labels is None:
continue
# FIXME: This is necessary because sometimes when data.input_space_scale==native, label shapes and output_dims dont match!
if pred.shape[1:] != labels.shape:
pred = nn.functional.interpolate(
pred[None],
size=labels.shape,
mode="bilinear",
)[0]
task_weights_key = self.task_to_keynames[task_name]["weights"]
task_weights = frame[task_weights_key]
valid_mask = (task_weights > 0.)
pred_ = pred[:, valid_mask]
task_weights_ = task_weights[valid_mask]
criterion = self.criterions[task_name]
if criterion.target_encoding == 'index':
loss_labels = labels.long()
loss_labels_ = loss_labels[valid_mask]
elif criterion.target_encoding == 'onehot':
# Note: 1HE is much easier to work with
loss_labels = kwarray.one_hot_embedding(
labels.long(),
criterion.in_channels,
dim=0)
loss_labels_ = loss_labels[:, valid_mask]
else:
raise KeyError(criterion.target_encoding)
loss = criterion(
pred_[None],
loss_labels_[None],
)
if loss.isnan().any():
print(loss)
print(pred)
print(frame)
loss *= task_weights_
frame_losses.append(
self.global_head_weights[task_name] * loss.mean()
)
self.log_dict(
self.head_metrics[f"{stage}_stage"][task_name](
pred.argmax(dim=0).flatten(),
# pred[None],
labels.flatten().long(),
),
prog_bar=True,
)
outputs["loss"] = sum(frame_losses) / len(frame_losses)
self.log(f"{stage}_loss", outputs["loss"], prog_bar=True)
return outputs
# def shared_step(self, batch, batch_idx=None, with_loss=True):
# outputs = {
# "change_probs": [
# [
# 0.5 * torch.ones(*frame["output_dims"])
# for frame in example["frames"]
# if frame["change"] != None
# ]
# for example in batch
# ],
# "saliency_probs": [
# [
# torch.ones(*frame["output_dims"], 2).sigmoid()
# for frame in example["frames"]
# ]
# for example in batch
# ],
# "class_probs": [
# [
# torch.ones(*frame["output_dims"], self.num_classes).softmax(dim=-1)
# for frame in example["frames"]
# ]
# for example in batch
# ],
# }
# if with_loss:
# outputs["loss"] = self.dummy_param
# return outputs
[docs]
@profile
def training_step(self, batch, batch_idx=None):
outputs = self.shared_step(batch, batch_idx=batch_idx, stage='train')
return outputs
[docs]
@profile
def validation_step(self, batch, batch_idx=None):
outputs = self.shared_step(batch, batch_idx=batch_idx, stage='val')
return outputs
[docs]
@profile
def test_step(self, batch, batch_idx=None):
outputs = self.shared_step(batch, batch_idx=batch_idx, stage='test')
return outputs
[docs]
@profile
def predict_step(self, batch, batch_idx=None):
outputs = self.shared_step(batch, batch_idx=batch_idx, stage='predict',
with_loss=False)
return outputs
# this is a special thing for the predict step
forward_step = shared_step
[docs]
def save_package(self, package_path, verbose=1):
"""
CommandLine:
xdoctest -m geowatch.tasks.fusion.methods.unet_baseline UNetBaseline.save_package
Example:
>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> #from geowatch.tasks.fusion.methods.unet_baseline import * # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> model = self = methods.UNetBaseline(
>>> input_sensorchan=5,
>>> )
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.UNetBaseline.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = model.state_dict()
>>> assert recon is not self
>>> assert set(recon_state) == set(recon_state)
>>> from geowatch.utils.util_kwarray import torch_array_equal
>>> for key in recon_state.keys():
>>> v1 = model_state[key]
>>> v2 = recon_state[key]
>>> if not torch.allclose(v1, v2, equal_nan=True):
>>> print('v1 = {}'.format(ub.urepr(v1, nl=1)))
>>> print('v2 = {}'.format(ub.urepr(v2, nl=1)))
>>> raise AssertionError(f'Difference in key={key}')
>>> assert v1 is not v2, 'should be distinct copies'
Example:
>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> #from geowatch.tasks.fusion.methods.unet_baseline import * # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> model = self = methods.UNetBaseline(
>>> input_sensorchan=5,
>>> )
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.UNetBaseline.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = self.state_dict()
>>> assert recon is not self
>>> assert set(recon_state) == set(recon_state)
>>> from geowatch.utils.util_kwarray import torch_array_equal
>>> for key in recon_state.keys():
>>> v1 = model_state[key]
>>> v2 = recon_state[key]
>>> if not torch.allclose(v1, v2, equal_nan=True):
>>> print('v1 = {}'.format(ub.urepr(v1, nl=1)))
>>> print('v2 = {}'.format(ub.urepr(v2, nl=1)))
>>> raise AssertionError(f'Difference in key={key}')
>>> assert v1 is not v2, 'should be distinct copies'
Example:
>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> #from geowatch.tasks.fusion.methods.unet_baseline import * # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> model = self = methods.UNetBaseline(
>>> input_sensorchan=5,
>>> )
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.UNetBaseline.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = self.state_dict()
>>> assert recon is not self
>>> assert set(recon_state) == set(recon_state)
>>> from geowatch.utils.util_kwarray import torch_array_equal
>>> for key in recon_state.keys():
>>> v1 = model_state[key]
>>> v2 = recon_state[key]
>>> if not torch.allclose(v1, v2, equal_nan=True):
>>> print('v1 = {}'.format(ub.urepr(v1, nl=1)))
>>> print('v2 = {}'.format(ub.urepr(v2, nl=1)))
>>> raise AssertionError(f'Difference in key={key}')
>>> assert v1 is not v2, 'should be distinct copies'
"""
self._save_package(package_path, verbose=verbose)