geowatch.tasks.fusion.methods.heads module

# TODO:

  • Generalize Task Heads, and hook up to off-the-shelf backbones. We almost never will be able to use those models as-is.

  • Port _build_item_loss_parts

  • Port forward_item

  • Port forward_step

  • Port forward_foot

class geowatch.tasks.fusion.methods.heads.TaskHeads(heads_config, feat_dim=None, classes=None)[source]

Bases: ModuleDict

Experimental feature. Not finished.

Sends features to task specific heads.

Parameters:

heads_config (str | List) – yaml coercable config containing the user-level configuration for the heads.

Example

>>> from geowatch.tasks.fusion.methods.heads import *  # NOQA
>>> import ubelt as ub
>>> heads_config = ub.codeblock(
>>>     '''
>>>     - name: class
>>>       type: MultiLayerPerceptron
>>>       hidden_channels: 3
>>>       classes: auto
>>>       loss:
>>>           type: dicefocal
>>>           gamma: 2.0
>>>       head_weight: 1.0
>>>     - name: nonlocal_class
>>>       type: MultiLayerPerceptron
>>>       hidden_channels: 3
>>>       classes: auto
>>>       loss:
>>>           type: dicefocal
>>>           gamma: 2.0
>>>       head_weight: 1.0
>>>     #
>>>     # Mirrors the simple FCNHead in torchvision
>>>     - name: saliency
>>>       type: mlp
>>>       hidden_channels: [256]
>>>       out_channels: 2
>>>       dropout: 0.1
>>>       norm: batch
>>>       loss:
>>>           type: focal
>>>           gamma: 2.0
>>>       head_weight: 1.0
>>>     ''')
>>> classes = ['a', 'b', 'c']
>>> feat_dim = 1024
>>> heads = TaskHeads(heads_config, feat_dim, classes=classes)
>>> print(heads)
forward(feats)[source]
compute_loss(outputs, batch)[source]