geowatch.tasks.mae.predict module

Baseline Example:

DVC_DATA_DPATH=$(geowatch_dvc –tags=’phase2_data’ –hardware=auto) DVC_EXPT_DPATH=$(geowatch_dvc –tags=’phase2_expt’ –hardware=auto) MAE_MODEL_FPATH=”$DVC_EXPT_DPATH/models/wu/wu_mae_2023_04_21/Drop6-epoch=01-val_loss=0.20.ckpt” KWCOCO_BUNDLE_DPATH=$DVC_DATA_DPATH/Drop7-MedianNoWinter10GSD

python -m geowatch.utils.simple_dvc request “$MAE_MODEL_FPATH”

python -m geowatch.tasks.mae.predict –device=”cuda:0” –mae_ckpt_path=”$MAE_MODEL_FPATH” –input_kwcoco=”$KWCOCO_BUNDLE_DPATH/imganns-KR_R001.kwcoco.zip” –output_kwcoco=”$KWCOCO_BUNDLE_DPATH/imganns-KR_R001-testmae2.kwcoco.zip” –window_space_scale=1.0 –workers=8 –assets_dname=teamfeats2 –io_workers=8

# After your model predicts the outputs, you should be able to use the # geowatch visualize tool to inspect your features. python -m geowatch visualize “$KWCOCO_BUNDLE_DPATH/imganns-KR_R001-testmae.kwcoco.zip” –channels “red|green|blue,mae.8:11,mae.14:17” –stack=only –workers=avail –animate=True –draw_anns=False

# Batch computing

export CUDA_VISIBLE_DEVICES=”1” DVC_DATA_DPATH=$(geowatch_dvc –tags=phase2_data –hardware=”hdd”) DVC_EXPT_DPATH=$(geowatch_dvc –tags=’phase2_expt’ –hardware=’auto’) BUNDLE_DPATH=$DVC_DATA_DPATH/Drop7-MedianNoWinter10GSD python -m geowatch.cli.queue_cli.prepare_teamfeats –base_fpath “$BUNDLE_DPATH”/imganns-*[0-9].kwcoco.zip –expt_dvc_dpath=”$DVC_EXPT_DPATH” –with_mae=1 –skip_existing=1 –assets_dname=teamfeats –gres=0,1 –tmux_workers=8 –backend=tmux –run=1

class geowatch.tasks.mae.predict.MAEPredictConfig(*args, **kwargs)[source]

Bases: DataConfig

Configuration for WashU MAE models

Valid options: []

Parameters:
  • *args – positional arguments for this data config

  • **kwargs – keyword arguments for this data config

default = {'assets_dname': <Value('_assets')>, 'bands': <Value(['shared'])>, 'batch_size': <Value(1)>, 'device': <Value('cuda:0')>, 'input_kwcoco': <Value(None)>, 'io_workers': <Value(8)>, 'mae_ckpt_path': <Value(None)>, 'output_kwcoco': <Value(None)>, 'patch_overlap': <Value(0.25)>, 'sensor': <Value(['S2', 'L8'])>, 'window_resolution': <Value(1.0)>, 'workers': <Value(4)>}
class geowatch.tasks.mae.predict.WatchDataset(coco_dset, sensor=['S2'], bands=['shared'], segmentation=False, patch_size=224, mask_patch_size=16, num_images=2, mode='train', patch_overlap=0.25, bas=True, rng=None, mask_pct=0.5, mask_time_width=2, temporal_mode='cat', window_space_scale=1.0)[source]

Bases: Dataset

Example

>>> # xdoctest: +REQUIRES(env:DVC_DPATH)
>>> from geowatch.tasks.mae.predict import *  # NOQA
>>> import geowatch
>>> import kwcoco
>>> import ubelt as ub
>>> dvc_dpath = geowatch.find_dvc_dpath(tags='drop7_data', hardware='auto')
>>> coco_fpath = dvc_dpath / 'Drop7-Cropped2GSD/BR_R002/BR_R002.kwcoco.zip'
>>> self = WatchDataset(coco_fpath)
>>> for idx in ub.ProgIter(range(len(self))):
>>>     images, item = self[idx]
S2_l2a_channel_names = ['B02.tif', 'B01.tif', 'B03.tif', 'B04.tif', 'B05.tif', 'B06.tif', 'B07.tif', 'B08.tif', 'B09.tif', 'B11.tif', 'B12.tif', 'B8A.tif']
S2_channel_names = ['coastal', 'blue', 'green', 'red', 'B05', 'B06', 'B07', 'nir', 'B09', 'cirrus', 'swir16', 'swir22', 'B8A']
L8_channel_names = ['coastal', 'lwir11', 'lwir12', 'blue', 'green', 'red', 'nir', 'swir16', 'swir22', 'pan', 'cirrus']
update_target_properties(target)[source]

Populate the target so it has the correct input scale and bands.

geowatch.tasks.mae.predict.pair(t)[source]
class geowatch.tasks.mae.predict.PreNorm(dim, fn)[source]

Bases: Module

forward(x, **kwargs)[source]
class geowatch.tasks.mae.predict.FeedForward(dim, hidden_dim, dropout=0.0)[source]

Bases: Module

forward(x)[source]
class geowatch.tasks.mae.predict.Attention(dim, heads=8, dim_head=64, dropout=0.0)[source]

Bases: Module

forward(x)[source]
class geowatch.tasks.mae.predict.Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=0.0)[source]

Bases: Module

forward(x)[source]
class geowatch.tasks.mae.predict.ViT(*, image_size, image_patch_size, frames, frame_patch_size, dim, depth, heads, mlp_dim, channels=6, dim_head=64, dropout=0.0, emb_dropout=0.0)[source]

Bases: Module

forward(video)[source]
class geowatch.tasks.mae.predict.MAE(*, encoder, decoder_dim, masking_ratio=0.75, decoder_depth=8, decoder_heads=8, decoder_dim_head=64)[source]

Bases: Module

forward(img)[source]
class geowatch.tasks.mae.predict.MaeCityscape(dataset, **kwargs)[source]

Bases: LightningModule

forward(x)[source]
shared_step(batch, batch_idx)[source]
geowatch.tasks.mae.predict.sigmoid(a)[source]
class geowatch.tasks.mae.predict.Predict(args)[source]

Bases: object

geowatch.tasks.mae.predict.main(cmdline=1, **kwargs)[source]