geowatch.tasks.fusion.architectures.wu_mae module

A liberated of WU’s MAE pretrained backbone

import sys sys.path.append(‘/data/joncrall/dvc-repos/smart_expt_dvc/models/wu/MAE-2023-02-09’) import pred_features

import liberator lib = liberator.Liberator() lib.add_dynamic(pred_features.ViT) lib.expand([‘pred_features’]) print(lib.current_sourcecode())

class geowatch.tasks.fusion.architectures.wu_mae.PreNorm(dim, fn)[source]

Bases: Module

forward(x, **kwargs)[source]
class geowatch.tasks.fusion.architectures.wu_mae.FeedForward(dim, hidden_dim, dropout=0.0)[source]

Bases: Module

forward(x, mask)[source]
class geowatch.tasks.fusion.architectures.wu_mae.Attention(dim, heads=8, dim_head=64, dropout=0.0)[source]

Bases: Module

forward(x, mask)[source]
geowatch.tasks.fusion.architectures.wu_mae.pair(t)[source]
class geowatch.tasks.fusion.architectures.wu_mae.Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=0.0)[source]

Bases: Module

Example

>>> dim = 4
>>> depth = 3
>>> heads = 2
>>> dim_head = 2
>>> mlp_dim = 2
>>> self = Transformer(dim, depth, heads, dim_head, mlp_dim)
>>> x = torch.rand(2, 3, dim)
>>> mask = torch.rand(2, 3) > 0
>>> out = self.forward(x, mask)
>>> print(f'out.shape={out.shape}')
forward(x, mask)[source]
class geowatch.tasks.fusion.architectures.wu_mae.ViT(*, image_size, image_patch_size, frames, frame_patch_size, dim, depth, heads, mlp_dim, channels=6, dim_head=64, dropout=0.0, emb_dropout=0.0)[source]

Bases: Module

forward(video)[source]
geowatch.tasks.fusion.architectures.wu_mae.wu_backbone()[source]

from torch_liberator import Pretrained ckpt_fpath = ‘/home/joncrall/remote/toothbrush/data/dvc-repos/smart_expt_dvc/models/wu/MAE-2023-02-09/goldenMae-epoch=07-val_loss=0.23.ckpt’ initializer = Pretrained(ckpt_fpath, association=’embedding’) vit = wu_backbone() result = initializer.forward(vit)