geowatch.tasks.mae.predictV3 module

Basline Example:

DVC_DATA_DPATH=$(geowatch_dvc –tags=’phase2_data’ –hardware=auto) DVC_EXPT_DPATH=$(geowatch_dvc –tags=’phase2_expt’ –hardware=auto) MAE_MODEL_FPATH=”$DVC_EXPT_DPATH/models/wu/wu_mae_2023_04_21/Drop6-epoch=01-val_loss=0.20.ckpt” KWCOCO_BUNDLE_DPATH=$DVC_DATA_DPATH/Drop7-MedianNoWinter10GSD

python -m geowatch.utils.simple_dvc request “$MAE_MODEL_FPATH”

# NOTE: different predict files correspond to different models # TODO: make the model size a parameter (or better yet inferred)

python -m geowatch.tasks.mae.predictV3 –device=”cuda:0” –mae_ckpt_path=”$MAE_MODEL_FPATH” –input_kwcoco=”$KWCOCO_BUNDLE_DPATH/imganns-KR_R001.kwcoco.zip” –output_kwcoco=”$KWCOCO_BUNDLE_DPATH/imganns-KR_R001-testmae.kwcoco.zip” –window_space_scale=1.0 –workers=8 –io_workers=8

# After your model predicts the outputs, you should be able to use the # geowatch visualize tool to inspect your features. python -m geowatch visualize “$KWCOCO_BUNDLE_DPATH/imganns-KR_R001-testmae.kwcoco.zip” –channels “red|green|blue,mae.8:11,mae.14:17” –stack=only –workers=avail –animate=True –draw_anns=False

class geowatch.tasks.mae.predictV3.WatchDataset(coco_dset, sensor=['S2'], bands=['shared'], segmentation=False, patch_size=224, mask_patch_size=16, num_images=2, mode='train', patch_overlap=0.25, bas=True, rng=None, mask_pct=0.5, mask_time_width=2, temporal_mode='cat', window_space_scale=1.0)[source]

Bases: Dataset

S2_l2a_channel_names = ['B02.tif', 'B01.tif', 'B03.tif', 'B04.tif', 'B05.tif', 'B06.tif', 'B07.tif', 'B08.tif', 'B09.tif', 'B11.tif', 'B12.tif', 'B8A.tif']
S2_channel_names = ['coastal', 'blue', 'green', 'red', 'B05', 'B06', 'B07', 'nir', 'B09', 'cirrus', 'swir16', 'swir22', 'B8A']
L8_channel_names = ['coastal', 'lwir11', 'lwir12', 'blue', 'green', 'red', 'nir', 'swir16', 'swir22', 'pan', 'cirrus']
update_target_properties(target)[source]

Populate the target so it has the correct input scale and bands.

class geowatch.tasks.mae.predictV3.MAEPredictConfig(*args, **kwargs)[source]

Bases: DataConfig

Configuration for WashU MAE models

Valid options: []

Parameters:
  • *args – positional arguments for this data config

  • **kwargs – keyword arguments for this data config

default = {'assets_dname': <Value('_assets')>, 'bands': <Value(['shared'])>, 'batch_size': <Value(1)>, 'device': <Value('cuda:0')>, 'input_kwcoco': <Value(None)>, 'io_workers': <Value(8)>, 'mae_ckpt_path': <Value(None)>, 'output_kwcoco': <Value(None)>, 'patch_overlap': <Value(0.25)>, 'sensor': <Value(['S2', 'L8'])>, 'window_resolution': <Value(1.0)>, 'workers': <Value(4)>}
geowatch.tasks.mae.predictV3.pair(t)[source]
class geowatch.tasks.mae.predictV3.PreNorm(dim, fn)[source]

Bases: Module

forward(x, **kwargs)[source]
class geowatch.tasks.mae.predictV3.FeedForward(dim, hidden_dim, dropout=0.0)[source]

Bases: Module

forward(x)[source]
class geowatch.tasks.mae.predictV3.Attention(dim, heads=8, dim_head=64, dropout=0.0)[source]

Bases: Module

forward(x)[source]
class geowatch.tasks.mae.predictV3.Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=0.0)[source]

Bases: Module

forward(x)[source]
class geowatch.tasks.mae.predictV3.ViT(*, image_size, image_patch_size, frames, frame_patch_size, dim, depth, heads, mlp_dim, channels=6, dim_head=64, dropout=0.0, emb_dropout=0.0)[source]

Bases: Module

forward(video)[source]
class geowatch.tasks.mae.predictV3.MAE(*, encoder, decoder_dim, masking_ratio=0.75, decoder_depth=8, decoder_heads=8, decoder_dim_head=64)[source]

Bases: Module

forward(img)[source]
class geowatch.tasks.mae.predictV3.MaeCityscape(dataset, **kwargs)[source]

Bases: LightningModule

forward(x)[source]
shared_step(batch, batch_idx)[source]
geowatch.tasks.mae.predictV3.sigmoid(a)[source]
class geowatch.tasks.mae.predictV3.Predict(args)[source]

Bases: object

geowatch.tasks.mae.predictV3.main()[source]