"""
Example:
from geowatch.tasks.fusion.datamodules.network_io import RGBImageBatchItem
item0 = RGBImageBatchItem.demo(index=0)
item1 = RGBImageBatchItem.demo(index=1)
batch_items = [item0, item1]
self.imdata_chw.shape
self.channels
from geowatch.tasks.fusion.methods.torchvision_efficientnet import * # NOQA
self = FCNResNet50()
batch = self.collate(batch_items)
out = self.forward(batch)
self = EfficientNetB7()
out = self.forward(batch)
"""
import torch
import pytorch_lightning as pl
# from torch.nn import functional as F
from geowatch.tasks.fusion.methods import heads as heads_module
from geowatch.tasks.fusion.methods.watch_module_mixins import WatchModuleMixins
from geowatch.utils.util_netharn import InputNorm, Identity
import ubelt as ub # NOQA
from geowatch.tasks.fusion.datamodules import network_io
from typing import Iterator # NOQA
# from geowatch.tasks.fusion.methods.loss import coerce_criterion
# from geowatch.tasks.fusion.methods.heads import TaskHeads
try:
from line_profiler import profile
except Exception:
profile = ub.identity
[docs]
class TorchvisionWrapper(pl.LightningModule):
# def __init__(self, heads):
# super().__init__()
# self.ots_model = self.define_ots_model()
# self.heads = heads_module.TaskHeads(heads)
[docs]
def define_ots_model(self):
"""
This should define the backbone model.
"""
raise NotImplementedError('Child class must define this')
[docs]
def forward(self, batch):
imdata_bchw = batch['imdata_bchw']
# self.ots_model.features.forward(imdata_bchw).shape
out = self.ots_model.forward(imdata_bchw)
return out
[docs]
def collate(self, batch_items):
imdatas = [batch_item.imdata_chw for batch_item in batch_items]
imdata_bchw = torch.stack(imdatas)
nonlocal_class_ohes = [batch_item.nonlocal_class_ohe for batch_item in batch_items]
nonlocal_class_ohe = torch.stack(nonlocal_class_ohes)
batch = {
'imdata_bchw': imdata_bchw,
'nonlocal_class_ohe': nonlocal_class_ohe,
}
return batch
[docs]
class TorchvisionSegmentationWrapper(TorchvisionWrapper):
...
[docs]
class TorchvisionClassificationWrapper(TorchvisionWrapper):
...
[docs]
class TorchvisionDetectionWrapper(TorchvisionWrapper):
...
[docs]
class EfficientNetB7(TorchvisionClassificationWrapper):
[docs]
def define_ots_model(self):
import torchvision
ots_model = torchvision.models.efficientnet_b7(weights='EfficientNet_B7_Weights.IMAGENET1K_V1')
# adaptpool = ots_model.avgpool # adaptive
# head = ots_model.classifier # 0x1000 classifier
# out_feat = ots_model.features[-1] # 640x2560 output features.
# backbone = ots_model.features[0:] # 640x2560 output features.
# stem = ots_model.features[0] # 3x64 input stem
return ots_model
[docs]
class FCNResNet50(TorchvisionSegmentationWrapper, WatchModuleMixins):
"""
Ignore:
from geowatch.tasks.fusion.datamodules.network_io import RGBImageBatchItem
item1 = RGBImageBatchItem.demo()
item2 = RGBImageBatchItem.demo()
batch_items = [item1, item2]
from geowatch.tasks.fusion.methods.torchvision_nets import * # NOQA
heads = ub.codeblock(
'''
feat_dim: 2048
tasks:
# Mirrors the simple FCNHead in torchvision
- name: saliency
type: mlp
hidden_channels: [256]
out_channels: 2
dropout: 0.1
norm: batch
loss:
type: focal
gamma: 2.0
global_weight: 1.0
''')
self = FCNResNet50(
heads=head_text,
)
batch = self.collate(batch_items)
Ignore:
>>> # Test with datamodule
>>> import ubelt as ub
>>> from geowatch.tasks.fusion import datamodules
>>> datamodule = datamodules.kwcoco_video_data.KWCocoVideoDataModule(
>>> train_dataset='special:vidshapes8', chip_size=32,
>>> batch_size=1, time_steps=1, num_workers=2, normalize_inputs=10)
>>> datamodule.setup('fit')
>>> dataset_stats = datamodule.torch_datasets['train'].cached_dataset_stats(num=3)
>>> classes = datamodule.torch_datasets['train'].classes
>>> # Use one of our fusion.architectures in a test
>>> self = FCNResNet50(
>>> classes=classes,
>>> dataset_stats=dataset_stats, input_sensorchan=datamodule.input_sensorchan)
"""
def __init__(self, heads, classes=10, dataset_stats=None):
super().__init__()
self.ots_model = self.define_ots_model()
feat_dim = self.ots_model.backbone.layer4[2].conv3.out_channels
self.automatic_optimization = True
assert feat_dim == 2048, 'hard coded sanity check'
# import kwcoco
# self.classes = kwcoco.CategoryTree.coerce(classes)
self.num_classes = len(self.classes)
self.heads = heads_module.TaskHeads(heads)
[docs]
def define_ots_model(self):
import torchvision
ots_model = torchvision.models.segmentation.fcn_resnet50(weights='FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1')
# num_classes = len(self.classes)
# torchvision.models.segmentation.FCNHead(2048, num_classes)
return ots_model
[docs]
def forward(self, batch):
# imdata_bchw = batch['imdata_bchw']
# input_hw = imdata_bchw.shape[-2:]
# self.ots_model.features.forward(imdata_bchw).shape
# downscaled_feats = self.ots_model.backbone.forward(imdata_bchw)['out']
# downscaled_task_outs = self.heads(downscaled_feats)
raise NotImplementedError
# x = F.interpolate(downscaled_feats, size=input_hw, mode="bilinear", align_corners=False)
# out = self.ots_model.forward(imdata_bchw)['out']
# return out
[docs]
def save_package(self, package_path, verbose=1):
self._save_package(package_path, verbose=verbose)
[docs]
def forward_step(self, batch, batch_idx=None, with_loss=True):
raise NotImplementedError
# outputs = {
# "change_probs": [
# [
# 0.5 * torch.ones(*frame["output_dims"])
# for frame in example["frames"]
# if frame["change"] is not None
# ]
# for example in batch
# ],
# "saliency_probs": [
# [
# torch.ones(*frame["output_dims"], 2).sigmoid()
# for frame in example["frames"]
# ]
# for example in batch
# ],
# "class_probs": [
# [
# torch.ones(*frame["output_dims"], self.num_classes).softmax(dim=-1)
# for frame in example["frames"]
# ]
# for example in batch
# ],
# }
# if with_loss:
# outputs["loss"] = self.dummy_param
# return outputs
[docs]
class Resnet50(TorchvisionClassificationWrapper, WatchModuleMixins):
"""
Ignore:
>>> import ubelt as ub
>>> from geowatch.tasks.fusion import datamodules
>>> datamodule = datamodules.kwcoco_video_data.KWCocoVideoDataModule(
>>> train_dataset='special:vidshapes8', chip_size=32,
>>> batch_size=1, time_steps=1, num_workers=2, normalize_inputs=10)
>>> datamodule.setup('fit')
>>> dataset_stats = datamodule.torch_datasets['train'].cached_dataset_stats(num=3)
>>> dataset = datamodule.torch_datasets['train']
>>> classes = dataset.predictable_classes
>>> dataset.requested_tasks['nonlocal_class'] = True
>>> item1 = dataset[0]
>>> item2 = dataset[1]
>>> batch_items = [item1, item2]
>>> from geowatch.tasks.fusion.methods.torchvision_nets import * # NOQA
>>> # Use one of our fusion.architectures in a test
>>> heads = ub.codeblock(
>>> '''
>>> # Mirrors the simple FCNHead in torchvision
>>> - name: nonlocal_class
>>> type: mlp
>>> hidden_channels: [256]
>>> out_channels: 4
>>> dropout: 0.1
>>> norm: batch
>>> loss:
>>> type: focal
>>> gamma: 2.0
>>> head_weight: 1.0
>>> ''')
>>> self = Resnet50(
>>> heads=heads,
>>> classes=classes,
>>> dataset_stats=dataset_stats)
>>> outputs = self.forward_step(batch_items, with_loss=True)
>>> canvas = datamodule.draw_batch(batch_items, outputs=outputs)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(canvas)
>>> kwplot.show_if_requested()
"""
def __init__(self, heads, classes=None, dataset_stats=None):
super().__init__()
input_stats = self.set_dataset_specific_attributes(None, dataset_stats)
assert len(input_stats) == 1
stats = list(input_stats.values())[0]
norm_kw = {'mean': stats['mean'][None, ...],
'std': stats['std'][None, ...]}
self.input_norms = InputNorm(**norm_kw)
self.ots_model = ots_model = self.define_ots_model()
feat_dim = ots_model.layer4[-1].conv3.out_channels
import kwcoco
self.classes = kwcoco.CategoryTree.coerce(classes)
self.num_classes = len(self.classes)
assert feat_dim == 2048, 'hard coded sanity check'
self.heads = heads_module.TaskHeads(heads, feat_dim=feat_dim)
self.automatic_optimization = True
[docs]
def forward(self, batch):
imdata_bchw = batch['imdata_bchw']
# Mean/Std Normalize the Input Batch
imdata_bchw = self.input_norms(imdata_bchw)
imdata_bchw.nan_to_num_()
# Compute backbone features
feats = self.ots_model(imdata_bchw)
outputs = self.heads(feats)
return outputs
[docs]
def forward_step(self, batch, with_loss=False, stage='unspecified'):
if stage == 'train':
if not self.automatic_optimization:
# Do we have to do this ourselves?
# https://lightning.ai/docs/pytorch/stable/common/optimization.html
opt = self.optimizers()
opt.zero_grad()
batch_size = len(batch['imdata_bchw'])
outputs = self.forward(batch)
if with_loss:
losses = self.heads.compute_loss(outputs, batch)
outputs.update(losses)
total_loss = losses['loss']
self.log(f'{stage}_loss', total_loss, prog_bar=True, batch_size=batch_size)
outputs = network_io.CollatedNetworkOutputs(outputs)
if stage == 'train':
if not self.automatic_optimization:
loss = outputs['loss']
self.manual_backwards(loss)
return outputs
[docs]
def define_ots_model(self):
import torchvision
ots_model = torchvision.models.resnet50(weights='ResNet50_Weights.IMAGENET1K_V2')
# Hack off the head
ots_model.fc = Identity()
return ots_model
def _to_collated(self, batch_items):
from geowatch.tasks.fusion.datamodules import network_io
self._cpu_batch_items = batch_items
batch_items = network_io.UncollatedRGBImageBatch.from_items(batch_items)
batch = batch_items.collate()
batch = batch.to(self.device)
return batch
def _grab_batch_from_dataloader(self, dataloader_iter):
raw_item = next(dataloader_iter)
batch_items = raw_item
return self._to_collated(batch_items)
# These train / vali / test specific methods should be moved to a mixin
# def training_step(self, dataloader_iter: Iterator) -> None:
# self._DataLoaderIterDataFetcher_training_step(dataloader_iter)
# def training_step(self, batch, batch_idx=None):
# return self._PrefetchDataFetcher_training_step(batch, batch_idx)
# def _DataLoaderIterDataFetcher_training_step(self, dataloader_iter) -> None:
# # it is the user responsibility to fetch and move the batch to the right device.
# # batch, batch_idx, dataloader_idx
# self._PrefetchDataFetcher_training_step(batch)
# def training_step(self, dataloader_iter: Iterator) -> None:
# self._grab_batch_from_dataloader()
# batch = self._to_collated(next(dataloader_iter))
[docs]
def training_step(self, batch):
outputs = self.forward_step(batch, with_loss=True, stage='train')
return outputs
# # def validation_step(self, batch, batch_idx=None):
# def validation_step(self, dataloader_iter: Iterator) -> None:
# batch = self._to_collated(next(dataloader_iter))
# outputs = self.forward_step(batch, with_loss=True, stage='val')
# return outputs
# # def test_step(self, batch, batch_idx=None):
# def test_step(self, dataloader_iter: Iterator) -> None:
# batch = self._to_collated(next(dataloader_iter))
# outputs = self.forward_step(batch, with_loss=True, stage='test')
# return outputs
[docs]
@profile
def on_before_batch_transfer(self, batch_items, dataloader_idx):
from geowatch.tasks.fusion.datamodules import network_io
self._cpu_batch_items = batch_items
batch_items = network_io.UncollatedRGBImageBatch.from_items(batch_items)
batch = batch_items.collate()
return batch
# def on_after_batch_transfer(batch, dataloader_idx):
# ...
# def transfer_batch_to_device(batch, device, dataloader_idx):
# ...