geowatch.tasks.fusion.architectures.transformer module¶
Current encoder config names are:
‘smt_it_joint_p8’, ‘smt_it_joint_n12’, ‘smt_it_joint_t12’, ‘smt_it_joint_t24’, ‘smt_it_joint_s12’, ‘smt_it_joint_s24’, ‘smt_it_joint_m24’, ‘smt_it_joint_l24’,
‘smt_it_stm_p8’, ‘smt_it_stm_n12’, ‘smt_it_stm_t12’, ‘smt_it_stm_t24’, ‘smt_it_stm_s12’, ‘smt_it_stm_s24’, ‘smt_it_stm_m24’, ‘smt_it_stm_l24’,
‘smt_it_sm_p8’, ‘smt_it_sm_n12’, ‘smt_it_sm_t12’, ‘smt_it_sm_t24’, ‘smt_it_sm_s12’, ‘smt_it_sm_s24’, ‘smt_it_sm_m24’, ‘smt_it_sm_l24’,
‘smt_it_st_p8’, ‘smt_it_st_n12’, ‘smt_it_st_t12’, ‘smt_it_st_t24’, ‘smt_it_st_s12’, ‘smt_it_st_s24’, ‘smt_it_st_m24’, ‘smt_it_st_l24’,
‘smt_it_tm_p8’, ‘smt_it_tm_n12’, ‘smt_it_tm_t12’, ‘smt_it_tm_t24’, ‘smt_it_tm_s12’, ‘smt_it_tm_s24’, ‘smt_it_tm_m24’, ‘smt_it_tm_l24’,
‘smt_it_s_p8’, ‘smt_it_s_n12’, ‘smt_it_s_t12’, ‘smt_it_s_t24’, ‘smt_it_s_s12’, ‘smt_it_s_s24’, ‘smt_it_s_m24’, ‘smt_it_s_l24’,
‘smt_it_t_p8’, ‘smt_it_t_n12’, ‘smt_it_t_t12’, ‘smt_it_t_t24’, ‘smt_it_t_s12’, ‘smt_it_t_s24’, ‘smt_it_t_m24’, ‘smt_it_t_l24’,
‘smt_it_hwtm_p8’, ‘smt_it_hwtm_n12’, ‘smt_it_hwtm_t12’, ‘smt_it_hwtm_t24’, ‘smt_it_hwtm_s12’, ‘smt_it_hwtm_s24’, ‘smt_it_hwtm_m24’, ‘smt_it_hwtm_l24’,
‘smt_it_m_p8’, ‘smt_it_m_n12’, ‘smt_it_m_t12’, ‘smt_it_m_t24’, ‘smt_it_m_s12’, ‘smt_it_m_s24’, ‘smt_it_m_m24’, ‘smt_it_m_l24’,
‘sm_it_joint_p8’, ‘sm_it_joint_n12’, ‘sm_it_joint_t12’, ‘sm_it_joint_t24’, ‘sm_it_joint_s12’, ‘sm_it_joint_s24’, ‘sm_it_joint_m24’, ‘sm_it_joint_l24’,
‘sm_it_sm_p8’, ‘sm_it_sm_n12’, ‘sm_it_sm_t12’, ‘sm_it_sm_t24’, ‘sm_it_sm_s12’, ‘sm_it_sm_s24’, ‘sm_it_sm_m24’, ‘sm_it_sm_l24’
Notes
pip install reformer_pytorch pip install performer-pytorch <- this one
- class geowatch.tasks.fusion.architectures.transformer.ResidualSequential(*args)[source]¶
Bases:
Sequential
A Sequential layer with a residual operation at the end
- class geowatch.tasks.fusion.architectures.transformer.ResidualAttentionSequential(norm, attention)[source]¶
Bases:
ResidualSequential
Special case of ResidualSequential to support masking
- geowatch.tasks.fusion.architectures.transformer.assert_allclose(a, b, rtol=1e-05, atol=1e-08)[source]¶
TODO: integrate with
kwcoco.coco_sql_dataset.assert_dsets_allclose()
.Add to kwarray
- class geowatch.tasks.fusion.architectures.transformer.MultiheadSelfAttention(embed_dim, num_heads, *args, **kwargs)[source]¶
Bases:
MultiheadAttention
Inherits from
torch.nn.MultiheadAttention
- Parameters:
embed_dim – total dimension of the model.
num_heads – parallel attention heads.
dropout – a Dropout layer on attn_output_weights. Default: 0.0.
bias – add bias as module parameter. Default: True.
add_bias_kv – add bias to the key and value sequences at dim=0.
add_zero_attn – add a new batch of zeros to the key and value sequences at dim=1.
kdim – total number of features in key. Default: None.
vdim – total number of features in value. Default: None.
batch_first – If
True
, then the input and output tensors are provided as (batch, seq, feature). Default:False
(seq, batch, feature).
CommandLine
xdoctest -m geowatch.tasks.fusion.architectures.transformer MultiheadSelfAttention
Example
>>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA >>> self = MultiheadSelfAttention(4, 1).eval() >>> S, B, F = (7, 3, 4) >>> x = (torch.rand(S, B, F) * 10).round() >>> # Results should be independent of the batch dim >>> y = self.forward(x) >>> y0 = self.forward(x[:, 0:1, :]) >>> y1 = self.forward(x[:, 1:2, :]) >>> y2 = self.forward(x[:, 2:3, :]) >>> assert_allclose(y[:, 0:1, :], y0, rtol=1e-3, atol=1e-6) >>> assert_allclose(y[:, 1:2, :], y1, rtol=1e-3, atol=1e-6) >>> assert_allclose(y[:, 2:3, :], y2, rtol=1e-3, atol=1e-6)
>>> key_padding_mask = torch.rand(B, S) > 0.5 >>> masked_result = self.forward(x, key_padding_mask=key_padding_mask)
- class geowatch.tasks.fusion.architectures.transformer.MetaModuleProperties(name, bases, namespace, *args, **kwargs)[source]¶
Bases:
type
Experimental way to get concisely property like behavior at a module level.
This defines code run whenever a the user DEFINES a class that inherits from
MetaModuleProperties
.
- class geowatch.tasks.fusion.architectures.transformer.ModuleProperties[source]¶
Bases:
object
Experimental way to get concisely property like behavior at a module level.
Inherit from this class. This class forwards metaclass magic that allows us to register any function decorated with
@property
. It also creates the__getattr__
method that will be assigned to the module.
- geowatch.tasks.fusion.architectures.transformer.new_attention_layer(embedding_size, n_heads, attention_impl='exact', **kwargs)[source]¶
Example
>>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA >>> import torch >>> batch_size = 1 >>> embedding_size = 4 >>> n_heads = 2 >>> num_tokens = 3 >>> input_shape = (num_tokens, batch_size, embedding_size) >>> inputs = torch.rand(*input_shape) >>> layer1 = new_attention_layer(embedding_size, n_heads, 'exact') >>> outputs1 = layer1(inputs) >>> assert outputs1.shape == input_shape >>> # xdoctest: +REQUIRES(module:performer_pytorch) >>> layer2 = new_attention_layer(embedding_size, n_heads, 'performer') >>> outputs2 = layer2(inputs) >>> assert outputs2.shape == input_shape
Example
>>> # Test with a mask >>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA >>> import torch >>> batch_size = 1 >>> embedding_size = 4 >>> n_heads = 2 >>> num_tokens = 3 >>> input_shape = (num_tokens, batch_size, embedding_size) >>> inputs = torch.rand(*input_shape) >>> key_padding_mask = torch.rand(batch_size, num_tokens) > 0.5 >>> layer1 = new_attention_layer(embedding_size, n_heads, 'exact') >>> outputs1 = layer1(inputs, key_padding_mask=key_padding_mask)
- geowatch.tasks.fusion.architectures.transformer.new_mlp_layer(embedding_size, dropout, **kwargs)[source]¶
Example
>>> import torch >>> embedding_size = 3 >>> batch_size = 1 >>> layer = new_mlp_layer(embedding_size, dropout=0) >>> input_shape = (batch_size, embedding_size) >>> inputs = torch.rand(*input_shape) >>> outputs = layer(inputs) >>> assert outputs.shape == (batch_size, embedding_size)
- class geowatch.tasks.fusion.architectures.transformer.ChannelwiseTransformerEncoderLayer(axes, embedding_size, n_heads, dropout=0.0, default_shape=('batch', 'time', 'mode', 'height', 'width', 'feature'), feature_axis='feature', batch_axis='batch', attention_impl='exact', attention_kwargs=None)[source]¶
Bases:
Module
Todo
- [ ] Can we resitrict how far the spatial window looks, so it only
sees neighboring spatial regions?
- [ ] Flatten tokens completely and have a mask that indicates
what tokens are allowed to see each other in each step
Notes
Currently ‘mode’ might indicate something like a sensor or special computation. Each ‘mode’ might have a differet number of ‘features’. In the future this might be better specified as a dictionary that maps ‘mode’-codes to a tensor containing only the ‘features’ for that mode. E.g.:
- inputs = {
‘S2’: Tensor([B, T, H, W, 13]), ‘WV’: Tensor([B, T, H, W, 8]), ‘Latent’: Tensor([B, T, H, W, 512]), ‘Materials’: Tensor([B, T, H, W, 16]),
}
Currently these are all stacked into a B x T x M x H x W x max(F) and padded with zeros.
Correction: the last statement is not correct. Curently F is hard coded to be F = 1 * ws * ws (where ws is the window size), so features are really spatial positions in a window. And the ‘width’ and ‘height’ here refer to the ‘number of windows’ in the area.
Example
>>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA >>> import torch >>> image_size = 128 >>> # >>> ws = window_size = 32 >>> W = H = image_size // ws >>> B = batch_size = 2 >>> T = num_times = 3 >>> M = num_modes = 13 # hack for number of features in S2 >>> F = 1 * ws * ws # hack to use spatial positions in a windows as features >>> input_shape = (B, T, M, H, W, F) >>> x = torch.rand(*input_shape) >>> embedding_size = F # Embedding size must be equal to F >>> # >>> # ================================ >>> # Joint Attentions Across all Axes >>> self = ChannelwiseTransformerEncoderLayer( >>> axes=[('time', 'mode', 'height', 'width')], >>> default_shape=['batch', 'time', 'mode', 'height', 'width', 'feature'], >>> feature_axis='feature', >>> batch_axis='batch', >>> embedding_size=embedding_size, >>> n_heads=4 >>> ) >>> print(self) >>> outputs = self(x) >>> assert tuple(outputs.shape) == (2, 3, 13, 4, 4, 1024) >>> # >>> # ================================ >>> # Separable Attentions Across Time, Mode, and then Space >>> self = ChannelwiseTransformerEncoderLayer( >>> axes=[('time', 'mode'), ('height', 'width')], >>> default_shape=['batch', 'time', 'mode', 'height', 'width', 'feature'], >>> feature_axis='feature', >>> batch_axis='batch', >>> embedding_size=embedding_size, >>> n_heads=4 >>> ) >>> print(self) >>> outputs = self(x) >>> assert tuple(outputs.shape) == (2, 3, 13, 4, 4, 1024) >>> # >>> # ================================ >>> # Space Only Attention >>> self = ChannelwiseTransformerEncoderLayer( >>> axes=[('height', 'width')], >>> default_shape=['batch', 'time', 'mode', 'height', 'width', 'feature'], >>> feature_axis='feature', >>> batch_axis='batch', >>> embedding_size=embedding_size, >>> n_heads=4 >>> ) >>> print(self) >>> outputs = self(x) >>> assert tuple(outputs.shape) == (2, 3, 13, 4, 4, 1024)
- forward(inputs, flat_coordinates=None, key_padding_mask=None)[source]¶
- Parameters:
x (Tensor) – of shape B, T, M, H, W, F if flat_coordinates is unspecified otherwise it should be of shape N, F where N is the total number of tokens
flat_coordinates (Dict[str, Tensor]) – the time, mode, height, and width coordinate of each token if specified batches are unsupported
key_padding_mask (Tensor) – of shape B, T, M, H, W if flat_coordinates is unspecified otherwise should be of shape N. A True value means ignore the token.
CommandLine
xdoctest -m geowatch.tasks.fusion.architectures.transformer ChannelwiseTransformerEncoderLayer.forward
Example
>>> # Test that coordinate aware implementation exactly reproduces aligned variant >>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA >>> import numpy as np >>> F = embedding_size = 4 >>> B, T, M, H, W = 1, 3, 5, 7, 11 >>> aligned_inputs = aligned_x = (torch.rand(B, T, M, H, W, F) * 100).round() / 10 >>> key_padding_mask = torch.rand(B, T, M, H, W) > 0.9 >>> flat_inputs = flat_x = aligned_inputs.view(-1, embedding_size) >>> flat_kpm = key_padding_mask.view(-1) >>> #inputs = flat_inputs >>> flat_coordinates = None >>> inputs = aligned_inputs >>> # Test that coordinate-aware flat attention works >>> t_coords, m_coords, h_coords, w_coords = np.meshgrid(np.arange(T), np.arange(M), np.arange(H), np.arange(W), indexing='ij') >>> flat_coordinates = { >>> 'time': t_coords.ravel(), >>> 'mode': m_coords.ravel(), >>> 'height': h_coords.ravel(), >>> 'width': w_coords.ravel(), >>> } >>> flat_coordinates = ub.map_vals(torch.from_numpy, flat_coordinates) >>> self = ChannelwiseTransformerEncoderLayer( >>> #axes=[('height', 'width'), ('time',)], >>> axes=[('time',)], >>> default_shape=['batch', 'time', 'mode', 'height', 'width', 'feature'], >>> feature_axis='feature', >>> batch_axis='batch', >>> embedding_size=embedding_size, >>> n_heads=1 >>> ) >>> self = self.eval() >>> with torch.set_grad_enabled(False): >>> print('----') >>> flat_y = self.forward(flat_inputs, flat_coordinates) >>> print('----') >>> aligned_y = self.forward(aligned_inputs) >>> print('----') >>> aligned_y_mask = self.forward(aligned_inputs, key_padding_mask=key_padding_mask) >>> print('----') >>> flat_y_mask = self.forward(flat_inputs, flat_coordinates, key_padding_mask=flat_kpm) >>> print('----====-') >>> recon_y1 = aligned_y.view(-1, embedding_size) >>> recon_y1_mask = aligned_y_mask.view(-1, embedding_size) >>> print('flat_y=\n{!r}'.format(flat_y)) >>> print('recon_y1=\n{!r}'.format(recon_y1)) >>> abs_diff = (flat_y - recon_y1).abs().max() >>> print('abs_diff = {!r}'.format(abs_diff)) >>> assert abs_diff < 1e-5 >>> # >>> flat_y_mask.nan_to_num_() >>> recon_y1_mask.nan_to_num_() >>> abs_diff_mask = (flat_y_mask - recon_y1_mask).abs().max() >>> print('abs_diff_mask = {!r}'.format(abs_diff_mask)) >>> assert abs_diff_mask < 1e-5 >>> #flags = torch.isclose(flat_y, recon_y1) >>> #assert flags.all()
- class geowatch.tasks.fusion.architectures.transformer.TimmEncoder(arch_name='vit_base_patch16_224', pretrained=True, dropout=0.0, attention_impl='exact', in_features=None)[source]¶
Bases:
object
Example
>>> # xdoctest: +REQUIRES(module:timm) >>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA >>> import torch >>> in_features = 7 >>> input_shape = B, T, M, H, W, F = (2, 3, 5, 2, 2, in_features) >>> inputs = torch.rand(*input_shape) >>> arch_name = 'vit_base_patch16_224' >>> self = TimmEncoder(arch_name)
- class geowatch.tasks.fusion.architectures.transformer.MM_VITEncoder[source]¶
Bases:
Module
mmsegmentation variant of VIT
Needs 768 features.
Notes
https://github.com/open-mmlab/mmsegmentation/tree/master/configs/vit
- Results:
# 1 https://github.com/open-mmlab/mmsegmentation/tree/master/configs/vit#ade20k https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_vit-b16_mln_512x512_80k_ade20k.py https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/models/upernet_vit-b16_ln_mln.py
- class geowatch.tasks.fusion.architectures.transformer.DeiTEncoder(in_features, pretrained=True)[source]¶
Bases:
Module
https://github.com/rishikksh20/ViViT-pytorch https://pytorch.org/tutorials/beginner/vt_tutorial.html
Example
>>> # xdoctest: +REQUIRES(module:timm) >>> # xdoctest: +SKIP('can cause network issues on CI') >>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA >>> import torch >>> in_features = 7 >>> input_shape = B, T, M, H, W, F = (2, 3, 5, 2, 2, in_features) >>> inputs = torch.rand(*input_shape) >>> self = DeiTEncoder(in_features, pretrained=False) >>> outputs = self.forward(inputs)
- class geowatch.tasks.fusion.architectures.transformer.PerceiverEncoder(in_features, depth=4, dropout=0.0)[source]¶
Bases:
Module
https://github.com/lucidrains/perceiver-pytorch/blob/main/perceiver_pytorch/perceiver_io.py
Example
>>> # xdoctest: +REQUIRES(module:perceiver_pytorch) >>> from geowatch.tasks.fusion.architectures.transformer import PerceiverEncoder # NOQA >>> import torch >>> B, T, M, H, W, F = 1, 2, 3, 5, 8, 13 >>> self = PerceiverEncoder(F, dropout=0.1) >>> inputs = torch.rand(B, T, M, H, W, F) >>> outputs = self(inputs) >>> assert outputs.shape == (B, T, M, H, W, F)
- class geowatch.tasks.fusion.architectures.transformer.FusionEncoder(axes, default_shape=('batch', 'sequence', 'feature'), feature_axis='feature', batch_axis='batch', embedding_size=128, n_layers=4, n_heads=8, dropout=0.0, attention_impl='exact', attention_kwargs={}, in_features=None)[source]¶
Bases:
Module
Primary entry point to create a feature transformer
Performs multiple “channelwise” (maybe rename to axil?) attention encodings in a row
Example
>>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA >>> import torch >>> in_features = 7 >>> input_shape = B, T, M, H, W, F = (2, 3, 5, 2, 2, in_features) >>> inputs = torch.rand(*input_shape) >>> model = FusionEncoder( >>> in_features=in_features, >>> axes=[("time", "mode", "height", "width")], >>> default_shape=["batch", "time", "mode", "height", "width", "feature"], >>> feature_axis="feature", >>> batch_axis="batch", >>> n_layers=8, >>> embedding_size=256, >>> n_heads=4 >>> ) >>> model(inputs) >>> output = model(inputs) >>> assert output.shape == (2, 3, 5, 2, 2, 256) >>> # >>> # Test Lazy variant >>> model = FusionEncoder( >>> in_features=None, >>> axes=[("time", "mode", "height", "width")], >>> default_shape=["batch", "time", "mode", "height", "width", "feature"], >>> feature_axis="feature", >>> batch_axis="batch", >>> n_layers=8, >>> embedding_size=256, >>> n_heads=4 >>> ) >>> print(model) >>> inputs = torch.rand(*input_shape) >>> output = model(inputs) >>> assert output.shape == (2, 3, 5, 2, 2, 256)
- class geowatch.tasks.fusion.architectures.transformer.PreNorm(dim, fn, context_dim=None)[source]¶
Bases:
Module
- class geowatch.tasks.fusion.architectures.transformer.GEGLU(*args, **kwargs)[source]¶
Bases:
Module
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class geowatch.tasks.fusion.architectures.transformer.FeedForward(dim, mult=4)[source]¶
Bases:
Module
- class geowatch.tasks.fusion.architectures.transformer.Attention(query_dim, context_dim=None, output_dim=None, heads=8, dim_head=64)[source]¶
Bases:
Module
- class geowatch.tasks.fusion.architectures.transformer.TransformerEncoderDecoder(encoder_depth: int = 2, decoder_depth: int = 1, dim: int = 128, queries_dim: int = 96, logits_dim: int = 32, decode_cross_every: int = 1, cross_heads: int = 1, latent_heads: int = 8, cross_dim_head: int = 64, latent_dim_head: int = 64, weight_tie_layers: bool = False)[source]¶
Bases:
Module
,BackboneEncoderDecoder
- class geowatch.tasks.fusion.architectures.transformer.TransformerEncoderLayerExtended(d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: str = 'relu', layer_norm_eps: float = 1e-05, batch_first: bool = False, norm_first: bool = False, mha_kwargs=None, device=None, dtype=None)[source]¶
Bases:
TransformerEncoderLayer
- class geowatch.tasks.fusion.architectures.transformer.VanillaTransformerEncoder(num_layers: int, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.0, activation: str = 'gelu', layer_norm_eps: float = 1e-05, batch_first: bool = True, norm_first: bool = True, mha_kwargs=None)[source]¶
Bases:
Module
,BackboneEncoderDecoder
- class geowatch.tasks.fusion.architectures.transformer.MM_VITEncoderDecoder(dim, logits_dim, queries_dim=None, pretrained=None)[source]¶
Bases:
Module
,BackboneEncoderDecoder
mmsegmentation variant of VIT
Needs 768 features.
Notes
https://github.com/open-mmlab/mmsegmentation/tree/master/configs/vit
- Results:
# 1 https://github.com/open-mmlab/mmsegmentation/tree/master/configs/vit#ade20k https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_vit-b16_mln_512x512_80k_ade20k.py https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/models/upernet_vit-b16_ln_mln.py
Example
>>> # xdoctest: +REQUIRES(module:mmseg) >>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA >>> self = MM_VITEncoderDecoder(16, 16, 16) >>> x = torch.rand(2, 3, 16) >>> self.forward(x)
- pretrained_fpath_shortnames = {'upernet_vit-b16_mln_512x512_80k_ade20k': 'https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_vit-b16_mln_512x512_80k_ade20k/upernet_vit-b16_mln_512x512_80k_ade20k_20210624_130547-0403cee1.pth'}¶