geowatch.tasks.mae.predictV3 module¶
Basline Example:
DVC_DATA_DPATH=$(geowatch_dvc –tags=’phase2_data’ –hardware=auto) DVC_EXPT_DPATH=$(geowatch_dvc –tags=’phase2_expt’ –hardware=auto) MAE_MODEL_FPATH=”$DVC_EXPT_DPATH/models/wu/wu_mae_2023_04_21/Drop6-epoch=01-val_loss=0.20.ckpt” KWCOCO_BUNDLE_DPATH=$DVC_DATA_DPATH/Drop7-MedianNoWinter10GSD
python -m geowatch.utils.simple_dvc request “$MAE_MODEL_FPATH”
# NOTE: different predict files correspond to different models # TODO: make the model size a parameter (or better yet inferred)
python -m geowatch.tasks.mae.predictV3 –device=”cuda:0” –mae_ckpt_path=”$MAE_MODEL_FPATH” –input_kwcoco=”$KWCOCO_BUNDLE_DPATH/imganns-KR_R001.kwcoco.zip” –output_kwcoco=”$KWCOCO_BUNDLE_DPATH/imganns-KR_R001-testmae.kwcoco.zip” –window_space_scale=1.0 –workers=8 –io_workers=8
# After your model predicts the outputs, you should be able to use the # geowatch visualize tool to inspect your features. python -m geowatch visualize “$KWCOCO_BUNDLE_DPATH/imganns-KR_R001-testmae.kwcoco.zip” –channels “red|green|blue,mae.8:11,mae.14:17” –stack=only –workers=avail –animate=True –draw_anns=False
- class geowatch.tasks.mae.predictV3.WatchDataset(coco_dset, sensor=['S2'], bands=['shared'], segmentation=False, patch_size=224, mask_patch_size=16, num_images=2, mode='train', patch_overlap=0.25, bas=True, rng=None, mask_pct=0.5, mask_time_width=2, temporal_mode='cat', window_space_scale=1.0)[source]¶
Bases:
Dataset
- S2_l2a_channel_names = ['B02.tif', 'B01.tif', 'B03.tif', 'B04.tif', 'B05.tif', 'B06.tif', 'B07.tif', 'B08.tif', 'B09.tif', 'B11.tif', 'B12.tif', 'B8A.tif']¶
- S2_channel_names = ['coastal', 'blue', 'green', 'red', 'B05', 'B06', 'B07', 'nir', 'B09', 'cirrus', 'swir16', 'swir22', 'B8A']¶
- L8_channel_names = ['coastal', 'lwir11', 'lwir12', 'blue', 'green', 'red', 'nir', 'swir16', 'swir22', 'pan', 'cirrus']¶
- class geowatch.tasks.mae.predictV3.MAEPredictConfig(*args, **kwargs)[source]¶
Bases:
DataConfig
Configuration for WashU MAE models
Valid options: []
- Parameters:
*args – positional arguments for this data config
**kwargs – keyword arguments for this data config
- default = {'assets_dname': <Value('_assets')>, 'bands': <Value(['shared'])>, 'batch_size': <Value(1)>, 'device': <Value('cuda:0')>, 'input_kwcoco': <Value(None)>, 'io_workers': <Value(8)>, 'mae_ckpt_path': <Value(None)>, 'output_kwcoco': <Value(None)>, 'patch_overlap': <Value(0.25)>, 'sensor': <Value(['S2', 'L8'])>, 'window_resolution': <Value(1.0)>, 'workers': <Value(4)>}¶
- class geowatch.tasks.mae.predictV3.Attention(dim, heads=8, dim_head=64, dropout=0.0)[source]¶
Bases:
Module
- class geowatch.tasks.mae.predictV3.Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=0.0)[source]¶
Bases:
Module
- class geowatch.tasks.mae.predictV3.ViT(*, image_size, image_patch_size, frames, frame_patch_size, dim, depth, heads, mlp_dim, channels=6, dim_head=64, dropout=0.0, emb_dropout=0.0)[source]¶
Bases:
Module
- class geowatch.tasks.mae.predictV3.MAE(*, encoder, decoder_dim, masking_ratio=0.75, decoder_depth=8, decoder_heads=8, decoder_dim_head=64)[source]¶
Bases:
Module