"""
# TODO:
* Generalize Task Heads, and hook up to off-the-shelf backbones. We almost
never will be able to use those models as-is.
* Port _build_item_loss_parts
* Port forward_item
* Port forward_step
* Port forward_foot
"""
import torch
import kwutil
import einops
import kwarray
from geowatch.utils.util_netharn import MultiLayerPerceptronNd
[docs]
class TaskHeads(torch.nn.ModuleDict):
"""
Experimental feature. Not finished.
Sends features to task specific heads.
"""
def __init__(heads, heads_config, feat_dim=None, classes=None):
"""
Args:
heads_config (str | List):
yaml coercable config containing the user-level configuration
for the heads.
Example:
>>> from geowatch.tasks.fusion.methods.heads import * # NOQA
>>> import ubelt as ub
>>> heads_config = ub.codeblock(
>>> '''
>>> - name: class
>>> type: MultiLayerPerceptron
>>> hidden_channels: 3
>>> classes: auto
>>> loss:
>>> type: dicefocal
>>> gamma: 2.0
>>> head_weight: 1.0
>>> - name: nonlocal_class
>>> type: MultiLayerPerceptron
>>> hidden_channels: 3
>>> classes: auto
>>> loss:
>>> type: dicefocal
>>> gamma: 2.0
>>> head_weight: 1.0
>>> #
>>> # Mirrors the simple FCNHead in torchvision
>>> - name: saliency
>>> type: mlp
>>> hidden_channels: [256]
>>> out_channels: 2
>>> dropout: 0.1
>>> norm: batch
>>> loss:
>>> type: focal
>>> gamma: 2.0
>>> head_weight: 1.0
>>> ''')
>>> classes = ['a', 'b', 'c']
>>> feat_dim = 1024
>>> heads = TaskHeads(heads_config, feat_dim, classes=classes)
>>> print(heads)
"""
super().__init__()
heads_config = kwutil.util_yaml.Yaml.coerce(heads_config)
if isinstance(heads_config, dict):
raise NotImplementedError
task_configs = heads_config
for task_config in task_configs:
task_config = task_config.copy()
head_name = task_config.pop('name')
head_type = task_config.pop('type')
head_weight = task_config.pop('head_weight', 1.0)
if head_weight > 0:
if head_type == 'box':
from geowatch.tasks.fusion.methods.object_head import DetrDecoderForObjectDetection
from transformers import DetrConfig
detr_config = DetrConfig(
d_model=feat_dim,
num_labels=1,
dropout=0.0,
eos_coefficient=1.0,
num_queries=20
)
heads[head_name] = head = DetrDecoderForObjectDetection(
config=detr_config,
d_model=feat_dim,
d_hidden=feat_dim
)
head.head_weight = head_weight
raise NotImplementedError
if 'loss' in task_config:
raise ValueError('DETR-HEAD handles its own loss')
elif head_type == 'nonlocal_clf':
raise NotImplementedError
elif head_type in {'mlp', 'MultiLayerPerceptron'}:
loss_config = task_config.pop('loss', {})
if 'norm' not in task_config:
task_config['norm'] = None
if 'classes' in task_config:
task_classes = task_config.pop('classes')
if task_classes == 'auto':
task_classes = classes
assert task_classes is not None
task_config['out_channels'] = len(task_classes)
heads[head_name] = head = MultiLayerPerceptronNd(
dim=0,
in_channels=feat_dim,
**task_config,
)
head.head_weight = head_weight
head.criterion = _coerce_loss(head.out_channels, loss_config)
elif head_type == 'segmenter':
from geowatch.tasks.fusion.architectures import segmenter_decoder
heads[head_name] = segmenter_decoder.MaskTransformerDecoder(
d_model=feat_dim,
n_layers=task_config['hidden_channels'],
n_cls=task_config['out_channels'],
)
head.head_weight = head_weight
raise NotImplementedError
else:
raise KeyError(head_type)
[docs]
def forward(heads, feats):
logits = {}
probs = {}
with_probs = True
for head_key, head in heads.items():
logits[head_key] = head(feats)
# Convert logits into probabilities for output
if with_probs:
# not a huge fan of modifying the key, but doing it to agree
# with existing code. Might want to refactor later.
probs_key = head_key + '_probs'
criterion_encoding = head.criterion.target_encoding
_logits = logits[head_key].detach()
if criterion_encoding == "onehot":
probs[probs_key] = _logits.sigmoid()
elif criterion_encoding == "index":
probs[probs_key] = _logits.softmax(dim=-1)
else:
raise NotImplementedError
head_outputs = {}
head_outputs['logits'] = logits
if with_probs:
head_outputs['probs'] = probs
return head_outputs
[docs]
def compute_loss(heads, outputs, batch):
head_loss_parts = {}
# Compute criterion loss for each head
resampled_logits = outputs['logits']
for head_key, head_logits in resampled_logits.items():
assert head_key == 'nonlocal_class'
head_truth = batch['nonlocal_class_ohe']
truth_encoding = 'ohe'
head = heads[head_key]
criterion = head.criterion
head_pred_input = einops.rearrange(head_logits, '(b t h w) c -> ' + criterion.logit_shape, t=1, h=1, w=1).contiguous()
if criterion.target_encoding == 'index':
head_true_idxs = head_truth.long()
head_true_input = einops.rearrange(head_true_idxs, '(b t h w) -> ' + criterion.target_shape, t=1, h=1, w=1).contiguous()
elif criterion.target_encoding == 'onehot':
# Note: 1HE is much easier to work with
if truth_encoding == 'index':
head_true_ohe = kwarray.one_hot_embedding(head_truth.long(), criterion.in_channels, dim=-1)
elif truth_encoding == 'ohe':
head_true_ohe = head_truth
else:
raise KeyError(truth_encoding)
head_true_input = einops.rearrange(head_true_ohe, '(b t h w) c -> ' + criterion.target_shape, t=1, h=1, w=1).contiguous()
else:
raise KeyError(criterion.target_encoding)
unreduced_head_loss = criterion(head_pred_input, head_true_input)
# full_head_weight = torch.broadcast_to(head_weights_input, unreduced_head_loss.shape)
# Weighted reduction
# EPS_F32 = 1.1920929e-07
# weighted_head_loss = (full_head_weight * unreduced_head_loss).sum() / (full_head_weight.sum() + EPS_F32)
weighted_head_loss = unreduced_head_loss
global_head_weight = 1
head_loss = global_head_weight * weighted_head_loss
head_loss_parts[head_key] = head_loss
total_loss = sum(t.sum() for t in head_loss_parts.values())
losses = {}
losses['head_loss_parts'] = head_loss_parts
losses['loss'] = total_loss
return losses
def _coerce_loss(out_channels, loss_config):
from geowatch.tasks.fusion.methods.loss import coerce_criterion
# hack
loss_config['type']
loss_config['loss_code'] = loss_config.pop('type')
if 'gamma' in loss_config:
loss_config['focal_gamma'] = loss_config.pop('gamma')
if 'weights' not in loss_config:
# TODO: determine a good way for the heads to be
# notified about changes in class / saliency weighting.
loss_config['weights'] = torch.ones(out_channels)
criterion = coerce_criterion(**loss_config)
return criterion
# class MLPHead(MultiLayerPerceptronNd):
# """
# Unfinished
# Example:
# >>> from geowatch.tasks.fusion.methods.heads import * # NOQA
# >>> head = MLPHead(0, 3, [], 5)
# """
# def __init__(self, dim, in_channels, hidden_channels, out_channels,
# bias=True, dropout=None, norm=None, noli='relu',
# residual=False, loss=None, **kwargs):
# super().__init__(dim=dim, in_channels=in_channels,
# hidden_channels=hidden_channels,
# out_channels=out_channels, bias=bias,
# dropout=dropout, noli=noli, residual=residual,
# **kwargs)
# if loss is None:
# loss = {
# 'type': 'dicefocal',
# 'gamma': 2.0,
# }
# self.criterion = _coerce_loss(self.out_channels, loss)
# @classmethod
# def coerce(cls, config):
# out_channels = config.get('out_channels', None)
# if out_channels is None:
# _config_classes = config.get('classes', None)
# if _config_classes is not None:
# if _config_classes == 'auto':
# raise NotImplementedError
# else:
# out_channels = len(_config_classes)
# ...
# def forward(self, inputs):
# ...