geowatch.tasks.fusion.architectures.transformer module

Current encoder config names are:

‘smt_it_joint_p8’, ‘smt_it_joint_n12’, ‘smt_it_joint_t12’, ‘smt_it_joint_t24’, ‘smt_it_joint_s12’, ‘smt_it_joint_s24’, ‘smt_it_joint_m24’, ‘smt_it_joint_l24’,

‘smt_it_stm_p8’, ‘smt_it_stm_n12’, ‘smt_it_stm_t12’, ‘smt_it_stm_t24’, ‘smt_it_stm_s12’, ‘smt_it_stm_s24’, ‘smt_it_stm_m24’, ‘smt_it_stm_l24’,

‘smt_it_sm_p8’, ‘smt_it_sm_n12’, ‘smt_it_sm_t12’, ‘smt_it_sm_t24’, ‘smt_it_sm_s12’, ‘smt_it_sm_s24’, ‘smt_it_sm_m24’, ‘smt_it_sm_l24’,

‘smt_it_st_p8’, ‘smt_it_st_n12’, ‘smt_it_st_t12’, ‘smt_it_st_t24’, ‘smt_it_st_s12’, ‘smt_it_st_s24’, ‘smt_it_st_m24’, ‘smt_it_st_l24’,

‘smt_it_tm_p8’, ‘smt_it_tm_n12’, ‘smt_it_tm_t12’, ‘smt_it_tm_t24’, ‘smt_it_tm_s12’, ‘smt_it_tm_s24’, ‘smt_it_tm_m24’, ‘smt_it_tm_l24’,

‘smt_it_s_p8’, ‘smt_it_s_n12’, ‘smt_it_s_t12’, ‘smt_it_s_t24’, ‘smt_it_s_s12’, ‘smt_it_s_s24’, ‘smt_it_s_m24’, ‘smt_it_s_l24’,

‘smt_it_t_p8’, ‘smt_it_t_n12’, ‘smt_it_t_t12’, ‘smt_it_t_t24’, ‘smt_it_t_s12’, ‘smt_it_t_s24’, ‘smt_it_t_m24’, ‘smt_it_t_l24’,

‘smt_it_hwtm_p8’, ‘smt_it_hwtm_n12’, ‘smt_it_hwtm_t12’, ‘smt_it_hwtm_t24’, ‘smt_it_hwtm_s12’, ‘smt_it_hwtm_s24’, ‘smt_it_hwtm_m24’, ‘smt_it_hwtm_l24’,

‘smt_it_m_p8’, ‘smt_it_m_n12’, ‘smt_it_m_t12’, ‘smt_it_m_t24’, ‘smt_it_m_s12’, ‘smt_it_m_s24’, ‘smt_it_m_m24’, ‘smt_it_m_l24’,

‘sm_it_joint_p8’, ‘sm_it_joint_n12’, ‘sm_it_joint_t12’, ‘sm_it_joint_t24’, ‘sm_it_joint_s12’, ‘sm_it_joint_s24’, ‘sm_it_joint_m24’, ‘sm_it_joint_l24’,

‘sm_it_sm_p8’, ‘sm_it_sm_n12’, ‘sm_it_sm_t12’, ‘sm_it_sm_t24’, ‘sm_it_sm_s12’, ‘sm_it_sm_s24’, ‘sm_it_sm_m24’, ‘sm_it_sm_l24’

Notes

pip install reformer_pytorch pip install performer-pytorch <- this one

class geowatch.tasks.fusion.architectures.transformer.ResidualSequential(*args)[source]

Bases: Sequential

A Sequential layer with a residual operation at the end

forward(x)[source]
class geowatch.tasks.fusion.architectures.transformer.ResidualAttentionSequential(norm, attention)[source]

Bases: ResidualSequential

Special case of ResidualSequential to support masking

forward(x, key_padding_mask=None)[source]
geowatch.tasks.fusion.architectures.transformer.assert_allclose(a, b, rtol=1e-05, atol=1e-08)[source]

TODO: integrate with kwcoco.coco_sql_dataset.assert_dsets_allclose().

Add to kwarray

class geowatch.tasks.fusion.architectures.transformer.MultiheadSelfAttention(embed_dim, num_heads, *args, **kwargs)[source]

Bases: MultiheadAttention

Inherits from torch.nn.MultiheadAttention

Parameters:
  • embed_dim – total dimension of the model.

  • num_heads – parallel attention heads.

  • dropout – a Dropout layer on attn_output_weights. Default: 0.0.

  • bias – add bias as module parameter. Default: True.

  • add_bias_kv – add bias to the key and value sequences at dim=0.

  • add_zero_attn – add a new batch of zeros to the key and value sequences at dim=1.

  • kdim – total number of features in key. Default: None.

  • vdim – total number of features in value. Default: None.

  • batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).

CommandLine

xdoctest -m geowatch.tasks.fusion.architectures.transformer MultiheadSelfAttention

Example

>>> from geowatch.tasks.fusion.architectures.transformer import *  # NOQA
>>> self = MultiheadSelfAttention(4, 1).eval()
>>> S, B, F = (7, 3, 4)
>>> x  = (torch.rand(S, B, F) * 10).round()
>>> # Results should be independent of the batch dim
>>> y  = self.forward(x)
>>> y0 = self.forward(x[:, 0:1, :])
>>> y1 = self.forward(x[:, 1:2, :])
>>> y2 = self.forward(x[:, 2:3, :])
>>> assert_allclose(y[:, 0:1, :], y0, rtol=1e-3, atol=1e-6)
>>> assert_allclose(y[:, 1:2, :], y1, rtol=1e-3, atol=1e-6)
>>> assert_allclose(y[:, 2:3, :], y2, rtol=1e-3, atol=1e-6)
>>> key_padding_mask = torch.rand(B, S) > 0.5
>>> masked_result = self.forward(x, key_padding_mask=key_padding_mask)
forward(x, key_padding_mask=None)[source]
Parameters:
  • x (Tensor) – of shape (seq, batch, feature)

  • key_padding_mask (Tensor) – of shape (batch, seq). A value of True means we will ignore the token.

Returns:

of shape (seq, batch, feature)

Return type:

attn_out

class geowatch.tasks.fusion.architectures.transformer.MetaModuleProperties(name, bases, namespace, *args, **kwargs)[source]

Bases: type

Experimental way to get concisely property like behavior at a module level.

This defines code run whenever a the user DEFINES a class that inherits from MetaModuleProperties.

class geowatch.tasks.fusion.architectures.transformer.ModuleProperties[source]

Bases: object

Experimental way to get concisely property like behavior at a module level.

Inherit from this class. This class forwards metaclass magic that allows us to register any function decorated with @property. It also creates the __getattr__ method that will be assigned to the module.

classmethod getattr(name)[source]
geowatch.tasks.fusion.architectures.transformer.new_attention_layer(embedding_size, n_heads, attention_impl='exact', **kwargs)[source]

Example

>>> from geowatch.tasks.fusion.architectures.transformer import *  # NOQA
>>> import torch
>>> batch_size = 1
>>> embedding_size = 4
>>> n_heads = 2
>>> num_tokens = 3
>>> input_shape = (num_tokens, batch_size, embedding_size)
>>> inputs = torch.rand(*input_shape)
>>> layer1 = new_attention_layer(embedding_size, n_heads, 'exact')
>>> outputs1 = layer1(inputs)
>>> assert outputs1.shape == input_shape
>>> # xdoctest: +REQUIRES(module:performer_pytorch)
>>> layer2 = new_attention_layer(embedding_size, n_heads, 'performer')
>>> outputs2 = layer2(inputs)
>>> assert outputs2.shape == input_shape

Example

>>> # Test with a mask
>>> from geowatch.tasks.fusion.architectures.transformer import *  # NOQA
>>> import torch
>>> batch_size = 1
>>> embedding_size = 4
>>> n_heads = 2
>>> num_tokens = 3
>>> input_shape = (num_tokens, batch_size, embedding_size)
>>> inputs = torch.rand(*input_shape)
>>> key_padding_mask = torch.rand(batch_size, num_tokens) > 0.5
>>> layer1 = new_attention_layer(embedding_size, n_heads, 'exact')
>>> outputs1 = layer1(inputs, key_padding_mask=key_padding_mask)
geowatch.tasks.fusion.architectures.transformer.new_mlp_layer(embedding_size, dropout, **kwargs)[source]

Example

>>> import torch
>>> embedding_size = 3
>>> batch_size = 1
>>> layer = new_mlp_layer(embedding_size, dropout=0)
>>> input_shape = (batch_size, embedding_size)
>>> inputs = torch.rand(*input_shape)
>>> outputs = layer(inputs)
>>> assert outputs.shape == (batch_size, embedding_size)
class geowatch.tasks.fusion.architectures.transformer.ChannelwiseTransformerEncoderLayer(axes, embedding_size, n_heads, dropout=0.0, default_shape=('batch', 'time', 'mode', 'height', 'width', 'feature'), feature_axis='feature', batch_axis='batch', attention_impl='exact', attention_kwargs=None)[source]

Bases: Module

Todo

  • [ ] Can we resitrict how far the spatial window looks, so it only

    sees neighboring spatial regions?

  • [ ] Flatten tokens completely and have a mask that indicates

    what tokens are allowed to see each other in each step

Notes

  • Currently ‘mode’ might indicate something like a sensor or special computation. Each ‘mode’ might have a differet number of ‘features’. In the future this might be better specified as a dictionary that maps ‘mode’-codes to a tensor containing only the ‘features’ for that mode. E.g.:

    inputs = {

    ‘S2’: Tensor([B, T, H, W, 13]), ‘WV’: Tensor([B, T, H, W, 8]), ‘Latent’: Tensor([B, T, H, W, 512]), ‘Materials’: Tensor([B, T, H, W, 16]),

    }

Currently these are all stacked into a B x T x M x H x W x max(F) and padded with zeros.

Correction: the last statement is not correct. Curently F is hard coded to be F = 1 * ws * ws (where ws is the window size), so features are really spatial positions in a window. And the ‘width’ and ‘height’ here refer to the ‘number of windows’ in the area.

Example

>>> from geowatch.tasks.fusion.architectures.transformer import *  # NOQA
>>> import torch
>>> image_size = 128
>>> #
>>> ws = window_size = 32
>>> W = H = image_size // ws
>>> B = batch_size = 2
>>> T = num_times = 3
>>> M = num_modes = 13  # hack for number of features in S2
>>> F = 1 * ws * ws # hack to use spatial positions in a windows as features
>>> input_shape = (B, T, M, H, W, F)
>>> x = torch.rand(*input_shape)
>>> embedding_size = F   # Embedding size must be equal to F
>>> #
>>> # ================================
>>> # Joint Attentions Across all Axes
>>> self = ChannelwiseTransformerEncoderLayer(
>>>     axes=[('time', 'mode', 'height', 'width')],
>>>     default_shape=['batch', 'time', 'mode', 'height', 'width', 'feature'],
>>>     feature_axis='feature',
>>>     batch_axis='batch',
>>>     embedding_size=embedding_size,
>>>     n_heads=4
>>> )
>>> print(self)
>>> outputs = self(x)
>>> assert tuple(outputs.shape) == (2, 3, 13, 4, 4, 1024)
>>> #
>>> # ================================
>>> # Separable Attentions Across Time, Mode, and then Space
>>> self = ChannelwiseTransformerEncoderLayer(
>>>     axes=[('time', 'mode'), ('height', 'width')],
>>>     default_shape=['batch', 'time', 'mode', 'height', 'width', 'feature'],
>>>     feature_axis='feature',
>>>     batch_axis='batch',
>>>     embedding_size=embedding_size,
>>>     n_heads=4
>>> )
>>> print(self)
>>> outputs = self(x)
>>> assert tuple(outputs.shape) == (2, 3, 13, 4, 4, 1024)
>>> #
>>> # ================================
>>> # Space Only Attention
>>> self = ChannelwiseTransformerEncoderLayer(
>>>     axes=[('height', 'width')],
>>>     default_shape=['batch', 'time', 'mode', 'height', 'width', 'feature'],
>>>     feature_axis='feature',
>>>     batch_axis='batch',
>>>     embedding_size=embedding_size,
>>>     n_heads=4
>>> )
>>> print(self)
>>> outputs = self(x)
>>> assert tuple(outputs.shape) == (2, 3, 13, 4, 4, 1024)
forward(inputs, flat_coordinates=None, key_padding_mask=None)[source]
Parameters:
  • x (Tensor) – of shape B, T, M, H, W, F if flat_coordinates is unspecified otherwise it should be of shape N, F where N is the total number of tokens

  • flat_coordinates (Dict[str, Tensor]) – the time, mode, height, and width coordinate of each token if specified batches are unsupported

  • key_padding_mask (Tensor) – of shape B, T, M, H, W if flat_coordinates is unspecified otherwise should be of shape N. A True value means ignore the token.

CommandLine

xdoctest -m geowatch.tasks.fusion.architectures.transformer ChannelwiseTransformerEncoderLayer.forward

Example

>>> # Test that coordinate aware implementation exactly reproduces aligned variant
>>> from geowatch.tasks.fusion.architectures.transformer import *  # NOQA
>>> import numpy as np
>>> F = embedding_size = 4
>>> B, T, M, H, W = 1, 3, 5, 7, 11
>>> aligned_inputs = aligned_x = (torch.rand(B, T, M, H, W, F) * 100).round() / 10
>>> key_padding_mask = torch.rand(B, T, M, H, W) > 0.9
>>> flat_inputs = flat_x = aligned_inputs.view(-1, embedding_size)
>>> flat_kpm = key_padding_mask.view(-1)
>>> #inputs = flat_inputs
>>> flat_coordinates = None
>>> inputs = aligned_inputs
>>> # Test that coordinate-aware flat attention works
>>> t_coords, m_coords, h_coords, w_coords = np.meshgrid(np.arange(T), np.arange(M), np.arange(H), np.arange(W), indexing='ij')
>>> flat_coordinates = {
>>>     'time':   t_coords.ravel(),
>>>     'mode':   m_coords.ravel(),
>>>     'height': h_coords.ravel(),
>>>     'width':  w_coords.ravel(),
>>> }
>>> flat_coordinates = ub.map_vals(torch.from_numpy, flat_coordinates)
>>> self = ChannelwiseTransformerEncoderLayer(
>>>     #axes=[('height', 'width'), ('time',)],
>>>     axes=[('time',)],
>>>     default_shape=['batch', 'time', 'mode', 'height', 'width', 'feature'],
>>>     feature_axis='feature',
>>>     batch_axis='batch',
>>>     embedding_size=embedding_size,
>>>     n_heads=1
>>> )
>>> self = self.eval()
>>> with torch.set_grad_enabled(False):
>>>     print('----')
>>>     flat_y = self.forward(flat_inputs, flat_coordinates)
>>>     print('----')
>>>     aligned_y = self.forward(aligned_inputs)
>>>     print('----')
>>>     aligned_y_mask = self.forward(aligned_inputs, key_padding_mask=key_padding_mask)
>>>     print('----')
>>>     flat_y_mask = self.forward(flat_inputs, flat_coordinates, key_padding_mask=flat_kpm)
>>> print('----====-')
>>> recon_y1 = aligned_y.view(-1, embedding_size)
>>> recon_y1_mask = aligned_y_mask.view(-1, embedding_size)
>>> print('flat_y=\n{!r}'.format(flat_y))
>>> print('recon_y1=\n{!r}'.format(recon_y1))
>>> abs_diff = (flat_y - recon_y1).abs().max()
>>> print('abs_diff = {!r}'.format(abs_diff))
>>> assert abs_diff < 1e-5
>>> #
>>> flat_y_mask.nan_to_num_()
>>> recon_y1_mask.nan_to_num_()
>>> abs_diff_mask = (flat_y_mask - recon_y1_mask).abs().max()
>>> print('abs_diff_mask = {!r}'.format(abs_diff_mask))
>>> assert abs_diff_mask < 1e-5
>>> #flags = torch.isclose(flat_y, recon_y1)
>>> #assert flags.all()
class geowatch.tasks.fusion.architectures.transformer.TimmEncoder(arch_name='vit_base_patch16_224', pretrained=True, dropout=0.0, attention_impl='exact', in_features=None)[source]

Bases: object

Example

>>> # xdoctest: +REQUIRES(module:timm)
>>> from geowatch.tasks.fusion.architectures.transformer import *  # NOQA
>>> import torch
>>> in_features = 7
>>> input_shape = B, T, M, H, W, F = (2, 3, 5, 2, 2, in_features)
>>> inputs = torch.rand(*input_shape)
>>> arch_name = 'vit_base_patch16_224'
>>> self = TimmEncoder(arch_name)
class geowatch.tasks.fusion.architectures.transformer.MM_VITEncoder[source]

Bases: Module

mmsegmentation variant of VIT

Needs 768 features.

Notes

https://github.com/open-mmlab/mmsegmentation/tree/master/configs/vit

Results:

# 1 https://github.com/open-mmlab/mmsegmentation/tree/master/configs/vit#ade20k https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_vit-b16_mln_512x512_80k_ade20k.py https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/models/upernet_vit-b16_ln_mln.py

initialize_from_pretrained()[source]
forward(x)[source]
class geowatch.tasks.fusion.architectures.transformer.DeiTEncoder(in_features, pretrained=True)[source]

Bases: Module

https://github.com/rishikksh20/ViViT-pytorch https://pytorch.org/tutorials/beginner/vt_tutorial.html

Example

>>> # xdoctest: +REQUIRES(module:timm)
>>> # xdoctest: +SKIP('can cause network issues on CI')
>>> from geowatch.tasks.fusion.architectures.transformer import *  # NOQA
>>> import torch
>>> in_features = 7
>>> input_shape = B, T, M, H, W, F = (2, 3, 5, 2, 2, in_features)
>>> inputs = torch.rand(*input_shape)
>>> self = DeiTEncoder(in_features, pretrained=False)
>>> outputs = self.forward(inputs)
forward(inputs)[source]
class geowatch.tasks.fusion.architectures.transformer.PerceiverEncoder(in_features, depth=4, dropout=0.0)[source]

Bases: Module

https://github.com/lucidrains/perceiver-pytorch/blob/main/perceiver_pytorch/perceiver_io.py

Example

>>> # xdoctest: +REQUIRES(module:perceiver_pytorch)
>>> from geowatch.tasks.fusion.architectures.transformer import PerceiverEncoder  # NOQA
>>> import torch
>>> B, T, M, H, W, F = 1, 2, 3, 5, 8, 13
>>> self = PerceiverEncoder(F, dropout=0.1)
>>> inputs = torch.rand(B, T, M, H, W, F)
>>> outputs = self(inputs)
>>> assert outputs.shape == (B, T, M, H, W, F)
forward(inputs)[source]
class geowatch.tasks.fusion.architectures.transformer.FusionEncoder(axes, default_shape=('batch', 'sequence', 'feature'), feature_axis='feature', batch_axis='batch', embedding_size=128, n_layers=4, n_heads=8, dropout=0.0, attention_impl='exact', attention_kwargs={}, in_features=None)[source]

Bases: Module

Primary entry point to create a feature transformer

Performs multiple “channelwise” (maybe rename to axil?) attention encodings in a row

Example

>>> from geowatch.tasks.fusion.architectures.transformer import *  # NOQA
>>> import torch
>>> in_features = 7
>>> input_shape = B, T, M, H, W, F = (2, 3, 5, 2, 2, in_features)
>>> inputs = torch.rand(*input_shape)
>>> model = FusionEncoder(
>>>     in_features=in_features,
>>>     axes=[("time", "mode", "height", "width")],
>>>     default_shape=["batch", "time", "mode", "height", "width", "feature"],
>>>     feature_axis="feature",
>>>     batch_axis="batch",
>>>     n_layers=8,
>>>     embedding_size=256,
>>>     n_heads=4
>>> )
>>> model(inputs)
>>> output = model(inputs)
>>> assert output.shape == (2, 3, 5, 2, 2, 256)
>>> #
>>> # Test Lazy variant
>>> model = FusionEncoder(
>>>     in_features=None,
>>>     axes=[("time", "mode", "height", "width")],
>>>     default_shape=["batch", "time", "mode", "height", "width", "feature"],
>>>     feature_axis="feature",
>>>     batch_axis="batch",
>>>     n_layers=8,
>>>     embedding_size=256,
>>>     n_heads=4
>>> )
>>> print(model)
>>> inputs = torch.rand(*input_shape)
>>> output = model(inputs)
>>> assert output.shape == (2, 3, 5, 2, 2, 256)
forward(x, flat_coordinates=None, key_padding_mask=None, mask=None)[source]
geowatch.tasks.fusion.architectures.transformer.default(val, d)[source]
geowatch.tasks.fusion.architectures.transformer.cache_fn(f)[source]
class geowatch.tasks.fusion.architectures.transformer.PreNorm(dim, fn, context_dim=None)[source]

Bases: Module

forward(x, **kwargs)[source]
class geowatch.tasks.fusion.architectures.transformer.GEGLU(*args, **kwargs)[source]

Bases: Module

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]
class geowatch.tasks.fusion.architectures.transformer.FeedForward(dim, mult=4)[source]

Bases: Module

forward(x)[source]
class geowatch.tasks.fusion.architectures.transformer.Attention(query_dim, context_dim=None, output_dim=None, heads=8, dim_head=64)[source]

Bases: Module

forward(x, context=None, mask=None)[source]
class geowatch.tasks.fusion.architectures.transformer.BackboneEncoderDecoder[source]

Bases: object

class geowatch.tasks.fusion.architectures.transformer.TransformerEncoderDecoder(encoder_depth: int = 2, decoder_depth: int = 1, dim: int = 128, queries_dim: int = 96, logits_dim: int = 32, decode_cross_every: int = 1, cross_heads: int = 1, latent_heads: int = 8, cross_dim_head: int = 64, latent_dim_head: int = 64, weight_tie_layers: bool = False)[source]

Bases: Module, BackboneEncoderDecoder

forward(x, mask=None, queries=None)[source]
class geowatch.tasks.fusion.architectures.transformer.TransformerEncoderLayerExtended(d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: str = 'relu', layer_norm_eps: float = 1e-05, batch_first: bool = False, norm_first: bool = False, mha_kwargs=None, device=None, dtype=None)[source]

Bases: TransformerEncoderLayer

class geowatch.tasks.fusion.architectures.transformer.VanillaTransformerEncoder(num_layers: int, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.0, activation: str = 'gelu', layer_norm_eps: float = 1e-05, batch_first: bool = True, norm_first: bool = True, mha_kwargs=None)[source]

Bases: Module, BackboneEncoderDecoder

forward(x, mask=None, queries=None)[source]
class geowatch.tasks.fusion.architectures.transformer.MM_VITEncoderDecoder(dim, logits_dim, queries_dim=None, pretrained=None)[source]

Bases: Module, BackboneEncoderDecoder

mmsegmentation variant of VIT

Needs 768 features.

Notes

https://github.com/open-mmlab/mmsegmentation/tree/master/configs/vit

Results:

# 1 https://github.com/open-mmlab/mmsegmentation/tree/master/configs/vit#ade20k https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_vit-b16_mln_512x512_80k_ade20k.py https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/models/upernet_vit-b16_ln_mln.py

Example

>>> # xdoctest: +REQUIRES(module:mmseg)
>>> from geowatch.tasks.fusion.architectures.transformer import *  # NOQA
>>> self = MM_VITEncoderDecoder(16, 16, 16)
>>> x = torch.rand(2, 3, 16)
>>> self.forward(x)
pretrained_fpath_shortnames = {'upernet_vit-b16_mln_512x512_80k_ade20k': 'https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_vit-b16_mln_512x512_80k_ade20k/upernet_vit-b16_mln_512x512_80k_ade20k_20210624_130547-0403cee1.pth'}
initialize_from_pretrained(fpath)[source]
forward(x, mask=None, queries=None)[source]