"""
Current encoder config names are:
'smt_it_joint_p8', 'smt_it_joint_n12', 'smt_it_joint_t12', 'smt_it_joint_t24',
'smt_it_joint_s12', 'smt_it_joint_s24', 'smt_it_joint_m24', 'smt_it_joint_l24',
'smt_it_stm_p8', 'smt_it_stm_n12', 'smt_it_stm_t12', 'smt_it_stm_t24',
'smt_it_stm_s12', 'smt_it_stm_s24', 'smt_it_stm_m24', 'smt_it_stm_l24',
'smt_it_sm_p8', 'smt_it_sm_n12', 'smt_it_sm_t12', 'smt_it_sm_t24',
'smt_it_sm_s12', 'smt_it_sm_s24', 'smt_it_sm_m24', 'smt_it_sm_l24',
'smt_it_st_p8', 'smt_it_st_n12', 'smt_it_st_t12', 'smt_it_st_t24',
'smt_it_st_s12', 'smt_it_st_s24', 'smt_it_st_m24', 'smt_it_st_l24',
'smt_it_tm_p8', 'smt_it_tm_n12', 'smt_it_tm_t12', 'smt_it_tm_t24',
'smt_it_tm_s12', 'smt_it_tm_s24', 'smt_it_tm_m24', 'smt_it_tm_l24',
'smt_it_s_p8', 'smt_it_s_n12', 'smt_it_s_t12', 'smt_it_s_t24',
'smt_it_s_s12', 'smt_it_s_s24', 'smt_it_s_m24', 'smt_it_s_l24',
'smt_it_t_p8', 'smt_it_t_n12', 'smt_it_t_t12', 'smt_it_t_t24',
'smt_it_t_s12', 'smt_it_t_s24', 'smt_it_t_m24', 'smt_it_t_l24',
'smt_it_hwtm_p8', 'smt_it_hwtm_n12', 'smt_it_hwtm_t12', 'smt_it_hwtm_t24',
'smt_it_hwtm_s12', 'smt_it_hwtm_s24', 'smt_it_hwtm_m24', 'smt_it_hwtm_l24',
'smt_it_m_p8', 'smt_it_m_n12', 'smt_it_m_t12', 'smt_it_m_t24',
'smt_it_m_s12', 'smt_it_m_s24', 'smt_it_m_m24', 'smt_it_m_l24',
'sm_it_joint_p8', 'sm_it_joint_n12', 'sm_it_joint_t12', 'sm_it_joint_t24',
'sm_it_joint_s12', 'sm_it_joint_s24', 'sm_it_joint_m24', 'sm_it_joint_l24',
'sm_it_sm_p8', 'sm_it_sm_n12', 'sm_it_sm_t12', 'sm_it_sm_t24',
'sm_it_sm_s12', 'sm_it_sm_s24', 'sm_it_sm_m24', 'sm_it_sm_l24'
Notes:
pip install reformer_pytorch
pip install performer-pytorch <- this one
"""
from functools import wraps
import torch
from torch import nn, einsum
import torch.nn.functional as F
import einops
from einops import rearrange, repeat
import ubelt as ub # NOQA
try:
import xdev
profile = xdev.profile
except Exception:
profile = ub.identity
[docs]
class ResidualSequential(nn.Sequential):
"""
A Sequential layer with a residual operation at the end
"""
def __init__(self, *args):
super().__init__(*args)
[docs]
def forward(self, x):
return x + super().forward(x)
[docs]
class ResidualAttentionSequential(ResidualSequential):
"""
Special case of ResidualSequential to support masking
"""
def __init__(self, norm, attention):
super().__init__(norm, attention)
[docs]
def forward(self, x, key_padding_mask=None):
h = x
h = self[0](h)
h = self[1](h, key_padding_mask=key_padding_mask)
return x + h
[docs]
def assert_allclose(a, b, rtol=1e-05, atol=1e-08):
"""
TODO: integrate with :func:`kwcoco.coco_sql_dataset.assert_dsets_allclose`.
Add to kwarray
Ignore:
>>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA
>>> import pytest
>>> a = np.random.rand(1, 2, 3)
>>> b = a + 0
>>> assert_allclose(a, b)
>>> b = np.random.rand(1, 2, 3)
>>> with pytest.raises(AssertionError):
>>> assert_allclose(a, b)
>>> b = a.copy()
>>> b.ravel()[0] += 1
>>> with pytest.raises(AssertionError):
>>> assert_allclose(a, b)
"""
a_shape = a.shape
b_shape = b.shape
if len(b_shape) != len(a_shape):
raise AssertionError(f'len(a.shape:={a_shape}) != len(b.shape:={b.shape})')
if b_shape != a_shape:
raise AssertionError(f'a.shape:={a_shape} != b.shape:={b.shape}')
import kwarray
import numpy as np
a = kwarray.ArrayAPI.numpy(a)
b = kwarray.ArrayAPI.numpy(b)
flag = np.allclose(a, b, rtol=rtol, atol=atol)
if flag:
...
else:
impl = kwarray.ArrayAPI.coerce(a)
flags = np.isclose(a, b)
num_close = flags.sum()
num_total = impl.numel(flags)
num_not_close = num_total - num_close
a_stats = kwarray.stats_dict(a)
b_stats = kwarray.stats_dict(b)
msg = ub.codeblock(
f'''
Failed closeness check
Found not close entries: {num_not_close} / {num_total}
a_stats = {ub.urepr(a_stats, nl=0, precision=4)}
b_stats = {ub.urepr(b_stats, nl=0, precision=4)}
''')
raise AssertionError(msg)
[docs]
class MultiheadSelfAttention(torch.nn.MultiheadAttention):
"""
Inherits from :class:`torch.nn.MultiheadAttention`
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
bias: add bias as module parameter. Default: True.
add_bias_kv: add bias to the key and value sequences at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
kdim: total number of features in key. Default: None.
vdim: total number of features in value. Default: None.
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
CommandLine:
xdoctest -m geowatch.tasks.fusion.architectures.transformer MultiheadSelfAttention
Example:
>>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA
>>> self = MultiheadSelfAttention(4, 1).eval()
>>> S, B, F = (7, 3, 4)
>>> x = (torch.rand(S, B, F) * 10).round()
>>> # Results should be independent of the batch dim
>>> y = self.forward(x)
>>> y0 = self.forward(x[:, 0:1, :])
>>> y1 = self.forward(x[:, 1:2, :])
>>> y2 = self.forward(x[:, 2:3, :])
>>> assert_allclose(y[:, 0:1, :], y0, rtol=1e-3, atol=1e-6)
>>> assert_allclose(y[:, 1:2, :], y1, rtol=1e-3, atol=1e-6)
>>> assert_allclose(y[:, 2:3, :], y2, rtol=1e-3, atol=1e-6)
>>> key_padding_mask = torch.rand(B, S) > 0.5
>>> masked_result = self.forward(x, key_padding_mask=key_padding_mask)
"""
def __init__(self, embed_dim, num_heads, *args, **kwargs):
super().__init__(embed_dim, num_heads, *args, **kwargs)
[docs]
def forward(self, x, key_padding_mask=None):
"""
Args:
x (Tensor) : of shape (seq, batch, feature)
key_padding_mask (Tensor) : of shape (batch, seq).
A value of True means we will **ignore** the token.
Returns:
attn_out : of shape (seq, batch, feature)
"""
# attention returns a tuple of output and weights, so just take the
# output
outs = super().forward(
query=x, key=x, value=x, key_padding_mask=key_padding_mask)
attn_out, attn_weights = outs
return attn_out
[docs]
class ModuleProperties(metaclass=MetaModuleProperties):
"""
Experimental way to get concisely property like behavior at a module level.
Inherit from this class. This class forwards metaclass magic that allows us
to register any function decorated with ``@property``. It also creates the
``__getattr__`` method that will be assigned to the module.
"""
[docs]
@classmethod
def getattr(cls, name):
try:
return cls._property_lut[name]()
except KeyError:
module_name = __name__ # generalize
raise AttributeError(f'Module {module_name!r} has no attribute {name!r}')
class __module_properties__(ModuleProperties):
"""
CommandLine:
xdoctest -m geowatch.tasks.fusion.architectures.transformer __module_properties__
Example:
>>> # xdoctest: +SKIP
>>> from geowatch.tasks.fusion.architectures import transformer as mod
>>> attr = mod.FastMultiheadSelfAttention
>>> print(f'attr={attr}')
>>> attr = mod.ReformerMultiheadedSelfAttention
>>> print(f'attr={attr}')
"""
@property
def FastMultiheadSelfAttention():
from geowatch.tasks.fusion.architectures.optional.performer_attention import FastMultiheadSelfAttention
return FastMultiheadSelfAttention
@property
def ReformerMultiheadedSelfAttention():
from geowatch.tasks.fusion.architectures.optional.reformer_attention import ReformerMultiheadedSelfAttention
return ReformerMultiheadedSelfAttention
__getattr__ = __module_properties__.getattr
[docs]
def new_attention_layer(embedding_size, n_heads, attention_impl='exact', **kwargs):
"""
Example:
>>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA
>>> import torch
>>> batch_size = 1
>>> embedding_size = 4
>>> n_heads = 2
>>> num_tokens = 3
>>> input_shape = (num_tokens, batch_size, embedding_size)
>>> inputs = torch.rand(*input_shape)
>>> layer1 = new_attention_layer(embedding_size, n_heads, 'exact')
>>> outputs1 = layer1(inputs)
>>> assert outputs1.shape == input_shape
>>> # xdoctest: +REQUIRES(module:performer_pytorch)
>>> layer2 = new_attention_layer(embedding_size, n_heads, 'performer')
>>> outputs2 = layer2(inputs)
>>> assert outputs2.shape == input_shape
Example:
>>> # Test with a mask
>>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA
>>> import torch
>>> batch_size = 1
>>> embedding_size = 4
>>> n_heads = 2
>>> num_tokens = 3
>>> input_shape = (num_tokens, batch_size, embedding_size)
>>> inputs = torch.rand(*input_shape)
>>> key_padding_mask = torch.rand(batch_size, num_tokens) > 0.5
>>> layer1 = new_attention_layer(embedding_size, n_heads, 'exact')
>>> outputs1 = layer1(inputs, key_padding_mask=key_padding_mask)
"""
if attention_impl == 'exact':
attention = MultiheadSelfAttention(embedding_size, n_heads, **kwargs)
elif attention_impl == 'performer':
import performer_pytorch # NOQA
# from performer_pytorch import SelfAttention
# attention = SelfAttention(dim=embedding_size, heads=n_heads)
from geowatch.tasks.fusion.architectures.optional.performer_attention import FastMultiheadSelfAttention
attention = FastMultiheadSelfAttention(embedding_size, n_heads, **kwargs)
elif attention_impl == 'reformer':
from geowatch.tasks.fusion.architectures.optional.reformer_attention import ReformerMultiheadedSelfAttention
attention = ReformerMultiheadedSelfAttention(embedding_size, n_heads, **kwargs)
else:
raise KeyError(attention_impl)
# num_groups = num_groups_hueristic(embedding_size)
# norm = nn.GroupNorm(num_groups=num_groups, num_channels=embedding_size)
norm = nn.LayerNorm(embedding_size)
layer = ResidualAttentionSequential(
norm,
attention,
)
return layer
[docs]
def new_mlp_layer(embedding_size, dropout, **kwargs):
"""
Example:
>>> import torch
>>> embedding_size = 3
>>> batch_size = 1
>>> layer = new_mlp_layer(embedding_size, dropout=0)
>>> input_shape = (batch_size, embedding_size)
>>> inputs = torch.rand(*input_shape)
>>> outputs = layer(inputs)
>>> assert outputs.shape == (batch_size, embedding_size)
"""
return ResidualSequential(
nn.Linear(embedding_size, embedding_size, **kwargs),
nn.Dropout(dropout),
nn.GELU(),
nn.Linear(embedding_size, embedding_size, **kwargs),
)
[docs]
class TimmEncoder:
"""
Example:
>>> # xdoctest: +REQUIRES(module:timm)
>>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA
>>> import torch
>>> in_features = 7
>>> input_shape = B, T, M, H, W, F = (2, 3, 5, 2, 2, in_features)
>>> inputs = torch.rand(*input_shape)
>>> arch_name = 'vit_base_patch16_224'
>>> self = TimmEncoder(arch_name)
"""
def __init__(self, arch_name='vit_base_patch16_224', pretrained=True,
dropout=0.0, attention_impl='exact', in_features=None):
import timm
self.timm_model = timm.create_model(arch_name, pretrained=True)
# embedding_size=128,
# n_layers=4,
# n_heads=8,
self.timm_model
timm.create_model('mobilenetv3_large_100_miil_in21k')
[docs]
class MM_VITEncoder(nn.Module):
"""
mmsegmentation variant of VIT
Needs 768 features.
Notes:
https://github.com/open-mmlab/mmsegmentation/tree/master/configs/vit
Results:
# 1
https://github.com/open-mmlab/mmsegmentation/tree/master/configs/vit#ade20k
https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_vit-b16_mln_512x512_80k_ade20k.py
https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/models/upernet_vit-b16_ln_mln.py
Ignore:
>>> from mmseg.models.backbones import vit
>>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA
>>> self = MM_VITEncoder()
>>> x = torch.rand(2, 3, 768)
>>> self.forward(x)
"""
def __init__(self):
super().__init__()
from mmseg.models.backbones.vit import VisionTransformer
kwargs = dict(
img_size=(512, 512),
patch_size=16,
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
out_indices=(2, 5, 8, 11),
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
with_cls_token=True,
norm_cfg=dict(type='LN', eps=1e-6),
act_cfg=dict(type='GELU'),
norm_eval=False,
interpolate_mode='bicubic')
vit_model = VisionTransformer(**kwargs)
# We only need the encoder
self.layers = vit_model.layers
self.initialize_from_pretrained()
self.in_features = self.layers[0].ln1.weight.shape[0]
self.out_features = self.layers[-1].ffn.layers[1].out_features
[docs]
def initialize_from_pretrained(self):
# pretrained_fpath = ub.grabdata('https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_vit-b16_mln_512x512_80k_ade20k/upernet_vit-b16_mln_512x512_80k_ade20k_20210624_130547-0403cee1.pth')
# FIXME: Having this import here breaks torch.package
# not exactly sure why
# from geowatch.tasks.fusion.fit import coerce_initializer
# initializer = coerce_initializer(pretrained_fpath)
# info = initializer.forward(self, verbose=0) # NOQA
...
[docs]
def forward(self, x):
orig_shape = x.shape
x = x.view(x.shape[0], -1, x.shape[1])
for i, layer in enumerate(self.layers):
x = layer(x)
# if i == len(self.layers) - 1:
# if self.final_norm:
# x = self.norm1(x)
# if i in self.out_indices:
# if self.with_cls_token:
# # Remove class token and reshape token for decoder head
# out = x[:, 1:]
# else:
# out = x
# B, _, C = out.shape
# out = out.reshape(B, hw_shape[0], hw_shape[1],
# C).permute(0, 3, 1, 2).contiguous()
# if self.output_cls_token:
# out = [out, x[:, 0]]
# outs.append(out)
x = x.view(*orig_shape[0], x.shape[-1])
return x
[docs]
class DeiTEncoder(nn.Module):
"""
https://github.com/rishikksh20/ViViT-pytorch
https://pytorch.org/tutorials/beginner/vt_tutorial.html
Example:
>>> # xdoctest: +REQUIRES(module:timm)
>>> # xdoctest: +SKIP('can cause network issues on CI')
>>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA
>>> import torch
>>> in_features = 7
>>> input_shape = B, T, M, H, W, F = (2, 3, 5, 2, 2, in_features)
>>> inputs = torch.rand(*input_shape)
>>> self = DeiTEncoder(in_features, pretrained=False)
>>> outputs = self.forward(inputs)
"""
def __init__(self, in_features, pretrained=True):
super().__init__()
deit = torch.hub.load('facebookresearch/deit:main',
'deit_base_patch16_224', pretrained=pretrained)
blocks = deit.blocks
block_in_features = blocks[0].norm1.weight.shape[0]
self.first = nn.Linear(in_features, out_features=block_in_features)
self.blocks = blocks
self.in_features = in_features
self.out_features = blocks[-1].mlp.fc2.out_features
[docs]
def forward(self, inputs):
B, T, M, H, W, F = inputs.shape
x = einops.rearrange(inputs, 'b t m h w f -> b (t m h w) f')
x = self.first(x)
x = self.blocks(x)
outputs = einops.rearrange(x, 'b (t m h w) f -> b t m h w f', t=T, m=M, h=H, w=W)
return outputs
[docs]
class PerceiverEncoder(nn.Module):
"""
https://github.com/lucidrains/perceiver-pytorch/blob/main/perceiver_pytorch/perceiver_io.py
Example:
>>> # xdoctest: +REQUIRES(module:perceiver_pytorch)
>>> from geowatch.tasks.fusion.architectures.transformer import PerceiverEncoder # NOQA
>>> import torch
>>> B, T, M, H, W, F = 1, 2, 3, 5, 8, 13
>>> self = PerceiverEncoder(F, dropout=0.1)
>>> inputs = torch.rand(B, T, M, H, W, F)
>>> outputs = self(inputs)
>>> assert outputs.shape == (B, T, M, H, W, F)
"""
def __init__(self, in_features, depth=4, dropout=0.0):
super().__init__()
import perceiver_pytorch as perceiver
# No dropout in perceiver? Perform it on input tokens.
self.dropout = nn.Dropout(dropout)
self.perceiver = perceiver.PerceiverIO(
depth=depth,
dim=in_features,
queries_dim=in_features,
num_latents=512,
latent_dim=256,
cross_heads=1,
latent_heads=8,
cross_dim_head=64,
latent_dim_head=64,
weight_tie_layers=False,
decoder_ff=False,
logits_dim=None,
)
self.in_features = in_features
self.out_features = in_features
[docs]
def forward(self, inputs):
B, T, M, H, W, F = inputs.shape
x = einops.rearrange(inputs, 'b t m h w f -> b (t m h w) f')
x = self.dropout(x)
x = self.perceiver(x, queries=x)
outputs = einops.rearrange(x, 'b (t m h w) f -> b t m h w f', t=T, m=M, h=H, w=W)
return outputs
[docs]
class FusionEncoder(nn.Module):
"""
Primary entry point to create a feature transformer
Performs multiple "channelwise" (maybe rename to axil?) attention
encodings in a row
Example:
>>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA
>>> import torch
>>> in_features = 7
>>> input_shape = B, T, M, H, W, F = (2, 3, 5, 2, 2, in_features)
>>> inputs = torch.rand(*input_shape)
>>> model = FusionEncoder(
>>> in_features=in_features,
>>> axes=[("time", "mode", "height", "width")],
>>> default_shape=["batch", "time", "mode", "height", "width", "feature"],
>>> feature_axis="feature",
>>> batch_axis="batch",
>>> n_layers=8,
>>> embedding_size=256,
>>> n_heads=4
>>> )
>>> model(inputs)
>>> output = model(inputs)
>>> assert output.shape == (2, 3, 5, 2, 2, 256)
>>> #
>>> # Test Lazy variant
>>> model = FusionEncoder(
>>> in_features=None,
>>> axes=[("time", "mode", "height", "width")],
>>> default_shape=["batch", "time", "mode", "height", "width", "feature"],
>>> feature_axis="feature",
>>> batch_axis="batch",
>>> n_layers=8,
>>> embedding_size=256,
>>> n_heads=4
>>> )
>>> print(model)
>>> inputs = torch.rand(*input_shape)
>>> output = model(inputs)
>>> assert output.shape == (2, 3, 5, 2, 2, 256)
Ignore:
traced = torch.jit.trace(model, inputs)
import timerit
ti = timerit.Timerit(5, bestof=1, verbose=2)
for timer in ti.reset('time'):
model(inputs)
for timer in ti.reset('time'):
traced(inputs)
Ignore:
>>> # Get a sense of the arch size
>>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA
>>> rows = []
>>> from geowatch.utils import util_netharn
>>> for key, config in ub.ProgIter(list(encoder_configs.items())):
>>> self = FusionEncoder(in_features=256, **config)
>>> num_params = util_netharn.number_of_parameters(self)
>>> row = {'arch': key, 'num_params': num_params}
>>> row.update(config)
>>> print('row = {}'.format(ub.urepr(row, nl=0, sort=0)))
>>> rows.append(row)
>>> import pandas as pd
>>> data = pd.DataFrame(rows).sort_values('num_params')
>>> print(data.to_string())
>>> # Look at only smt configs
>>> flags = data['axes'].apply(lambda x: x == [("height", "width"), ("time",), ("mode",)])
>>> print(data[flags].to_string())
"""
def __init__(self, axes,
default_shape=('batch', 'sequence', 'feature'),
feature_axis='feature',
batch_axis='batch',
embedding_size=128,
n_layers=4,
n_heads=8,
dropout=0.0,
attention_impl='exact',
attention_kwargs=dict(),
in_features=None):
super().__init__()
if in_features is None:
# Use lazy linear to allow data to specify the channel dims
first = nn.LazyLinear(embedding_size)
else:
first = nn.Linear(in_features=in_features, out_features=embedding_size)
_layers = [
ChannelwiseTransformerEncoderLayer(
axes,
embedding_size=embedding_size,
n_heads=n_heads,
dropout=dropout,
default_shape=default_shape,
feature_axis=feature_axis,
batch_axis=batch_axis,
attention_impl=attention_impl,
attention_kwargs=attention_kwargs,
)
for _ in range(n_layers)
]
self.in_features = in_features
self.out_features = embedding_size
self.first = first
self.layers = nn.Sequential(*_layers)
[docs]
def forward(self, x, flat_coordinates=None, key_padding_mask=None, mask=None):
if mask is not None:
key_padding_mask = mask
x = self.first(x)
for layer in self.layers:
# Can't use sequentail because we need extra args
x = layer(x, flat_coordinates=flat_coordinates,
key_padding_mask=key_padding_mask)
return x
def _build_global_configs():
"""
Previously we manually defined a bunch of functions, now it
is defined programatically
"""
# dont define tons of functions, use a configuration dictionary
_smt_axes_basis = dict(
joint=[('time', 'mode', 'height', 'width')],
stm=[('height', 'width'), ('time',), ('mode',)],
sm=[('height', 'width'), ('mode',)],
st=[('height', 'width'), ('time',)],
tm=[('time',), ('mode',)],
s=[('height', 'width')],
t=[('time',)],
hwtm=[('height',), ('width',), ('time',), ('mode',)],
m=[('mode',)],
)
# Names are inspired by:
# pico
# nano
# small
# medium
# large
# But they really dont line up with the intuitions
_encoder_size_basis = {
'p1': dict(n_layers=1, embedding_size=64, n_heads=4),
'p2': dict(n_layers=2, embedding_size=64, n_heads=4),
'p2w': dict(n_layers=2, embedding_size=128, n_heads=8),
'p3': dict(n_layers=3, embedding_size=128, n_heads=4),
'p4': dict(n_layers=4, embedding_size=128, n_heads=4),
'p8': dict(n_layers=8, embedding_size=128, n_heads=4),
'b8': dict(n_layers=8, embedding_size=384, n_heads=4),
'm8': dict(n_layers=8, embedding_size=512, n_heads=4),
'p16': dict(n_layers=16, embedding_size=128, n_heads=4),
'p24': dict(n_layers=24, embedding_size=128, n_heads=4),
'p32': dict(n_layers=32, embedding_size=128, n_heads=4),
'n12': dict(n_layers=12, embedding_size=128, n_heads=4),
't12': dict(n_layers=12, embedding_size=192, n_heads=4),
't24': dict(n_layers=24, embedding_size=192, n_heads=4),
's12': dict(n_layers=12, embedding_size=384, n_heads=8),
's24': dict(n_layers=24, embedding_size=384, n_heads=8),
'm24': dict(n_layers=24, embedding_size=512, n_heads=8),
'l24': dict(n_layers=24, embedding_size=768, n_heads=8),
}
# space-mode-time transformer params
_smt_value = dict(
default_shape=['batch', 'time', 'mode', 'height', 'width', 'feature'],
feature_axis='feature',
batch_axis='batch',
)
encoder_configs = {}
for axes_code, axes_value in _smt_axes_basis.items():
for size_code, size_value in _encoder_size_basis.items():
code = f'smt_it_{axes_code}_{size_code}'
encoder_configs[code] = ub.dict_union(
size_value, _smt_value, dict(axes=axes_value))
# space-mode transformer params
_sm_value = dict(
default_shape=['batch', 'mode', 'height', 'width', 'feature'],
feature_axis='feature',
batch_axis='batch',
)
_sm_axes_basis = {
'joint': [('mode', 'height', 'width')],
'sm': [('height', 'width'), ('mode',)],
}
for axes_code, axes_value in _sm_axes_basis.items():
for size_code, size_value in _encoder_size_basis.items():
code = f'sm_it_{axes_code}_{size_code}'
encoder_configs[code] = ub.dict_union(
size_value, _sm_value, dict(axes=axes_value))
return encoder_configs
encoder_configs = _build_global_configs()
# print('encoder_configs = {}'.format(ub.urepr(list(encoder_configs.keys(), nl=1)))
# ========================================
# Below is an implementation of the transformer architecture that uses the
# same base components and follows the same patterns as the perceiver. This
# is a 1-to-1 drop-in replacement for perceiver models (minus latent_dim/num_latents options).
# ========================================
[docs]
def default(val, d):
return val if val is not None else d
[docs]
def cache_fn(f):
cache = None
@wraps(f)
def cached_fn(*args, _cache=True, **kwargs):
if not _cache:
return f(*args, **kwargs)
nonlocal cache
if cache is not None:
return cache
cache = f(*args, **kwargs)
return cache
return cached_fn
# helper classes
[docs]
class PreNorm(nn.Module):
def __init__(self, dim, fn, context_dim=None):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
self.norm_context = nn.LayerNorm(context_dim) if context_dim is not None else None
[docs]
def forward(self, x, **kwargs):
x = self.norm(x)
if self.norm_context is not None:
context = kwargs['context']
normed_context = self.norm_context(context)
kwargs.update(context=normed_context)
return self.fn(x, **kwargs)
[docs]
class GEGLU(nn.Module):
[docs]
def forward(self, x):
x, gates = x.chunk(2, dim=-1)
return x * F.gelu(gates)
[docs]
class FeedForward(nn.Module):
def __init__(self, dim, mult=4):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Linear(dim * mult, dim)
)
[docs]
def forward(self, x):
return self.net(x)
[docs]
class Attention(nn.Module):
def __init__(self, query_dim, context_dim=None, output_dim=None, heads=8, dim_head=64):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
output_dim = default(output_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, output_dim)
[docs]
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
[docs]
class BackboneEncoderDecoder:
pass
[docs]
class MM_VITEncoderDecoder(nn.Module, BackboneEncoderDecoder):
"""
mmsegmentation variant of VIT
Needs 768 features.
Notes:
https://github.com/open-mmlab/mmsegmentation/tree/master/configs/vit
Results:
# 1
https://github.com/open-mmlab/mmsegmentation/tree/master/configs/vit#ade20k
https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_vit-b16_mln_512x512_80k_ade20k.py
https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/models/upernet_vit-b16_ln_mln.py
Example:
>>> # xdoctest: +REQUIRES(module:mmseg)
>>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA
>>> self = MM_VITEncoderDecoder(16, 16, 16)
>>> x = torch.rand(2, 3, 16)
>>> self.forward(x)
Ignore:
>>> # xdoctest: +REQUIRES(module:mmseg)
>>> # This tests downloading weights from the MM repo
>>> from geowatch.tasks.fusion.architectures.transformer import * # NOQA
>>> self = MM_VITEncoderDecoder(16, 16, 16, pretrained="upernet_vit-b16_mln_512x512_80k_ade20k")
>>> x = torch.rand(2, 3, 16)
>>> self.forward(x)
"""
pretrained_fpath_shortnames = {
"upernet_vit-b16_mln_512x512_80k_ade20k":
'https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_vit-b16_mln_512x512_80k_ade20k/upernet_vit-b16_mln_512x512_80k_ade20k_20210624_130547-0403cee1.pth',
}
def __init__(
self,
dim,
logits_dim,
queries_dim=None,
pretrained=None,
):
from mmseg.models.backbones.vit import VisionTransformer
super().__init__()
# if a short name is used, replace it with the appropriate full path
if pretrained in MM_VITEncoderDecoder.pretrained_fpath_shortnames.keys():
pretrained = MM_VITEncoderDecoder.pretrained_fpath_shortnames[pretrained]
pretrained = ub.grabdata(pretrained)
kwargs = dict(
pretrained=pretrained,
img_size=(512, 512),
patch_size=16,
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
out_indices=(2, 5, 8, 11),
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
with_cls_token=True,
norm_cfg=dict(type='LN', eps=1e-6),
act_cfg=dict(type='GELU'),
norm_eval=False,
interpolate_mode='bicubic')
vit_model = VisionTransformer(**kwargs)
# We only need the encoder
self.layers = vit_model.layers
# if a pretrained path is provided, try to use it
# if isinstance(pretrained, str):
# self.initialize_from_pretrained(pretrained)
self.encoder_in_features = self.layers[0].ln1.weight.shape[0]
self.encoder_out_features = self.layers[-1].ffn.layers[1].out_features
self.input_projector = nn.Linear(dim, self.encoder_in_features)
self.output_projector = nn.Linear(self.encoder_out_features, logits_dim)
self.has_decoder = (queries_dim is not None) and (queries_dim > 0)
if self.has_decoder:
self.query_projector = nn.Linear(queries_dim, self.encoder_out_features)
self.decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=self.encoder_out_features, nhead=8, dim_feedforward=512, batch_first=True),
num_layers=1,
)
[docs]
def initialize_from_pretrained(self, fpath):
# FIXME: Having this import here breaks torch.package
# initializer = coerce_initializer(fpath)
# info = initializer.forward(self, verbose=0) # NOQA
pass
[docs]
def forward(self, x, mask=None, queries=None):
# orig_shape = x.shape
# x = x.view(x.shape[0], -1, x.shape[1])
x = self.input_projector(x)
for i, layer in enumerate(self.layers):
x = layer(x)
# if i == len(self.layers) - 1:
# if self.final_norm:
# x = self.norm1(x)
# if i in self.out_indices:
# if self.with_cls_token:
# # Remove class token and reshape token for decoder head
# out = x[:, 1:]
# else:
# out = x
# B, _, C = out.shape
# out = out.reshape(B, hw_shape[0], hw_shape[1],
# C).permute(0, 3, 1, 2).contiguous()
# if self.output_cls_token:
# out = [out, x[:, 0]]
# outs.append(out)
# x = x.view(*orig_shape[0], x.shape[-1])
if (queries is None) or (not self.has_decoder):
return x
queries = self.query_projector(queries)
x = self.decoder(queries, x)
x = self.output_projector(x)
return x