geowatch.tasks.fusion.architectures.wu_mae module¶
A liberated of WU’s MAE pretrained backbone
import sys sys.path.append(‘/data/joncrall/dvc-repos/smart_expt_dvc/models/wu/MAE-2023-02-09’) import pred_features
import liberator lib = liberator.Liberator() lib.add_dynamic(pred_features.ViT) lib.expand([‘pred_features’]) print(lib.current_sourcecode())
- class geowatch.tasks.fusion.architectures.wu_mae.FeedForward(dim, hidden_dim, dropout=0.0)[source]¶
Bases:
Module
- class geowatch.tasks.fusion.architectures.wu_mae.Attention(dim, heads=8, dim_head=64, dropout=0.0)[source]¶
Bases:
Module
- class geowatch.tasks.fusion.architectures.wu_mae.Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=0.0)[source]¶
Bases:
Module
Example
>>> dim = 4 >>> depth = 3 >>> heads = 2 >>> dim_head = 2 >>> mlp_dim = 2 >>> self = Transformer(dim, depth, heads, dim_head, mlp_dim) >>> x = torch.rand(2, 3, dim) >>> mask = torch.rand(2, 3) > 0 >>> out = self.forward(x, mask) >>> print(f'out.shape={out.shape}')
- class geowatch.tasks.fusion.architectures.wu_mae.ViT(*, image_size, image_patch_size, frames, frame_patch_size, dim, depth, heads, mlp_dim, channels=6, dim_head=64, dropout=0.0, emb_dropout=0.0)[source]¶
Bases:
Module
- geowatch.tasks.fusion.architectures.wu_mae.wu_backbone()[source]¶
from torch_liberator import Pretrained ckpt_fpath = ‘/home/joncrall/remote/toothbrush/data/dvc-repos/smart_expt_dvc/models/wu/MAE-2023-02-09/goldenMae-epoch=07-val_loss=0.23.ckpt’ initializer = Pretrained(ckpt_fpath, association=’embedding’) vit = wu_backbone() result = initializer.forward(vit)