geowatch.tasks.fusion.predict module

Fusion prediction script.

Given a kwcoco file and a packaged model, run prediction and output a new kwcoco file where predicted heatmaps are new raster bands.

This is the module that handles heatmap prediction over a kwcoco file. There are SMART-specific parts, but it’s mostly general. It makes heavy use of CocoStitchingManager and KWCocoVideoDataModule. The critical loop is a simple custom for loop over a dataloader. We currently do not integrate with LightningCLI here, but we may want to in the future (it is unclear).

Todo

  • [ ] Prediction caching?

  • [ ] Reduce memory usage?

  • [ ] Pseudo Live.

  • [ ] Investigate benefits of LightningCLI integration?

  • [ ] Option to keep annotations and only loop over relevant areas for

    drawing interesting validation / test batches.

  • [ ] Optimize for the case where we have an image-only dataset.

  • [ ] Integrate debug visualizations to the CLI

class geowatch.tasks.fusion.predict.DataModuleConfigMixin(*args, **kwargs)[source]

Bases: DataConfig

Valid options: []

Parameters:
  • *args – positional arguments for this data config

  • **kwargs – keyword arguments for this data config

default = {'absolute_weighting': <Value(False)>, 'augment_space_rot': <Value(True)>, 'augment_space_shift_rate': <Value(0.9)>, 'augment_space_xflip': <Value(True)>, 'augment_space_yflip': <Value(True)>, 'augment_time_resample_rate': <Value(0.8)>, 'balance_areas': <Value(False)>, 'balance_options': <Value(None)>, 'batch_size': <Value(1)>, 'channel_dropout': <Value(0.0)>, 'channel_dropout_rate': <Value(0.0)>, 'channels': <Value('auto')>, 'chip_dims': <Value('auto')>, 'chip_overlap': <Value(0.3)>, 'default_class_behavior': <Value('background')>, 'dist_weights': <Value(0)>, 'downweight_nan_regions': <Value(True)>, 'dynamic_fixed_resolution': <Value(None)>, 'exclude_sensors': <Value(None)>, 'failed_sample_policy': <Value('warn')>, 'fixed_resolution': <Value(None)>, 'force_bad_frames': <Value(False)>, 'ignore_dilate': <Value(0)>, 'include_sensors': <Value(None)>, 'input_space_scale': <Value('auto')>, 'key': 'set_cover_algo', 'mask_low_quality': <Value('auto')>, 'mask_nan_bands': <Value('')>, 'mask_samecolor_bands': <Value('red')>, 'mask_samecolor_method': <Value(None)>, 'mask_samecolor_values': <Value(0)>, 'max_epoch_length': <Value(None)>, 'min_spacetime_weight': <Value(0.9)>, 'modality_dropout': <Value(0.0)>, 'modality_dropout_rate': <Value(0.0)>, 'neg_to_pos_ratio': <Value(1.0)>, 'normalize_inputs': <Value(True)>, 'normalize_perframe': <Value(False)>, 'normalize_peritem': <Value('auto')>, 'num_balance_trees': <Value(16)>, 'num_workers': <Value(4)>, 'observable_threshold': <Value('auto')>, 'output_space_scale': <Value('auto')>, 'output_type': <Value('heterogeneous')>, 'pin_memory': <Value(True)>, 'prenormalize_inputs': <Value(None)>, 'quality_threshold': <Value('auto')>, 'reduce_item_size': <Value(False)>, 'request_rlimit_nofile': <Value('auto')>, 'resample_invalid_frames': <Value('auto')>, 'reseed_fit_random_generators': <Value(True)>, 'sampler_backend': <Value(None)>, 'sampler_workdir': <Value(None)>, 'sampler_workers': <Value('avail/2')>, 'select_images': <Value(None)>, 'select_videos': <Value(None)>, 'set_cover_algo': <Value('auto')>, 'sqlview': <Value(False)>, 'temporal_dropout': <Value(0.0)>, 'temporal_dropout_rate': <Value(1.0)>, 'test_dataset': <Value(None)>, 'test_with_annot_info': <Value(False)>, 'time_kernel': <Value('auto')>, 'time_sampling': <Value('auto')>, 'time_span': <Value('auto')>, 'time_steps': <Value('auto')>, 'torch_sharing_strategy': <Value('default')>, 'torch_start_method': <Value('default')>, 'train_dataset': <Value(None)>, 'upweight_centers': <Value(True)>, 'upweight_time': <Value(None)>, 'use_centered_positives': <Value(False)>, 'use_cloudmask': <Value('auto')>, 'use_grid_cache': <Value(True)>, 'use_grid_negatives': <Value(True)>, 'use_grid_positives': <Value(True)>, 'use_grid_valid_regions': <Value(True)>, 'vali_dataset': <Value(None)>, 'weight_dilate': <Value(0)>, 'window_space_scale': <Value('auto')>}
class geowatch.tasks.fusion.predict.PredictConfig(*args, **kwargs)[source]

Bases: DataModuleConfigMixin

Prediction script for the fusion task

Example

python -m geowatch.tasks.fusion.predict

–write_probs=True –with_class=auto –with_saliency=auto –with_change=False –package_fpath=/localdisk0/SCRATCH/watch/ben/smart_watch_dvc/training/raven/brodie/uky_invariants/features_22_03_14/runs/BASELINE_EXPERIMENT_V001/package.pt –pred_dataset=/localdisk0/SCRATCH/watch/ben/smart_watch_dvc/training/raven/brodie/uky_invariants/features_22_03_14/runs/BASELINE_EXPERIMENT_V001/pred.kwcoco.json –test_dataset=/localdisk0/SCRATCH/watch/ben/smart_watch_dvc/Drop2-Aligned-TA1-2022-02-15/data_nowv_vali.kwcoco.json –num_workers=5 –devices=0, –batch_size=1

Valid options: []

Parameters:
  • *args – positional arguments for this data config

  • **kwargs – keyword arguments for this data config

default = {'absolute_weighting': <Value(False)>, 'accelerator': <Value('auto')>, 'augment_space_rot': <Value(True)>, 'augment_space_shift_rate': <Value(0.9)>, 'augment_space_xflip': <Value(True)>, 'augment_space_yflip': <Value(True)>, 'augment_time_resample_rate': <Value(0.8)>, 'balance_areas': <Value(False)>, 'balance_options': <Value(None)>, 'batch_size': <Value(1)>, 'channel_dropout': <Value(0.0)>, 'channel_dropout_rate': <Value(0.0)>, 'channels': <Value('auto')>, 'chip_dims': <Value('auto')>, 'chip_overlap': <Value(0.3)>, 'clear_annots': <Value(1)>, 'compress': <Value('DEFLATE')>, 'datamodule': <Value('KWCocoVideoDataModule')>, 'default_class_behavior': <Value('background')>, 'devices': <Value(None)>, 'dist_weights': <Value(0)>, 'downweight_edges': <Value(True)>, 'downweight_nan_regions': <Value(True)>, 'draw_batches': <Value(False)>, 'drop_unused_frames': <Value(0)>, 'dynamic_fixed_resolution': <Value(None)>, 'exclude_sensors': <Value(None)>, 'failed_sample_policy': <Value('warn')>, 'fixed_resolution': <Value(None)>, 'force_bad_frames': <Value(False)>, 'format': <Value('cog')>, 'hidden_layers_chan_code': <Value('hidden_layers')>, 'ignore_dilate': <Value(0)>, 'include_sensors': <Value(None)>, 'input_space_scale': <Value('auto')>, 'key': 'set_cover_algo', 'mask_low_quality': <Value('auto')>, 'mask_nan_bands': <Value('')>, 'mask_samecolor_bands': <Value('red')>, 'mask_samecolor_method': <Value(None)>, 'mask_samecolor_values': <Value(0)>, 'max_epoch_length': <Value(None)>, 'memmap': <Value(None)>, 'min_spacetime_weight': <Value(0.9)>, 'modality_dropout': <Value(0.0)>, 'modality_dropout_rate': <Value(0.0)>, 'neg_to_pos_ratio': <Value(1.0)>, 'normalize_inputs': <Value(True)>, 'normalize_perframe': <Value(False)>, 'normalize_peritem': <Value('auto')>, 'num_balance_trees': <Value(16)>, 'num_workers': <Value(4)>, 'observable_threshold': <Value('auto')>, 'output_space_scale': <Value('auto')>, 'output_type': <Value('heterogeneous')>, 'override_meanstd': <Value(None)>, 'package_fpath': <Value(None)>, 'pin_memory': <Value(True)>, 'pred_dataset': <Value(None)>, 'prenormalize_inputs': <Value(None)>, 'quality_threshold': <Value('auto')>, 'quantize': <Value(True)>, 'record_context': <Value(True)>, 'reduce_item_size': <Value(False)>, 'request_rlimit_nofile': <Value('auto')>, 'resample_invalid_frames': <Value('auto')>, 'reseed_fit_random_generators': <Value(True)>, 'saliency_chan_code': <Value('salient')>, 'sampler_backend': <Value(None)>, 'sampler_workdir': <Value(None)>, 'sampler_workers': <Value('avail/2')>, 'select_images': <Value(None)>, 'select_videos': <Value(None)>, 'set_cover_algo': <Value('auto')>, 'sqlview': <Value(False)>, 'temporal_dropout': <Value(0.0)>, 'temporal_dropout_rate': <Value(1.0)>, 'test_dataset': <Value(None)>, 'test_with_annot_info': <Value(False)>, 'thresh': <Value(0.01)>, 'time_kernel': <Value('auto')>, 'time_sampling': <Value('auto')>, 'time_span': <Value('auto')>, 'time_steps': <Value('auto')>, 'torch_sharing_strategy': <Value('default')>, 'torch_start_method': <Value('default')>, 'track_emissions': <Value('offline')>, 'train_dataset': <Value(None)>, 'tta_fliprot': <Value(0)>, 'tta_time': <Value(0)>, 'upweight_centers': <Value(True)>, 'upweight_time': <Value(None)>, 'use_centered_positives': <Value(False)>, 'use_cloudmask': <Value('auto')>, 'use_grid_cache': <Value(True)>, 'use_grid_negatives': <Value(True)>, 'use_grid_positives': <Value(True)>, 'use_grid_valid_regions': <Value(True)>, 'vali_dataset': <Value(None)>, 'weight_dilate': <Value(0)>, 'window_space_scale': <Value('auto')>, 'with_change': <Value('auto')>, 'with_class': <Value('auto')>, 'with_hidden_layers': <Value(False)>, 'with_saliency': <Value('auto')>, 'write_preds': <Value(False)>, 'write_probs': <Value(True)>, 'write_workers': <Value('datamodule')>}
geowatch.tasks.fusion.predict.build_stitching_managers(config, model, result_dataset, writer_queue=None)[source]

For each type of requested raster output, we construct a stitching manager that will help map batches back into the correct location in a larger image.

Returns:

Dict[str, CocoStitchingManager]

geowatch.tasks.fusion.predict.resolve_datamodule(config, model, datamodule_defaults, fit_config)[source]

Creates an instance of the datamodule class.

Note this will also modify the config. TODO: refactor / cleanup.

Breakup the sections that handle getting the traintime params, resolving the datamodule args, and building the datamodule.

Parameters:

config (dict) – nested train-time configuration provided by the model This should have a “data” key for dataset params.

class geowatch.tasks.fusion.predict.PeriodicMemoryMonitor[source]

Bases: object

Helper to print out memory stats at certain time intervals

check()[source]
geowatch.tasks.fusion.predict.predict(cmdline=False, **kwargs)[source]

Predict entry point and doctests

CommandLine

xdoctest -m geowatch.tasks.fusion.predict predict:0

Example

>>> # Train a demo model (in the future grab a pretrained demo model)
>>> from geowatch.tasks.fusion.predict import *  # NOQA
>>> import os
>>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings
>>> disable_lightning_hardware_warnings()
>>> args = None
>>> cmdline = False
>>> devices = None
>>> test_dpath = ub.Path.appdir('geowatch/tests/fusion/').ensuredir()
>>> results_path = (test_dpath / 'predict').ensuredir()
>>> results_path.delete()
>>> results_path.ensuredir()
>>> import kwcoco
>>> train_dset = kwcoco.CocoDataset.demo('special:vidshapes2-gsize64-frames9-speed0.5-multispectral')
>>> test_dset = kwcoco.CocoDataset.demo('special:vidshapes1-gsize64-frames9-speed0.5-multispectral')
>>> root_dpath = ub.Path(test_dpath, 'train').ensuredir()
>>> fit_config = kwargs = {
...     'subcommand': 'fit',
...     'fit.data.train_dataset': train_dset.fpath,
...     'fit.data.time_steps': 2,
...     'fit.data.time_span': "2m",
...     'fit.data.chip_dims': 64,
...     'fit.data.time_sampling': 'hardish3',
...     'fit.data.num_workers': 0,
...     #'package_fpath': package_fpath,
...     'fit.model.class_path': 'geowatch.tasks.fusion.methods.MultimodalTransformer',
...     'fit.model.init_args.global_change_weight': 1.0,
...     'fit.model.init_args.global_class_weight': 1.0,
...     'fit.model.init_args.global_saliency_weight': 1.0,
...     'fit.optimizer.class_path': 'torch.optim.SGD',
...     'fit.optimizer.init_args.lr': 1e-5,
...     'fit.trainer.max_steps': 10,
...     'fit.trainer.accelerator': 'cpu',
...     'fit.trainer.devices': 1,
...     'fit.trainer.max_epochs': 3,
...     'fit.trainer.log_every_n_steps': 1,
...     'fit.trainer.default_root_dir': os.fspath(root_dpath),
... }
>>> from geowatch.tasks.fusion import fit_lightning
>>> package_fpath = root_dpath / 'final_package.pt'
>>> fit_lightning.main(fit_config)
>>> # Unfortunately, its not as easy to get the package path of
>>> # this call..
>>> assert ub.Path(package_fpath).exists()
>>> # Predict via that model
>>> predict_kwargs = kwargs = {
>>>     'package_fpath': package_fpath,
>>>     'pred_dataset': ub.Path(results_path) / 'pred.kwcoco.json',
>>>     'test_dataset': test_dset.fpath,
>>>     'datamodule': 'KWCocoVideoDataModule',
>>>     'batch_size': 1,
>>>     'num_workers': 0,
>>>     'devices': devices,
>>>     'draw_batches': 1,
>>>     'with_hidden_layers': True,
>>> }
>>> result_dataset = predict(**kwargs)
>>> dset = result_dataset
>>> dset.dataset['info'][-1]['properties']['config']['time_sampling']
>>> # Check that the result format looks correct
>>> for vidid in dset.index.videos.keys():
>>>     # Note: only some of the images in the pred sequence will get
>>>     # a change predictoion, depending on the temporal sampling.
>>>     images = dset.images(dset.index.vidid_to_gids[1])
>>>     pred_chans = [[a['channels'] for a in aux] for aux in images.lookup('auxiliary')]
>>>     assert any('change' in cs for cs in pred_chans), 'some frames should have change'
>>>     assert not all('change' in cs for cs in pred_chans), 'some frames should not have change'
>>>     # Test number of annots in each frame
>>>     frame_to_cathist = {
>>>         img['frame_index']: ub.dict_hist(annots.cnames, labels=result_dataset.object_categories())
>>>         for img, annots in zip(images.objs, images.annots)
>>>     }
>>>     assert frame_to_cathist[0]['change'] == 0, 'first frame should have no change polygons'
>>>     # This test may fail with very low probability, so warn
>>>     import warnings
>>>     if sum(d['change'] for d in frame_to_cathist.values()) == 0:
>>>         warnings.warn('should have some change predictions elsewhere')
>>> coco_img = dset.images().coco_images[1]
>>> # Test that new quantization does not existing APIs
>>> pred1 = coco_img.imdelay('salient', nodata_method='float').finalize()
>>> assert pred1.max() <= 1
>>> # new delayed image does not make it easy to remove dequantization
>>> # add test back in if we add support for that.
>>> # pred2 = coco_img.imdelay('salient').finalize(nodata_method='float', dequantize=False)
>>> # assert pred2.max() > 1

Example

>>> # xdoctest: +REQUIRES(env:SLOW_DOCTEST)
>>> # FIXME: why does this test hang on the strict dashboard?
>>> # Train a demo model (in the future grab a pretrained demo model)
>>> from geowatch.tasks.fusion.predict import *  # NOQA
>>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings
>>> disable_lightning_hardware_warnings()
>>> args = None
>>> cmdline = False
>>> devices = None
>>> test_dpath = ub.Path.appdir('geowatch/tests/fusion/').ensuredir()
>>> results_path = ub.ensuredir((test_dpath, 'predict'))
>>> ub.delete(results_path)
>>> ub.ensuredir(results_path)
>>> import kwcoco
>>> train_dset = kwcoco.CocoDataset.demo('special:vidshapes4-multispectral-multisensor', num_frames=5, image_size=(64, 64))
>>> test_dset = kwcoco.CocoDataset.demo('special:vidshapes2-multispectral-multisensor', num_frames=5, image_size=(64, 64))
>>> datamodule = datamodules.kwcoco_video_data.KWCocoVideoDataModule(
>>>     train_dataset=train_dset, #'special:vidshapes8-multispectral-multisensor',
>>>     test_dataset=test_dset, #'special:vidshapes8-multispectral-multisensor',
>>>     chip_dims=32,
>>>     channels="r|g|b",
>>>     batch_size=1, time_steps=3, num_workers=2, normalize_inputs=10)
>>> datamodule.setup('fit')
>>> datamodule.setup('test')
>>> dataset_stats = datamodule.torch_datasets['train'].cached_dataset_stats(num=3)
>>> classes = datamodule.torch_datasets['train'].classes
>>> print("classes = ", classes)
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder
>>> position_encoder = methods.heterogeneous.ScaleAgnostictPositionalEncoder(3)
>>> backbone = TransformerEncoderDecoder(
>>>     encoder_depth=1,
>>>     decoder_depth=1,
>>>     dim=position_encoder.output_dim + 16,
>>>     queries_dim=position_encoder.output_dim,
>>>     logits_dim=16,
>>>     cross_heads=1,
>>>     latent_heads=1,
>>>     cross_dim_head=1,
>>>     latent_dim_head=1,
>>> )
>>> model = methods.HeterogeneousModel(
>>>     classes=classes,
>>>     position_encoder=position_encoder,
>>>     backbone=backbone,
>>>     decoder="trans_conv",
>>>     token_width=16,
>>>     global_change_weight=1, global_class_weight=1, global_saliency_weight=1,
>>>     dataset_stats=dataset_stats, input_sensorchan=datamodule.input_sensorchan)
>>> print("model.heads.keys = ", model.heads.keys())
>>> # Save the self
>>> package_fpath = root_dpath / 'final_package.pt'
>>> model.save_package(package_fpath)
>>> assert ub.Path(package_fpath).exists()
>>> # Predict via that model
>>> test_dset = datamodule.train_dataset
>>> predict_kwargs = kwargs = {
>>>     'package_fpath': package_fpath,
>>>     'pred_dataset': ub.Path(results_path) / 'pred.kwcoco.json',
>>>     'test_dataset': test_dset.sampler.dset.fpath,
>>>     'datamodule': 'KWCocoVideoDataModule',
>>>     'channels': 'r|g|b',
>>>     'batch_size': 1,
>>>     'num_workers': 0,
>>>     'devices': devices,
>>> }
>>> result_dataset = predict(**kwargs)
>>> dset = result_dataset
>>> dset.dataset['info'][-1]['properties']['config']['time_sampling']
>>> # Check that the result format looks correct
>>> for vidid in dset.index.videos.keys():
>>>     # Note: only some of the images in the pred sequence will get
>>>     # a change predictoion, depending on the temporal sampling.
>>>     images = dset.images(dset.index.vidid_to_gids[1])
>>>     pred_chans = [[a['channels'] for a in aux] for aux in images.lookup('auxiliary')]
>>>     print("pred_chans = ", pred_chans)
>>>     assert any('change' in cs for cs in pred_chans), 'some frames should have change'
>>>     assert not all('change' in cs for cs in pred_chans), 'some frames should not have change'
>>>     # Test number of annots in each frame
>>>     frame_to_cathist = {
>>>         img['frame_index']: ub.dict_hist(annots.cnames, labels=result_dataset.object_categories())
>>>         for img, annots in zip(images.objs, images.annots)
>>>     }
>>>     assert frame_to_cathist[0]['change'] == 0, 'first frame should have no change polygons'
>>>     # This test may fail with very low probability, so warn
>>>     import warnings
>>>     if sum(d['change'] for d in frame_to_cathist.values()) == 0:
>>>         warnings.warn('should have some change predictions elsewhere')
>>> coco_img = dset.images().coco_images[1]
>>> # Test that new quantization does not existing APIs
>>> pred1 = coco_img.imdelay('salient', nodata_method='float').finalize()
>>> assert pred1.max() <= 1
>>> # new delayed image does not make it easy to remove dequantization
>>> # add test back in if we add support for that.
>>> # pred2 = coco_img.imdelay('salient').finalize(nodata_method='float', dequantize=False)
>>> # assert pred2.max() > 1
class geowatch.tasks.fusion.predict.Predictor(config)[source]

Bases: object

Abstracts different stages of the prediction process

New in 0.17.1, needs to be refactored with the rest of the code in this file.

geowatch.tasks.fusion.predict.main(cmdline=True, **kwargs)[source]