geowatch.tasks.fusion.predict module¶
Fusion prediction script.
Given a kwcoco file and a packaged model, run prediction and output a new kwcoco file where predicted heatmaps are new raster bands.
This is the module that handles heatmap prediction over a kwcoco file.
There are SMART-specific parts, but it’s mostly general. It makes heavy use of
CocoStitchingManager
and KWCocoVideoDataModule
. The critical
loop is a simple custom for loop over a dataloader. We currently do not
integrate with LightningCLI here, but we may want to in the future (it is
unclear).
Todo
[ ] Prediction caching?
[ ] Reduce memory usage?
[ ] Pseudo Live.
[ ] Investigate benefits of LightningCLI integration?
- [ ] Option to keep annotations and only loop over relevant areas for
drawing interesting validation / test batches.
[ ] Optimize for the case where we have an image-only dataset.
[ ] Integrate debug visualizations to the CLI
- class geowatch.tasks.fusion.predict.DataModuleConfigMixin(*args, **kwargs)[source]¶
Bases:
DataConfig
Valid options: []
- Parameters:
*args – positional arguments for this data config
**kwargs – keyword arguments for this data config
- default = {'absolute_weighting': <Value(False)>, 'augment_space_rot': <Value(True)>, 'augment_space_shift_rate': <Value(0.9)>, 'augment_space_xflip': <Value(True)>, 'augment_space_yflip': <Value(True)>, 'augment_time_resample_rate': <Value(0.8)>, 'balance_areas': <Value(False)>, 'balance_options': <Value(None)>, 'batch_size': <Value(1)>, 'channel_dropout': <Value(0.0)>, 'channel_dropout_rate': <Value(0.0)>, 'channels': <Value('auto')>, 'chip_dims': <Value('auto')>, 'chip_overlap': <Value(0.3)>, 'default_class_behavior': <Value('background')>, 'dist_weights': <Value(0)>, 'downweight_nan_regions': <Value(True)>, 'dynamic_fixed_resolution': <Value(None)>, 'exclude_sensors': <Value(None)>, 'failed_sample_policy': <Value('warn')>, 'fixed_resolution': <Value(None)>, 'force_bad_frames': <Value(False)>, 'ignore_dilate': <Value(0)>, 'include_sensors': <Value(None)>, 'input_space_scale': <Value('auto')>, 'key': 'set_cover_algo', 'mask_low_quality': <Value('auto')>, 'mask_nan_bands': <Value('')>, 'mask_samecolor_bands': <Value('red')>, 'mask_samecolor_method': <Value(None)>, 'mask_samecolor_values': <Value(0)>, 'max_epoch_length': <Value(None)>, 'min_spacetime_weight': <Value(0.9)>, 'modality_dropout': <Value(0.0)>, 'modality_dropout_rate': <Value(0.0)>, 'neg_to_pos_ratio': <Value(1.0)>, 'normalize_inputs': <Value(True)>, 'normalize_perframe': <Value(False)>, 'normalize_peritem': <Value('auto')>, 'num_balance_trees': <Value(16)>, 'num_workers': <Value(4)>, 'observable_threshold': <Value('auto')>, 'output_space_scale': <Value('auto')>, 'output_type': <Value('heterogeneous')>, 'pin_memory': <Value(True)>, 'prenormalize_inputs': <Value(None)>, 'quality_threshold': <Value('auto')>, 'reduce_item_size': <Value(False)>, 'request_rlimit_nofile': <Value('auto')>, 'resample_invalid_frames': <Value('auto')>, 'reseed_fit_random_generators': <Value(True)>, 'sampler_backend': <Value(None)>, 'sampler_workdir': <Value(None)>, 'sampler_workers': <Value('avail/2')>, 'select_images': <Value(None)>, 'select_videos': <Value(None)>, 'set_cover_algo': <Value('auto')>, 'sqlview': <Value(False)>, 'temporal_dropout': <Value(0.0)>, 'temporal_dropout_rate': <Value(1.0)>, 'test_dataset': <Value(None)>, 'test_with_annot_info': <Value(False)>, 'time_kernel': <Value('auto')>, 'time_sampling': <Value('auto')>, 'time_span': <Value('auto')>, 'time_steps': <Value('auto')>, 'torch_sharing_strategy': <Value('default')>, 'torch_start_method': <Value('default')>, 'train_dataset': <Value(None)>, 'upweight_centers': <Value(True)>, 'upweight_time': <Value(None)>, 'use_centered_positives': <Value(False)>, 'use_cloudmask': <Value('auto')>, 'use_grid_cache': <Value(True)>, 'use_grid_negatives': <Value(True)>, 'use_grid_positives': <Value(True)>, 'use_grid_valid_regions': <Value(True)>, 'vali_dataset': <Value(None)>, 'weight_dilate': <Value(0)>, 'window_space_scale': <Value('auto')>}¶
- class geowatch.tasks.fusion.predict.PredictConfig(*args, **kwargs)[source]¶
Bases:
DataModuleConfigMixin
Prediction script for the fusion task
Example
- python -m geowatch.tasks.fusion.predict
–write_probs=True –with_class=auto –with_saliency=auto –with_change=False –package_fpath=/localdisk0/SCRATCH/watch/ben/smart_watch_dvc/training/raven/brodie/uky_invariants/features_22_03_14/runs/BASELINE_EXPERIMENT_V001/package.pt –pred_dataset=/localdisk0/SCRATCH/watch/ben/smart_watch_dvc/training/raven/brodie/uky_invariants/features_22_03_14/runs/BASELINE_EXPERIMENT_V001/pred.kwcoco.json –test_dataset=/localdisk0/SCRATCH/watch/ben/smart_watch_dvc/Drop2-Aligned-TA1-2022-02-15/data_nowv_vali.kwcoco.json –num_workers=5 –devices=0, –batch_size=1
Valid options: []
- Parameters:
*args – positional arguments for this data config
**kwargs – keyword arguments for this data config
- default = {'absolute_weighting': <Value(False)>, 'accelerator': <Value('auto')>, 'augment_space_rot': <Value(True)>, 'augment_space_shift_rate': <Value(0.9)>, 'augment_space_xflip': <Value(True)>, 'augment_space_yflip': <Value(True)>, 'augment_time_resample_rate': <Value(0.8)>, 'balance_areas': <Value(False)>, 'balance_options': <Value(None)>, 'batch_size': <Value(1)>, 'channel_dropout': <Value(0.0)>, 'channel_dropout_rate': <Value(0.0)>, 'channels': <Value('auto')>, 'chip_dims': <Value('auto')>, 'chip_overlap': <Value(0.3)>, 'clear_annots': <Value(1)>, 'compress': <Value('DEFLATE')>, 'datamodule': <Value('KWCocoVideoDataModule')>, 'default_class_behavior': <Value('background')>, 'devices': <Value(None)>, 'dist_weights': <Value(0)>, 'downweight_edges': <Value(True)>, 'downweight_nan_regions': <Value(True)>, 'draw_batches': <Value(False)>, 'drop_unused_frames': <Value(0)>, 'dynamic_fixed_resolution': <Value(None)>, 'exclude_sensors': <Value(None)>, 'failed_sample_policy': <Value('warn')>, 'fixed_resolution': <Value(None)>, 'force_bad_frames': <Value(False)>, 'format': <Value('cog')>, 'hidden_layers_chan_code': <Value('hidden_layers')>, 'ignore_dilate': <Value(0)>, 'include_sensors': <Value(None)>, 'input_space_scale': <Value('auto')>, 'key': 'set_cover_algo', 'mask_low_quality': <Value('auto')>, 'mask_nan_bands': <Value('')>, 'mask_samecolor_bands': <Value('red')>, 'mask_samecolor_method': <Value(None)>, 'mask_samecolor_values': <Value(0)>, 'max_epoch_length': <Value(None)>, 'memmap': <Value(None)>, 'min_spacetime_weight': <Value(0.9)>, 'modality_dropout': <Value(0.0)>, 'modality_dropout_rate': <Value(0.0)>, 'neg_to_pos_ratio': <Value(1.0)>, 'normalize_inputs': <Value(True)>, 'normalize_perframe': <Value(False)>, 'normalize_peritem': <Value('auto')>, 'num_balance_trees': <Value(16)>, 'num_workers': <Value(4)>, 'observable_threshold': <Value('auto')>, 'output_space_scale': <Value('auto')>, 'output_type': <Value('heterogeneous')>, 'override_meanstd': <Value(None)>, 'package_fpath': <Value(None)>, 'pin_memory': <Value(True)>, 'pred_dataset': <Value(None)>, 'prenormalize_inputs': <Value(None)>, 'quality_threshold': <Value('auto')>, 'quantize': <Value(True)>, 'record_context': <Value(True)>, 'reduce_item_size': <Value(False)>, 'request_rlimit_nofile': <Value('auto')>, 'resample_invalid_frames': <Value('auto')>, 'reseed_fit_random_generators': <Value(True)>, 'saliency_chan_code': <Value('salient')>, 'sampler_backend': <Value(None)>, 'sampler_workdir': <Value(None)>, 'sampler_workers': <Value('avail/2')>, 'select_images': <Value(None)>, 'select_videos': <Value(None)>, 'set_cover_algo': <Value('auto')>, 'sqlview': <Value(False)>, 'temporal_dropout': <Value(0.0)>, 'temporal_dropout_rate': <Value(1.0)>, 'test_dataset': <Value(None)>, 'test_with_annot_info': <Value(False)>, 'thresh': <Value(0.01)>, 'time_kernel': <Value('auto')>, 'time_sampling': <Value('auto')>, 'time_span': <Value('auto')>, 'time_steps': <Value('auto')>, 'torch_sharing_strategy': <Value('default')>, 'torch_start_method': <Value('default')>, 'track_emissions': <Value('offline')>, 'train_dataset': <Value(None)>, 'tta_fliprot': <Value(0)>, 'tta_time': <Value(0)>, 'upweight_centers': <Value(True)>, 'upweight_time': <Value(None)>, 'use_centered_positives': <Value(False)>, 'use_cloudmask': <Value('auto')>, 'use_grid_cache': <Value(True)>, 'use_grid_negatives': <Value(True)>, 'use_grid_positives': <Value(True)>, 'use_grid_valid_regions': <Value(True)>, 'vali_dataset': <Value(None)>, 'weight_dilate': <Value(0)>, 'window_space_scale': <Value('auto')>, 'with_change': <Value('auto')>, 'with_class': <Value('auto')>, 'with_hidden_layers': <Value(False)>, 'with_saliency': <Value('auto')>, 'write_preds': <Value(False)>, 'write_probs': <Value(True)>, 'write_workers': <Value('datamodule')>}¶
- geowatch.tasks.fusion.predict.build_stitching_managers(config, model, result_dataset, writer_queue=None)[source]¶
For each type of requested raster output, we construct a stitching manager that will help map batches back into the correct location in a larger image.
- Returns:
Dict[str, CocoStitchingManager]
- geowatch.tasks.fusion.predict.resolve_datamodule(config, model, datamodule_defaults, fit_config)[source]¶
Creates an instance of the datamodule class.
Note this will also modify the config. TODO: refactor / cleanup.
Breakup the sections that handle getting the traintime params, resolving the datamodule args, and building the datamodule.
- Parameters:
config (dict) – nested train-time configuration provided by the model This should have a “data” key for dataset params.
- class geowatch.tasks.fusion.predict.PeriodicMemoryMonitor[source]¶
Bases:
object
Helper to print out memory stats at certain time intervals
- geowatch.tasks.fusion.predict.predict(cmdline=False, **kwargs)[source]¶
Predict entry point and doctests
CommandLine
xdoctest -m geowatch.tasks.fusion.predict predict:0
Example
>>> # Train a demo model (in the future grab a pretrained demo model) >>> from geowatch.tasks.fusion.predict import * # NOQA >>> import os >>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings >>> disable_lightning_hardware_warnings() >>> args = None >>> cmdline = False >>> devices = None >>> test_dpath = ub.Path.appdir('geowatch/tests/fusion/').ensuredir() >>> results_path = (test_dpath / 'predict').ensuredir() >>> results_path.delete() >>> results_path.ensuredir() >>> import kwcoco >>> train_dset = kwcoco.CocoDataset.demo('special:vidshapes2-gsize64-frames9-speed0.5-multispectral') >>> test_dset = kwcoco.CocoDataset.demo('special:vidshapes1-gsize64-frames9-speed0.5-multispectral') >>> root_dpath = ub.Path(test_dpath, 'train').ensuredir() >>> fit_config = kwargs = { ... 'subcommand': 'fit', ... 'fit.data.train_dataset': train_dset.fpath, ... 'fit.data.time_steps': 2, ... 'fit.data.time_span': "2m", ... 'fit.data.chip_dims': 64, ... 'fit.data.time_sampling': 'hardish3', ... 'fit.data.num_workers': 0, ... #'package_fpath': package_fpath, ... 'fit.model.class_path': 'geowatch.tasks.fusion.methods.MultimodalTransformer', ... 'fit.model.init_args.global_change_weight': 1.0, ... 'fit.model.init_args.global_class_weight': 1.0, ... 'fit.model.init_args.global_saliency_weight': 1.0, ... 'fit.optimizer.class_path': 'torch.optim.SGD', ... 'fit.optimizer.init_args.lr': 1e-5, ... 'fit.trainer.max_steps': 10, ... 'fit.trainer.accelerator': 'cpu', ... 'fit.trainer.devices': 1, ... 'fit.trainer.max_epochs': 3, ... 'fit.trainer.log_every_n_steps': 1, ... 'fit.trainer.default_root_dir': os.fspath(root_dpath), ... } >>> from geowatch.tasks.fusion import fit_lightning >>> package_fpath = root_dpath / 'final_package.pt' >>> fit_lightning.main(fit_config) >>> # Unfortunately, its not as easy to get the package path of >>> # this call.. >>> assert ub.Path(package_fpath).exists() >>> # Predict via that model >>> predict_kwargs = kwargs = { >>> 'package_fpath': package_fpath, >>> 'pred_dataset': ub.Path(results_path) / 'pred.kwcoco.json', >>> 'test_dataset': test_dset.fpath, >>> 'datamodule': 'KWCocoVideoDataModule', >>> 'batch_size': 1, >>> 'num_workers': 0, >>> 'devices': devices, >>> 'draw_batches': 1, >>> 'with_hidden_layers': True, >>> } >>> result_dataset = predict(**kwargs) >>> dset = result_dataset >>> dset.dataset['info'][-1]['properties']['config']['time_sampling'] >>> # Check that the result format looks correct >>> for vidid in dset.index.videos.keys(): >>> # Note: only some of the images in the pred sequence will get >>> # a change predictoion, depending on the temporal sampling. >>> images = dset.images(dset.index.vidid_to_gids[1]) >>> pred_chans = [[a['channels'] for a in aux] for aux in images.lookup('auxiliary')] >>> assert any('change' in cs for cs in pred_chans), 'some frames should have change' >>> assert not all('change' in cs for cs in pred_chans), 'some frames should not have change' >>> # Test number of annots in each frame >>> frame_to_cathist = { >>> img['frame_index']: ub.dict_hist(annots.cnames, labels=result_dataset.object_categories()) >>> for img, annots in zip(images.objs, images.annots) >>> } >>> assert frame_to_cathist[0]['change'] == 0, 'first frame should have no change polygons' >>> # This test may fail with very low probability, so warn >>> import warnings >>> if sum(d['change'] for d in frame_to_cathist.values()) == 0: >>> warnings.warn('should have some change predictions elsewhere') >>> coco_img = dset.images().coco_images[1] >>> # Test that new quantization does not existing APIs >>> pred1 = coco_img.imdelay('salient', nodata_method='float').finalize() >>> assert pred1.max() <= 1 >>> # new delayed image does not make it easy to remove dequantization >>> # add test back in if we add support for that. >>> # pred2 = coco_img.imdelay('salient').finalize(nodata_method='float', dequantize=False) >>> # assert pred2.max() > 1
Example
>>> # xdoctest: +REQUIRES(env:SLOW_DOCTEST) >>> # FIXME: why does this test hang on the strict dashboard? >>> # Train a demo model (in the future grab a pretrained demo model) >>> from geowatch.tasks.fusion.predict import * # NOQA >>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings >>> disable_lightning_hardware_warnings()
>>> args = None >>> cmdline = False >>> devices = None >>> test_dpath = ub.Path.appdir('geowatch/tests/fusion/').ensuredir() >>> results_path = ub.ensuredir((test_dpath, 'predict')) >>> ub.delete(results_path) >>> ub.ensuredir(results_path)
>>> import kwcoco >>> train_dset = kwcoco.CocoDataset.demo('special:vidshapes4-multispectral-multisensor', num_frames=5, image_size=(64, 64)) >>> test_dset = kwcoco.CocoDataset.demo('special:vidshapes2-multispectral-multisensor', num_frames=5, image_size=(64, 64)) >>> datamodule = datamodules.kwcoco_video_data.KWCocoVideoDataModule( >>> train_dataset=train_dset, #'special:vidshapes8-multispectral-multisensor', >>> test_dataset=test_dset, #'special:vidshapes8-multispectral-multisensor', >>> chip_dims=32, >>> channels="r|g|b", >>> batch_size=1, time_steps=3, num_workers=2, normalize_inputs=10) >>> datamodule.setup('fit') >>> datamodule.setup('test') >>> dataset_stats = datamodule.torch_datasets['train'].cached_dataset_stats(num=3) >>> classes = datamodule.torch_datasets['train'].classes >>> print("classes = ", classes)
>>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion.architectures.transformer import TransformerEncoderDecoder >>> position_encoder = methods.heterogeneous.ScaleAgnostictPositionalEncoder(3) >>> backbone = TransformerEncoderDecoder( >>> encoder_depth=1, >>> decoder_depth=1, >>> dim=position_encoder.output_dim + 16, >>> queries_dim=position_encoder.output_dim, >>> logits_dim=16, >>> cross_heads=1, >>> latent_heads=1, >>> cross_dim_head=1, >>> latent_dim_head=1, >>> ) >>> model = methods.HeterogeneousModel( >>> classes=classes, >>> position_encoder=position_encoder, >>> backbone=backbone, >>> decoder="trans_conv", >>> token_width=16, >>> global_change_weight=1, global_class_weight=1, global_saliency_weight=1, >>> dataset_stats=dataset_stats, input_sensorchan=datamodule.input_sensorchan) >>> print("model.heads.keys = ", model.heads.keys())
>>> # Save the self >>> package_fpath = root_dpath / 'final_package.pt' >>> model.save_package(package_fpath) >>> assert ub.Path(package_fpath).exists()
>>> # Predict via that model >>> test_dset = datamodule.train_dataset >>> predict_kwargs = kwargs = { >>> 'package_fpath': package_fpath, >>> 'pred_dataset': ub.Path(results_path) / 'pred.kwcoco.json', >>> 'test_dataset': test_dset.sampler.dset.fpath, >>> 'datamodule': 'KWCocoVideoDataModule', >>> 'channels': 'r|g|b', >>> 'batch_size': 1, >>> 'num_workers': 0, >>> 'devices': devices, >>> } >>> result_dataset = predict(**kwargs)
>>> dset = result_dataset >>> dset.dataset['info'][-1]['properties']['config']['time_sampling'] >>> # Check that the result format looks correct >>> for vidid in dset.index.videos.keys(): >>> # Note: only some of the images in the pred sequence will get >>> # a change predictoion, depending on the temporal sampling. >>> images = dset.images(dset.index.vidid_to_gids[1]) >>> pred_chans = [[a['channels'] for a in aux] for aux in images.lookup('auxiliary')] >>> print("pred_chans = ", pred_chans) >>> assert any('change' in cs for cs in pred_chans), 'some frames should have change' >>> assert not all('change' in cs for cs in pred_chans), 'some frames should not have change' >>> # Test number of annots in each frame >>> frame_to_cathist = { >>> img['frame_index']: ub.dict_hist(annots.cnames, labels=result_dataset.object_categories()) >>> for img, annots in zip(images.objs, images.annots) >>> } >>> assert frame_to_cathist[0]['change'] == 0, 'first frame should have no change polygons' >>> # This test may fail with very low probability, so warn >>> import warnings >>> if sum(d['change'] for d in frame_to_cathist.values()) == 0: >>> warnings.warn('should have some change predictions elsewhere') >>> coco_img = dset.images().coco_images[1] >>> # Test that new quantization does not existing APIs >>> pred1 = coco_img.imdelay('salient', nodata_method='float').finalize() >>> assert pred1.max() <= 1 >>> # new delayed image does not make it easy to remove dequantization >>> # add test back in if we add support for that. >>> # pred2 = coco_img.imdelay('salient').finalize(nodata_method='float', dequantize=False) >>> # assert pred2.max() > 1