geowatch.tasks.fusion.methods.loss module

geowatch.tasks.fusion.methods.loss.coerce_criterion(loss_code, weights, ohem_ratio=None, focal_gamma=2.0, spatial_dims='legacy')[source]

Helps build a loss function and returns information about the shapes needed by the specific loss. Augments the criterion with extra information about what it expects.

Parameters:
  • loss_code (str) – The code that corresponds to loss function call. One of [‘cce’, ‘focal’, ‘dicefocal’].

  • weights (torch.Tensor) – Per class weights. Note: Only used for ‘cce’ and ‘focal’ losses.

  • ohem_ratio (float) – Ratio of hard examples to sample to compute loss. Note: Only applies to focal losses.

  • focal_gamma (float) – Focal loss gamma parameter.

  • spatial_dims (str) – A code indicating which spatial dimension we are expecting in this loss. The “legacy” maintains backwards compat with the multimodal transformer. For spacetime segmentation this should usually be ‘t h w’. For nonlocal it should be ‘’.

Raises:

KeyError – if loss_code is not recognized.

Returns:

The loss function.

The loss criterion will contain variables:

target_encoding: which is either index or onehot logit_shape: the expected shape of the predicted logits. target_shape: the expected shape of the truth targets.

Return type:

torch.nn.modules.loss._Loss