geowatch.tasks.fusion.datamodules.kwcoco_datamodule module¶
Defines a lightning DataModule for kwcoco video data.
The parameters to each are handled by scriptconfig objects, which prevents us from needing to specify what the available options are in multiple places.
- class geowatch.tasks.fusion.datamodules.kwcoco_datamodule.KWCocoVideoDataModuleConfig(*args, **kwargs)[source]¶
Bases:
KWCocoVideoDatasetConfig
These are the argument accepted by the KWCocoDataModule.
The scriptconfig class is not used directly as it normally would be here. Instead we use it as a convinience to minimize lightning boilerplate later when it constructs its own argparse object, and for handling arguments passed directly to the KWCocoDataModule
In the future this might be convertable to, or handled by omegaconfig
Valid options: []
- Parameters:
*args – positional arguments for this data config
**kwargs – keyword arguments for this data config
- default = {'absolute_weighting': <Value(False)>, 'augment_space_rot': <Value(True)>, 'augment_space_shift_rate': <Value(0.9)>, 'augment_space_xflip': <Value(True)>, 'augment_space_yflip': <Value(True)>, 'augment_time_resample_rate': <Value(0.8)>, 'balance_areas': <Value(False)>, 'balance_options': <Value(None)>, 'batch_size': <Value(4)>, 'channel_dropout': <Value(0.0)>, 'channel_dropout_rate': <Value(0.0)>, 'channels': <Value(None)>, 'chip_dims': <Value(128)>, 'chip_overlap': <Value(0.0)>, 'default_class_behavior': <Value('background')>, 'dist_weights': <Value(0)>, 'downweight_nan_regions': <Value(True)>, 'dynamic_fixed_resolution': <Value(None)>, 'exclude_sensors': <Value(None)>, 'failed_sample_policy': <Value('warn')>, 'fixed_resolution': <Value(None)>, 'force_bad_frames': <Value(False)>, 'ignore_dilate': <Value(0)>, 'include_sensors': <Value(None)>, 'input_space_scale': <Value(None)>, 'mask_low_quality': <Value(False)>, 'mask_nan_bands': <Value('')>, 'mask_samecolor_bands': <Value('red')>, 'mask_samecolor_method': <Value(None)>, 'mask_samecolor_values': <Value(0)>, 'max_epoch_length': <Value(None)>, 'min_spacetime_weight': <Value(0.9)>, 'modality_dropout': <Value(0.0)>, 'modality_dropout_rate': <Value(0.0)>, 'neg_to_pos_ratio': <Value(1.0)>, 'normalize_inputs': <Value(True)>, 'normalize_perframe': <Value(False)>, 'normalize_peritem': <Value(None)>, 'num_balance_trees': <Value(16)>, 'num_workers': <Value(4)>, 'observable_threshold': <Value(0.0)>, 'output_space_scale': <Value(None)>, 'output_type': <Value('heterogeneous')>, 'pin_memory': <Value(True)>, 'prenormalize_inputs': <Value(None)>, 'quality_threshold': <Value(0.0)>, 'reduce_item_size': <Value(False)>, 'request_rlimit_nofile': <Value('auto')>, 'resample_invalid_frames': <Value(3)>, 'reseed_fit_random_generators': <Value(True)>, 'sampler_backend': <Value(None)>, 'sampler_workdir': <Value(None)>, 'sampler_workers': <Value('avail/2')>, 'select_images': <Value(None)>, 'select_videos': <Value(None)>, 'set_cover_algo': <Value(None)>, 'sqlview': <Value(False)>, 'temporal_dropout': <Value(0.0)>, 'temporal_dropout_rate': <Value(1.0)>, 'test_dataset': <Value(None)>, 'test_with_annot_info': <Value(False)>, 'time_kernel': <Value(None)>, 'time_sampling': <Value('contiguous')>, 'time_span': <Value(None)>, 'time_steps': <Value(2)>, 'torch_sharing_strategy': <Value('default')>, 'torch_start_method': <Value('default')>, 'train_dataset': <Value(None)>, 'upweight_centers': <Value(True)>, 'upweight_time': <Value(None)>, 'use_centered_positives': <Value(False)>, 'use_cloudmask': <Value(None)>, 'use_grid_cache': <Value(True)>, 'use_grid_negatives': <Value(True)>, 'use_grid_positives': <Value(True)>, 'use_grid_valid_regions': <Value(True)>, 'vali_dataset': <Value(None)>, 'weight_dilate': <Value(0)>, 'window_space_scale': <Value(None)>}¶
- class geowatch.tasks.fusion.datamodules.kwcoco_datamodule.KWCocoVideoDataModule(verbose=1, **kwargs)[source]¶
Bases:
LightningDataModule
Prepare the kwcoco dataset as torch video datamodules
Example
>>> # Demo of the data module on auto-generated toy data >>> from geowatch.tasks.fusion.datamodules.kwcoco_datamodule import * # NOQA >>> import geowatch >>> import kwcoco >>> coco_dset = geowatch.coerce_kwcoco('vidshapes8-geowatch') >>> channels = None >>> batch_size = 1 >>> time_steps = 3 >>> chip_size = 416 >>> self = KWCocoVideoDataModule( >>> train_dataset=coco_dset, >>> test_dataset=None, >>> batch_size=batch_size, >>> normalize_inputs=8, >>> channels=channels, >>> num_workers=0, >>> time_steps=time_steps, >>> chip_size=chip_size, >>> neg_to_pos_ratio=0, >>> ) >>> self.setup('fit') >>> dl = self.train_dataloader() >>> dataset = dl.dataset >>> batch = next(iter(dl)) >>> batch = [dl.dataset[0]] >>> # Visualize >>> canvas = self.draw_batch(batch) >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(canvas) >>> kwplot.show_if_requested()
Example
>>> # xdoctest: +REQUIRES(env:DVC_DPATH) >>> # Run the following tests on real geowatch data if DVC is available >>> from geowatch.tasks.fusion.datamodules.kwcoco_datamodule import * # NOQA >>> import geowatch >>> import kwcoco >>> dvc_dpath = geowatch.find_dvc_dpath() >>> coco_fpath = dvc_dpath / 'Drop2-Aligned-TA1-2022-02-15/combo_ILM.kwcoco.json' >>> #coco_fpath = dvc_dpath / 'Aligned-Drop2-TA1-2022-03-07/combo_DILM.kwcoco.json' >>> #coco_fpath = dvc_dpath / 'Drop2-Aligned-TA1-2022-02-15/combo_DILM.kwcoco.json' >>> dset = kwcoco.CocoDataset(coco_fpath) >>> images = dset.images() >>> train_dataset = dset >>> #sub_images = dset.videos(names=['KR_R002']).images[0] >>> #train_dataset = dset.subset(sub_images.lookup('id')) >>> test_dataset = None >>> img = ub.peek(train_dataset.imgs.values()) >>> chan_info = kwcoco_extensions.coco_channel_stats(dset) >>> #channels = chan_info['common_channels'] >>> channels = 'blue|green|red|nir|swir16|swir22,forest|bare_ground,matseg_0|matseg_1|matseg_2,invariants.0:3,cloudmask' >>> #channels = 'blue|green|red|depth' >>> #channels = None >>> # >>> batch_size = 1 >>> time_steps = 8 >>> chip_size = 512 >>> datamodule = KWCocoVideoDataModule( >>> train_dataset=train_dataset, >>> test_dataset=test_dataset, >>> batch_size=batch_size, >>> channels=channels, >>> num_workers=0, >>> normalize_inputs=8, >>> time_steps=time_steps, >>> chip_size=chip_size, >>> neg_to_pos_ratio=0, >>> min_spacetime_weight=0.5, >>> ) >>> datamodule.setup('fit') >>> dl = datamodule.train_dataloader() >>> dataset = dl.dataset >>> dataset.requested_tasks['change'] = False >>> dataset.disable_augmenter = True >>> target = 0 >>> item, *_ = batch = [dataset[target]] >>> #item, *_ = batch = next(iter(dl)) >>> # Visualize >>> canvas = datamodule.draw_batch(batch) >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(canvas, doclf=1) >>> kwplot.show_if_requested()
Example
>>> # xdoctest: +SKIP >>> # NOTE: I DONT KNOW WHY THIS IS FAILING ON CI AT THE MOMENT. FIXME! >>> # Run the data module on coco demo datamodules for the CI >>> from geowatch.tasks.fusion.datamodules.kwcoco_datamodule import * # NOQA >>> import kwcoco >>> import delayed_image >>> train_dataset = kwcoco.CocoDataset.demo('vidshapes2-multispectral', num_frames=5) >>> test_dataset = kwcoco.CocoDataset.demo('vidshapes1-multispectral', num_frames=5) >>> channels = '|'.join([aux['channels'] for aux in train_dataset.imgs[1]['auxiliary']]) >>> chan_spec = delayed_image.channel_spec.FusedChannelSpec.coerce(channels) >>> # >>> batch_size = 2 >>> time_steps = 3 >>> chip_size = 128 >>> channels = channels >>> self = KWCocoVideoDataModule( >>> train_dataset=train_dataset, >>> test_dataset=test_dataset, >>> batch_size=batch_size, >>> channels=channels, >>> num_workers=0, >>> time_steps=time_steps, >>> chip_size=chip_size, >>> normalize_inputs=True, >>> ) >>> self.setup('fit') >>> dl = self.train_dataloader() >>> item, *_ = batch = next(iter(dl)) >>> expect_shape = (batch_size, time_steps, len(chan_spec), chip_size, chip_size) >>> assert len(batch) == batch_size >>> for item in batch: ... assert len(item['frames']) == time_steps ... for mode_key, mode_val in item['frames'][0]['modes'].items(): ... assert mode_val.shape[1:3] == (chip_size, chip_size)
For details on accepted arguments see KWCocoVideoDataModuleConfig
- property train_dataset¶
- property test_dataset¶
- property vali_dataset¶
- classmethod add_argparse_args(parent_parser)[source]¶
Previously the arguments were in multiple places including here. This has been updated to use the
KWCocoVideoDataModuleConfig
as the single point where arguments are defined. The functionality of this method is roughly the same as it used to be given that scriptconfig objects can be transformed into argparse objects.CommandLine
xdoctest -m /home/joncrall/code/watch/geowatch/tasks/fusion/datamodules/kwcoco_datamodule.py add_argparse_args
Example
>>> from geowatch.tasks.fusion.datamodules.kwcoco_datamodule import * # NOQA >>> cls = KWCocoVideoDataModule >>> # TODO: make use of geowatch.utils.lightning_ext import argparse_ext >>> import argparse >>> parent_parser = argparse.ArgumentParser() >>> cls.add_argparse_args(parent_parser) >>> parent_parser.print_help() >>> args, _ = parent_parser.parse_known_args(['--use_grid_positives=True']) >>> assert args.use_grid_positives >>> args, _ = parent_parser.parse_known_args(['--use_grid_positives=False']) >>> assert not args.use_grid_positives >>> args, _ = parent_parser.parse_known_args(['--exclude_sensors=l8,f3']) >>> assert args.exclude_sensors == 'l8,f3' >>> args, _ = parent_parser.parse_known_args(['--exclude_sensors=l8']) >>> assert args.exclude_sensors == 'l8'
- classmethod compatible(cfgdict)[source]¶
Given keyword arguments, find the subset that is compatible with this constructor. This is somewhat hacked because of usage of scriptconfig, but could be made nicer by future updates.
- draw_batch(batch, stage='train', outputs=None, max_items=2, overlay_on_image=False, classes=None, **kwargs)[source]¶
Visualize a batch produced by a KWCocoVideoDataset.
- Parameters:
batch (Dict[str, List[Tensor]]) – dictionary of uncollated lists of Dataset Items change: [ [T-1, H, W] in [0, 1] forall examples ] saliency: [ [T, H, W, 2] in [0, 1] forall examples ] class: [ [T, H, W, 10] in [0, 1] forall examples ]
outputs (Dict[str, Tensor]) – maybe-collated list of network outputs?
max_items (int) – Maximum number of items within this batch to draw in a single figure. Defaults to 2.
overlay_on_image (bool) – if True overlay annotations on image data for a more compact view. if False separate annotations / images for a less cluttered view.
CommandLine
xdoctest -m geowatch.tasks.fusion.datamodules.kwcoco_datamodule KWCocoVideoDataModule.draw_batch
Example
>>> from geowatch.tasks.fusion.datamodules.kwcoco_datamodule import * # NOQA >>> from geowatch.tasks.fusion import datamodules >>> self = datamodules.KWCocoVideoDataModule( >>> train_dataset='special:vidshapes8-multispectral', channels='auto', num_workers=0) >>> self.setup('fit') >>> loader = self.train_dataloader() >>> batch = next(iter(loader)) >>> item = batch[0] >>> # Visualize >>> B = len(batch) >>> C, H, W = ub.peek(item['frames'][0]['modes'].values()).shape >>> T = len(item['frames']) >>> import torch >>> outputs = {'change_probs': [torch.rand(T - 1, H, W) for _ in range(B)]} >>> outputs.update({'class_probs': [torch.rand(T, H, W, 10) for _ in range(B)]}) >>> stage = 'train' >>> canvas = self.draw_batch(batch, stage=stage, outputs=outputs) >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(canvas) >>> kwplot.show_if_requested()
Example
>>> # xdoctest: +REQUIRES(--slow) >>> from geowatch.tasks.fusion.datamodules.kwcoco_datamodule import * # NOQA >>> from geowatch.tasks.fusion import datamodules >>> import geowatch >>> train_dataset = geowatch.demo.demo_kwcoco_multisensor() >>> self = datamodules.KWCocoVideoDataModule( >>> train_dataset=train_dataset, chip_size=256, time_steps=5, num_workers=0, batch_size=3) >>> self.setup('fit') >>> loader = self.train_dataloader() >>> batch_iter = iter(loader) >>> batch = next(batch_iter) >>> batch[1] = None # simulate a dropped batch item >>> batch[0] = None # simulate a dropped batch item >>> #item = batch[0] >>> # Visualize >>> B = len(batch) >>> outputs = {'change_probs': [], 'class_probs': [], 'saliency_probs': []} >>> # Add dummy outputs >>> import torch >>> for item in batch: >>> if item is None: >>> [v.append([None]) for v in outputs.values()] >>> else: >>> [v.append([]) for v in outputs.values()] >>> for frame_idx, frame in enumerate(item['frames']): >>> H, W = frame['class_idxs'].shape >>> if frame_idx > 0: >>> outputs['change_probs'][-1].append(torch.rand(H, W)) >>> outputs['class_probs'][-1].append(torch.rand(H, W, 10)) >>> outputs['saliency_probs'][-1].append(torch.rand(H, W, 2)) >>> from geowatch.utils import util_nesting >>> print(ub.urepr(util_nesting.shape_summary(outputs), nl=1, sort=0)) >>> stage = 'train' >>> canvas = self.draw_batch(batch, stage=stage, outputs=outputs, max_items=4) >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(canvas) >>> kwplot.show_if_requested()
Example
>>> from geowatch.tasks.fusion.datamodules.kwcoco_datamodule import * # NOQA >>> from geowatch.tasks.fusion import datamodules >>> self = datamodules.KWCocoVideoDataModule( >>> batch_size = 3, >>> train_dataset='special:vidshapes8-multispectral', channels='auto', num_workers=0) >>> self.setup('fit') >>> loader = self.train_dataloader() >>> batch = next(iter(loader)) >>> batch[1] = None >>> item = batch[0] >>> # Visualize >>> B = len(batch) >>> C, H, W = ub.peek(item['frames'][0]['modes'].values()).shape >>> T = len(item['frames']) >>> import torch >>> outputs = {'change_probs': [torch.rand(T - 1, H, W) for _ in range(B)]} >>> outputs.update({'class_probs': [torch.rand(T, H, W, 10) for _ in range(B)]}) >>> stage = 'train' >>> canvas = self.draw_batch(batch, stage=stage, outputs=outputs) >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(canvas) >>> kwplot.show_if_requested()