import torch
[docs]
def coerce_criterion(loss_code, weights, ohem_ratio=None, focal_gamma=2.0,
spatial_dims='legacy'):
"""
Helps build a loss function and returns information about the shapes needed
by the specific loss. Augments the criterion with extra information about
what it expects.
Args:
loss_code (str): The code that corresponds to loss function call.
One of ['cce', 'focal', 'dicefocal'].
weights (torch.Tensor): Per class weights.
Note: Only used for 'cce' and 'focal' losses.
ohem_ratio (float): Ratio of hard examples to sample to compute loss.
Note: Only applies to focal losses.
focal_gamma (float): Focal loss gamma parameter.
spatial_dims (str):
A code indicating which spatial dimension we are expecting in this
loss. The "legacy" maintains backwards compat with the multimodal
transformer. For spacetime segmentation this should usually be
't h w'. For nonlocal it should be ''.
Raises:
KeyError: if loss_code is not recognized.
Returns:
torch.nn.modules.loss._Loss: The loss function.
The loss criterion will contain variables:
target_encoding: which is either index or onehot
logit_shape: the expected shape of the predicted logits.
target_shape: the expected shape of the truth targets.
"""
# import monai
if loss_code == 'cce':
criterion = torch.nn.CrossEntropyLoss(
weight=weights, reduction='none')
criterion.target_encoding = 'index'
if spatial_dims == 'legacy':
spatial_dims = 't h w'
criterion.logit_shape = f'(b {spatial_dims}) c'
criterion.target_shape = f'(b {spatial_dims})'
elif loss_code == 'bce':
criterion = torch.nn.BCELoss(
weight=weights, reduction='none')
criterion.target_encoding = 'onehot'
if spatial_dims == 'legacy':
spatial_dims = 't h w'
criterion.logit_shape = '(b {spatial_dims}) c'
criterion.target_shape = f'(b {spatial_dims}) c'
elif loss_code == 'focal_multiclass':
from geowatch.utils.ext_monai import FocalLoss
# from monai.losses import FocalLoss
criterion = FocalLoss(
reduction='none',
to_onehot_y=True,
weight=weights,
ohem_ratio=ohem_ratio,
gamma=focal_gamma)
criterion.target_encoding = 'index'
if spatial_dims == 'legacy':
spatial_dims = 't h w'
criterion.logit_shape = f'(b {spatial_dims}) c'
criterion.target_shape = f'(b {spatial_dims})'
elif loss_code == 'focal':
from geowatch.utils.ext_monai import FocalLoss
# from monai.losses import FocalLoss
criterion = FocalLoss(
reduction='none',
to_onehot_y=False,
weight=weights,
ohem_ratio=ohem_ratio,
gamma=focal_gamma)
criterion.target_encoding = 'onehot'
if spatial_dims == 'legacy':
spatial_dims = 'h w t'
criterion.logit_shape = f'b c {spatial_dims}'
criterion.target_shape = f'b c {spatial_dims}'
elif loss_code == 'dicefocal':
from geowatch.utils.ext_monai import DiceFocalLoss
# from monai.losses import DiceFocalLoss
criterion = DiceFocalLoss(
sigmoid=True,
to_onehot_y=False,
focal_weight=weights,
reduction='none',
ohem_ratio_focal=ohem_ratio,
gamma=focal_gamma)
criterion.target_encoding = 'onehot'
if spatial_dims == 'legacy':
spatial_dims = 'h w t'
criterion.logit_shape = f'b c {spatial_dims}'
criterion.target_shape = f'b c {spatial_dims}'
else:
# self.class_criterion = nn.CrossEntropyLoss()
# self.class_criterion = nn.BCEWithLogitsLoss()
raise NotImplementedError(loss_code)
criterion.in_channels = len(weights)
return criterion