geowatch.tasks.mae.predictV2 module¶
Basline Example:
DVC_DATA_DPATH=$(geowatch_dvc –tags=’phase2_data’ –hardware=auto) DVC_EXPT_DPATH=$(geowatch_dvc –tags=’phase2_expt’ –hardware=auto)
python -m geowatch.tasks.mae.predict –device=”cuda:0” –mae_ckpt_path=”/storage1/fs1/jacobsn/Active/user_s.sastry/smart_watch/new_models/checkpoints/Drop6-epoch=01-val_loss=0.20.ckpt” –input_kwcoco=”$DVC_DATA_DPATH/Drop6-MeanYear10GSD-V2/data_train_I2L_split6.kwcoco.zip” –output_kwcoco=”$DVC_DATA_DPATH/Drop6-MeanYear10GSD-V2/mae_v1_train_split6.kwcoco.zip” –window_space_scale=1.0 –workers=8 –io_workers=8
# After your model predicts the outputs, you should be able to use the # geowatch visualize tool to inspect your features. python -m geowatch visualize $DVC_DATA_DPATH/Drop6-MeanYear10GSD-V2/mae_v1_train_split6.kwcoco.zip –channels “red|green|blue,mae.8:11,mae.14:17” –stack=only –workers=avail –animate=True –draw_anns=False
- class geowatch.tasks.mae.predictV2.WatchDataset(coco_dset, sensor=['S2'], bands=['shared'], segmentation=False, patch_size=224, mask_patch_size=16, num_images=2, mode='train', patch_overlap=0.25, bas=True, rng=None, mask_pct=0.5, mask_time_width=2, temporal_mode='cat', window_space_scale=1.0)[source]¶
Bases:
Dataset
- S2_l2a_channel_names = ['B02.tif', 'B01.tif', 'B03.tif', 'B04.tif', 'B05.tif', 'B06.tif', 'B07.tif', 'B08.tif', 'B09.tif', 'B11.tif', 'B12.tif', 'B8A.tif']¶
- S2_channel_names = ['coastal', 'blue', 'green', 'red', 'B05', 'B06', 'B07', 'nir', 'B09', 'cirrus', 'swir16', 'swir22', 'B8A']¶
- L8_channel_names = ['coastal', 'lwir11', 'lwir12', 'blue', 'green', 'red', 'nir', 'swir16', 'swir22', 'pan', 'cirrus']¶
- class geowatch.tasks.mae.predictV2.MAEPredictConfig(*args, **kwargs)[source]¶
Bases:
DataConfig
Configuration for WashU MAE models
Valid options: []
- Parameters:
*args – positional arguments for this data config
**kwargs – keyword arguments for this data config
- default = {'assets_dname': <Value('_assets')>, 'bands': <Value(['shared'])>, 'batch_size': <Value(1)>, 'device': <Value('cuda:0')>, 'input_kwcoco': <Value(None)>, 'io_workers': <Value(8)>, 'mae_ckpt_path': <Value(None)>, 'output_kwcoco': <Value(None)>, 'patch_overlap': <Value(0.25)>, 'sensor': <Value(['S2', 'L8'])>, 'window_resolution': <Value(1.0)>, 'workers': <Value(4)>}¶
- class geowatch.tasks.mae.predictV2.Attention(dim, heads=8, dim_head=64, dropout=0.0)[source]¶
Bases:
Module
- class geowatch.tasks.mae.predictV2.Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=0.0)[source]¶
Bases:
Module
- class geowatch.tasks.mae.predictV2.ViT(*, image_size, image_patch_size, frames, frame_patch_size, dim, depth, heads, mlp_dim, channels=6, dim_head=64, dropout=0.0, emb_dropout=0.0)[source]¶
Bases:
Module
- class geowatch.tasks.mae.predictV2.MAE(*, encoder, decoder_dim, masking_ratio=0.75, decoder_depth=8, decoder_heads=8, decoder_dim_head=64)[source]¶
Bases:
Module