geowatch.tasks.fusion.methods.heads module¶
# TODO:
Generalize Task Heads, and hook up to off-the-shelf backbones. We almost never will be able to use those models as-is.
Port _build_item_loss_parts
Port forward_item
Port forward_step
Port forward_foot
- class geowatch.tasks.fusion.methods.heads.TaskHeads(heads_config, feat_dim=None, classes=None)[source]¶
Bases:
ModuleDict
Experimental feature. Not finished.
Sends features to task specific heads.
- Parameters:
heads_config (str | List) – yaml coercable config containing the user-level configuration for the heads.
Example
>>> from geowatch.tasks.fusion.methods.heads import * # NOQA >>> import ubelt as ub >>> heads_config = ub.codeblock( >>> ''' >>> - name: class >>> type: MultiLayerPerceptron >>> hidden_channels: 3 >>> classes: auto >>> loss: >>> type: dicefocal >>> gamma: 2.0 >>> head_weight: 1.0 >>> - name: nonlocal_class >>> type: MultiLayerPerceptron >>> hidden_channels: 3 >>> classes: auto >>> loss: >>> type: dicefocal >>> gamma: 2.0 >>> head_weight: 1.0 >>> # >>> # Mirrors the simple FCNHead in torchvision >>> - name: saliency >>> type: mlp >>> hidden_channels: [256] >>> out_channels: 2 >>> dropout: 0.1 >>> norm: batch >>> loss: >>> type: focal >>> gamma: 2.0 >>> head_weight: 1.0 >>> ''') >>> classes = ['a', 'b', 'c'] >>> feat_dim = 1024 >>> heads = TaskHeads(heads_config, feat_dim, classes=classes) >>> print(heads)