Source code for geowatch.tasks.fusion.methods.network_modules

"""
This module should be reorganized into architectures as it consists of smaller
modular network components

Ignore:
    import liberator
    lib = liberator.Liberator()
    from timm.models.layers import drop_path
    lib.add_dynamic(drop_path)
    lib.expand(['timm'])
    print(lib.current_sourcecode())
"""
import torch
from torch.nn.modules.container import Module
from torch._jit_internal import _copy_to_script_wrapper
import einops
import numpy as np
from torch import nn
from geowatch.utils import util_netharn

from geowatch.tasks.fusion.methods.loss import coerce_criterion  # backwards compat  # NOQA


[docs] def drop_path(x, drop_prob: float = 0., training: bool = False): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. From: from timm.models.layers import drop_path """ if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output
[docs] class RobustModuleDict(torch.nn.ModuleDict): """ Regular torch.nn.ModuleDict doesnt allow empty str. Hack around this. Example: >>> from geowatch.tasks.fusion.methods.network_modules import * # NOQA >>> import string >>> torch_dict = RobustModuleDict() >>> # All printable characters should be usable as keys >>> # If they are not, hack it. >>> failed = [] >>> for c in list(string.printable) + ['']: >>> try: >>> torch_dict[c] = torch.nn.Linear(1, 1) >>> except KeyError: >>> failed.append(c) >>> assert len(failed) == 0 """ repl_dot = '#D#' repl_empty = '__EMPTY' def _normalize_key(self, key): if key is None: return '*' # HACK key = self.repl_empty if key == '' else key.replace('.', self.repl_dot) return key @classmethod def _unnormalize_key(self, key): if key == self.repl_empty: return '' else: return key.replace(self.repl_dot, '.') @_copy_to_script_wrapper def __getitem__(self, key: str) -> Module: key = self._normalize_key(key) return self._modules[key] def __setitem__(self, key: str, module: Module) -> None: key = self._normalize_key(key) self.add_module(key, module) def __delitem__(self, key: str) -> None: key = self._normalize_key(key) del self._modules[key] @_copy_to_script_wrapper def __contains__(self, key: str) -> bool: key = self._normalize_key(key) return key in self._modules
[docs] def pop(self, key: str) -> Module: r"""Remove key from the ModuleDict and return its module. Args: key (string): key to pop from the ModuleDict """ key = self._normalize_key(key) v = self[key] del self[key] return v
[docs] class RobustParameterDict(torch.nn.ParameterDict): """ Regular torch.nn.ParameterDict doesnt allow empty str. Hack around this. Example: >>> from geowatch.tasks.fusion.methods.network_modules import * # NOQA >>> import string >>> torch_dict = RobustParameterDict() >>> # All printable characters should be usable as keys >>> # If they are not, hack it. >>> failed = [] >>> for c in list(string.printable) + ['']: >>> try: >>> torch_dict[c] = torch.nn.Parameter(torch.ones((1, 1))) >>> except KeyError: >>> failed.append(c) >>> assert len(failed) == 0 >>> for v in torch_dict.values(): >>> assert list(v.shape) == [1, 1] """ repl_dot = '#D#' repl_empty = '__EMPTY' def _normalize_key(self, key): key = self.repl_empty if key == '' else key.replace('.', self.repl_dot) return key @classmethod def _unnormalize_key(self, key): if key == self.repl_empty: return '' else: return key.replace(self.repl_dot, '.') def __getitem__(self, key: str) -> Module: key = self._normalize_key(key) return super().__getitem__(key) def __setitem__(self, key: str, value) -> None: key = self._normalize_key(key) super().__setitem__(key, value) def __delitem__(self, key: str) -> None: key = self._normalize_key(key) super().__delitem__(key, key) def __contains__(self, key: str) -> bool: key = self._normalize_key(key) return super().__contains__(key, key)
[docs] def pop(self, key: str) -> Module: key = self._normalize_key(key) return super().pop(key)
[docs] class OurDepthwiseSeparableConv(nn.Module): """ DepthwiseSeparable block Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion (factor of 1.0). This is an alternative to having a IR with an optional first pw conv. From timm Ignore: from geowatch.tasks.fusion.methods.network_modules import * # NOQA tokenizer = nn.Sequential(*[ OurDepthwiseSeparableConv(in_modes, in_modes, kernel_size=3, stride=1, padding=1, residual=1, norm=None, noli=None), OurDepthwiseSeparableConv(in_modes, in_modes * 2, kernel_size=3, stride=2, padding=1, residual=0, norm=None), OurDepthwiseSeparableConv(in_modes * 2, in_modes * 4, kernel_size=3, stride=2, padding=1, residual=0), OurDepthwiseSeparableConv(in_modes * 4, in_modes * 8, kernel_size=3, stride=2, padding=1, residual=0), ]) tokenizer = nn.Sequential(*[ OurDepthwiseSeparableConv(in_modes, in_modes, kernel_size=3, stride=1, padding=1, residual=1), OurDepthwiseSeparableConv(in_modes, in_modes * 2, kernel_size=3, stride=2, padding=1, residual=0), OurDepthwiseSeparableConv(in_modes * 2, in_modes * 4, kernel_size=3, stride=2, padding=1, residual=0), OurDepthwiseSeparableConv(in_modes * 4, in_modes * 8, kernel_size=3, stride=2, padding=1, residual=0), ]) """ def __init__( self, in_chs, out_chs, kernel_size=3, stride=1, dilation=1, padding=0, residual=False, pw_kernel_size=1, norm='group', noli='swish', drop_path_rate=0.): super().__init__() if norm == 'auto': norm = {'type': 'group', 'num_groups': 'auto'} self.has_residual = (stride == 1 and in_chs == out_chs) and residual self.drop_path_rate = drop_path_rate conv_cls = util_netharn.rectify_conv(dim=2) # self.conv_dw = create_conv2d( # in_chs, in_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True) self.conv_dw = conv_cls( in_chs, in_chs, kernel_size, stride=stride, dilation=dilation, padding=padding, groups=in_chs) # depthwise self.bn1 = util_netharn.rectify_normalizer(in_channels=in_chs, key=norm) if self.bn1 is None: self.bn1 = util_netharn.Identity() self.act1 = util_netharn.rectify_nonlinearity(noli) if self.act1 is None: self.act1 = util_netharn.Identity() self.conv_pw = conv_cls(in_chs, out_chs, pw_kernel_size, padding=0) # self.bn2 = norm_layer(out_chs) self.bn2 = util_netharn.rectify_normalizer(in_channels=out_chs, key=norm) if self.bn2 is None: self.bn2 = util_netharn.Identity()
[docs] def feature_info(self, location): if location == 'expansion': # after SE, input to PW info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) else: # location == 'bottleneck', block output info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels) return info
[docs] def forward(self, x): shortcut = x x = self.conv_dw(x) x = self.bn1(x) x = self.act1(x) x = self.conv_pw(x) x = self.bn2(x) if self.has_residual: if self.drop_path_rate > 0.: x = drop_path(x, self.drop_path_rate, self.training) x += shortcut return x
[docs] class DWCNNTokenizer(torch.nn.Sequential): """ self = DWCNNTokenizer(13, 2) inputs = torch.rand(2, 13, 16, 16) self(inputs) """ def __init__(self, in_chn, out_chn, norm='auto'): super().__init__() if norm == 'none': norm = None self.norm = norm super().__init__(*[ OurDepthwiseSeparableConv(in_chn, in_chn, kernel_size=3, stride=1, padding=1, residual=1, norm=None, noli=None), OurDepthwiseSeparableConv(in_chn, in_chn * 4, kernel_size=3, stride=2, padding=1, residual=0, norm=norm), OurDepthwiseSeparableConv(in_chn * 4, in_chn * 8, kernel_size=3, stride=2, padding=1, residual=0, norm=norm), OurDepthwiseSeparableConv(in_chn * 8, out_chn, kernel_size=3, stride=2, padding=1, residual=0, norm=norm), ]) self.in_channels = in_chn self.out_channels = out_chn
[docs] class LinearConvTokenizer(torch.nn.Sequential): """ Example: >>> from geowatch.tasks.fusion.methods.network_modules import * # NOQA >>> LinearConvTokenizer(1, 512) """ def __init__(self, in_channels, out_channels): # import math c1 = in_channels * 1 c2 = in_channels * 4 c3 = in_channels * 16 c4 = in_channels * 8 # final_groups = math.gcd(104, out_channels) final_groups = 1 super().__init__( util_netharn.ConvNormNd( dim=2, in_channels=c1, out_channels=c2, groups=c1, norm=None, noli=None, kernel_size=3, stride=2, padding=1, ).conv, util_netharn.ConvNormNd( dim=2, in_channels=c2, out_channels=c3, groups=c2, norm=None, noli=None, kernel_size=3, stride=2, padding=1, ).conv, util_netharn.ConvNormNd( dim=2, in_channels=c3, out_channels=c4, groups=min(c3, c4), norm=None, noli=None, kernel_size=3, stride=2, padding=1, ).conv, util_netharn.ConvNormNd( dim=2, in_channels=c4, out_channels=out_channels, groups=final_groups, norm=None, noli=None, kernel_size=1, stride=1, padding=0, ).conv, ) self.in_channels = in_channels self.out_channels = out_channels
[docs] class ConvTokenizer(nn.Module): """ Example: from geowatch.tasks.fusion.methods.network_modules import * # NOQA self = ConvTokenizer(13, 64) print('self = {!r}'.format(self)) inputs = torch.rand(2, 13, 128, 128) tokens = self(inputs) print('inputs.shape = {!r}'.format(inputs.shape)) print('tokens.shape = {!r}'.format(tokens.shape)) Benchmark: in_channels = 13 tokenizer1 = ConvTokenizer(in_channels, 512) tokenizer2 = RearrangeTokenizer(in_channels, 8, 8) tokenizer3 = DWCNNTokenizer(in_channels, 512) tokenizer4 = LinearConvTokenizer(in_channels, 512) print(util_netharn.number_of_parameters(tokenizer1)) print(util_netharn.number_of_parameters(tokenizer2)) print(util_netharn.number_of_parameters(tokenizer3)) print(util_netharn.number_of_parameters(tokenizer4)) print(util_netharn.number_of_parameters(tokenizer4[0])) print(util_netharn.number_of_parameters(tokenizer4[1])) print(util_netharn.number_of_parameters(tokenizer4[2])) print(util_netharn.number_of_parameters(tokenizer4[3])) inputs = torch.rand(1, in_channels, 128, 128) import timerit ti = timerit.Timerit(100, bestof=1, verbose=2) tokenizer1(inputs).shape tokenizer2(inputs).shape for timer in ti.reset('tokenizer1'): with timer: tokenizer1(inputs) for timer in ti.reset('tokenizer2'): with timer: tokenizer2(inputs) for timer in ti.reset('tokenizer3'): with timer: tokenizer3(inputs) for timer in ti.reset('tokenizer4'): with timer: tokenizer4(inputs) input_shape = (1, in_channels, 64, 64) print(tokenizer2(torch.rand(*input_shape)).shape) downsampler1 = torch.nn.Sequential(*[ util_netharn.ConvNormNd( dim=2, in_channels=in_channels, out_channels=in_channels, groups=in_channels, norm=None, noli=None, kernel_size=3, stride=2, padding=1, ), util_netharn.ConvNormNd( dim=2, in_channels=in_channels, out_channels=in_channels, groups=in_channels, norm=None, noli=None, kernel_size=3, stride=2, padding=1, ), util_netharn.ConvNormNd( dim=2, in_channels=in_channels, out_channels=in_channels, groups=in_channels, norm=None, noli=None, kernel_size=3, stride=2, padding=1, ), ]) downsampler2 = torch.nn.Sequential(*[ util_netharn.ConvNormNd( dim=2, in_channels=in_channels, out_channels=in_channels, groups=in_channels, norm=None, noli=None, kernel_size=7, stride=5, padding=3, ), ]) print(ub.urepr(downsampler1.output_shape_for(input_shape).hidden.shallow(30), nl=1)) print(ub.urepr(downsampler2.output_shape_for(input_shape).hidden.shallow(30), nl=1)) """ def __init__(self, in_chn, out_chn, norm=None): super().__init__() self.down = util_netharn.ConvNormNd( dim=2, in_channels=in_chn, out_channels=in_chn, groups=in_chn, norm=norm, noli=None, kernel_size=7, stride=5, padding=3, ) self.one_by_one = util_netharn.ConvNormNd( dim=2, in_channels=in_chn, out_channels=out_chn, groups=1, norm=norm, noli=None, kernel_size=1, stride=1, padding=0, ) self.out_channels = out_chn
[docs] def forward(self, inputs): # b, t, c, h, w = inputs.shape b, c, h, w = inputs.shape # inputs2d = einops.rearrange(inputs, 'b t c h w -> (b t) c h w') inputs2d = inputs tokens2d = self.down(inputs2d) tokens2d = self.one_by_one(tokens2d) tokens = tokens2d # tokens = einops.rearrange(tokens2d, '(b t) c h w -> b t c h w 1', b=b, t=t) return tokens
[docs] class RearrangeTokenizer(nn.Module): """ A mapping to a common number of channels and then rearrange Not quite a pure rearrange, but is this way for backwards compat """ def __init__(self, in_channels, agree, window_size): super().__init__() self.window_size = window_size self.foot = util_netharn.MultiLayerPerceptronNd( dim=2, in_channels=in_channels, hidden_channels=3, out_channels=agree, residual=True, norm=None) self.out_channels = agree * window_size * window_size
[docs] def forward(self, x): mixed_mode = self.foot(x) ws = self.window_size # HACK: reorganize and fix mode_vals_tokens = einops.rearrange( mixed_mode, 'b c (h hs) (w ws) -> b (ws hs c) h w', hs=ws, ws=ws) return mode_vals_tokens
def _torch_meshgrid(*basis_dims): """ References: https://zhaoyu.li/post/how-to-implement-meshgrid-in-pytorch/ """ basis_lens = list(map(len, basis_dims)) new_dims = [] for i, basis in enumerate(basis_dims): # Probably a more efficent way to do this, but its right newshape = [1] * len(basis_dims) reps = list(basis_lens) newshape[i] = -1 reps[i] = 1 dd = basis.view(*newshape).repeat(*reps) new_dims.append(dd) return new_dims def _class_weights_from_freq(total_freq, mode='median-idf'): """ Example: >>> from geowatch.tasks.fusion.methods.network_modules import _class_weights_from_freq >>> total_freq = np.array([19503736, 92885, 883379, 0, 0]) >>> print(_class_weights_from_freq(total_freq, mode='idf')) >>> print(_class_weights_from_freq(total_freq, mode='median-idf')) >>> print(_class_weights_from_freq(total_freq, mode='log-median-idf')) >>> total_freq = np.array([19503736, 92885, 883379, 0, 0, 0, 0, 0, 0, 0, 0]) >>> print(_class_weights_from_freq(total_freq, mode='idf')) >>> print(_class_weights_from_freq(total_freq, mode='median-idf')) >>> print(_class_weights_from_freq(total_freq, mode='log-median-idf')) """ def logb(arr, base): if base == 'e': return np.log(arr) elif base == 2: return np.log2(arr) elif base == 10: return np.log10(arr) else: out = np.log(arr) out /= np.log(base) return out freq = total_freq.copy() is_natural = total_freq > 0 & np.isfinite(total_freq) natural_freq = freq[is_natural] mask = is_natural.copy() if len(natural_freq): _min, _max = np.quantile(natural_freq, [0.05, 0.95]) is_robust = (_max >= freq) & (freq >= _min) if np.any(is_robust): middle_value = np.median(freq[is_robust]) else: middle_value = np.median(natural_freq) freq[~is_natural] = natural_freq.min() / 2 else: middle_value = 2 # variant of median-inverse-frequency if mode == 'idf': # There is no difference and this and median after reweighting weights = (1 / freq) mask &= np.isfinite(weights) elif mode == 'name-me': z = freq[mask] a = ((1 - np.eye(len(z))) * z[:, None]).sum(axis=0) b = a / z c = b / b.max() weights = np.zeros(len(freq)) weights[mask] = c elif mode == 'median-idf': weights = (middle_value / freq) mask &= np.isfinite(weights) elif mode == 'log-median-idf': weights = (middle_value / freq) mask &= np.isfinite(weights) weights[~np.isfinite(weights)] = 1.0 base = 2 base = np.exp(1) weights = logb(weights + (base - 1), base) weights = np.maximum(weights, .1) weights = np.minimum(weights, 10) else: raise KeyError('mode = {!r}'.format(mode)) # unseen classes should probably get a reasonably high weight in case we do # see them and need to learn them, but my intuition is to give them # less weight than things we have a shot of learning well # so they dont mess up the main categories natural_weights = weights[mask] if len(natural_weights): denom = natural_weights.max() else: denom = 1 weights[mask] = weights[mask] / denom if np.any(mask): weights[~mask] = weights[mask].max() / 7 else: weights[~mask] = 1e-1 weights = np.round(weights, 6) return weights
[docs] def torch_safe_stack(tensors, dim=0, *, out=None, item_shape=None, dtype=None, device=None): """ Behaves like torch.stack, but does not error when tensors is empty. When tensors are not empty this is exactly :func:`torch.stack`. When tensors are empty, it constructs an empty output tensor based on explicit expected item shape if available, otherwise it assumes items would have had a shape of ``[0]``. Likewise dtype and device should be specified otherwise they use :func:`torch.empty` defaults. Args: tensors (List[Tensor]): sequence of tensors to concatenate. Passed to :func:`torch.stack`. dim (int): dimension to insert. Has to be between 0 and the number of dimensions of concatenated tensors (inclusive). Passed to :func:`torch.stack`. out (Tensor): passed to :func:`torch.stack`. item_shape (Tuple[int, ...]): what the shape of an item should be. used to construct a default output. dtype (torch.dtype): the expected output datatype when tensors is empty. device (torch.device | str | int | None) : the expected output device when tensors is empty. Example: >>> from geowatch.tasks.fusion.methods.network_modules import * # NOQA >>> import ubelt as ub >>> grid = list(ub.named_product({ >>> # 'num': [0, 1, 2, 3], >>> 'num': [0, 7], >>> 'item_shape': ['auto', None], >>> 'shape': [[], [0], [2], [2, 3], [2, 0, 3]], >>> 'dim': [0, 1], >>> })) >>> results = [] >>> for item in grid: >>> print(f'item={item}') >>> dim = item['dim'] >>> shape = item['shape'] >>> item['shape'] = tuple(item['shape']) >>> if item['item_shape'] == 'auto': >>> item['item_shape'] = item['shape'] >>> num = item['num'] >>> tensors = [torch.empty(shape)] * num >>> if dim >= len(shape): >>> continue >>> out = torch_safe_stack(tensors, dim=dim, >>> item_shape=item['item_shape']) >>> row = { >>> **item, >>> 'out.shape': out.shape, >>> } >>> print(f'row={row}') >>> results.append(row) >>> import pandas as pd >>> import rich >>> df = pd.DataFrame(results) >>> for _, subdf in df.groupby('shape'): >>> subdf = subdf.sort_values(['shape', 'dim', 'item_shape', 'num']) >>> print('') >>> rich.print(subdf.to_string()) """ if len(tensors) == 0: if item_shape is None: # TODO: WARN HERE, THE USER SHOULD PROVIDE A DEFAULT SHAPE # OTHERWISE THE FUNCTION MAY NOT PRODUCE COMPATIBLE RESULTS WITH # POPULATED VARIANTS item_shape = [0] out_shape = list(item_shape) if dim > len(out_shape): raise IndexError( f'Dimension out of range (expected to be in range of ' f'[-1, {len(out_shape)}], but got {dim})' ) out_shape.insert(dim, 0) return torch.empty(out_shape, dtype=dtype, device=device) else: return torch.stack(tensors, dim=dim, out=out)