geowatch.tasks.fusion.datamodules package

Subpackages

Submodules

Module contents

python -m geowatch.tasks.fusion

mkinit ~/code/watch/geowatch/tasks/fusion/datamodules/__init__.py –nomods -w

class geowatch.tasks.fusion.datamodules.KWCocoVideoDataModule(verbose=1, **kwargs)[source]

Bases: LightningDataModule

Prepare the kwcoco dataset as torch video datamodules

Example

>>> # Demo of the data module on auto-generated toy data
>>> from geowatch.tasks.fusion.datamodules.kwcoco_datamodule import *  # NOQA
>>> import geowatch
>>> import kwcoco
>>> coco_dset = geowatch.coerce_kwcoco('vidshapes8-geowatch')
>>> channels = None
>>> batch_size = 1
>>> time_steps = 3
>>> chip_size = 416
>>> self = KWCocoVideoDataModule(
>>>     train_dataset=coco_dset,
>>>     test_dataset=None,
>>>     batch_size=batch_size,
>>>     normalize_inputs=8,
>>>     channels=channels,
>>>     num_workers=0,
>>>     time_steps=time_steps,
>>>     chip_size=chip_size,
>>>     neg_to_pos_ratio=0,
>>> )
>>> self.setup('fit')
>>> dl = self.train_dataloader()
>>> dataset = dl.dataset
>>> batch = next(iter(dl))
>>> batch = [dl.dataset[0]]
>>> # Visualize
>>> canvas = self.draw_batch(batch)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(canvas)
>>> kwplot.show_if_requested()

Example

>>> # xdoctest: +REQUIRES(env:DVC_DPATH)
>>> # Run the following tests on real geowatch data if DVC is available
>>> from geowatch.tasks.fusion.datamodules.kwcoco_datamodule import *  # NOQA
>>> import geowatch
>>> import kwcoco
>>> dvc_dpath = geowatch.find_dvc_dpath()
>>> coco_fpath = dvc_dpath / 'Drop2-Aligned-TA1-2022-02-15/combo_ILM.kwcoco.json'
>>> #coco_fpath = dvc_dpath / 'Aligned-Drop2-TA1-2022-03-07/combo_DILM.kwcoco.json'
>>> #coco_fpath = dvc_dpath / 'Drop2-Aligned-TA1-2022-02-15/combo_DILM.kwcoco.json'
>>> dset = kwcoco.CocoDataset(coco_fpath)
>>> images = dset.images()
>>> train_dataset = dset
>>> #sub_images = dset.videos(names=['KR_R002']).images[0]
>>> #train_dataset = dset.subset(sub_images.lookup('id'))
>>> test_dataset = None
>>> img = ub.peek(train_dataset.imgs.values())
>>> chan_info = kwcoco_extensions.coco_channel_stats(dset)
>>> #channels = chan_info['common_channels']
>>> channels = 'blue|green|red|nir|swir16|swir22,forest|bare_ground,matseg_0|matseg_1|matseg_2,invariants.0:3,cloudmask'
>>> #channels = 'blue|green|red|depth'
>>> #channels = None
>>> #
>>> batch_size = 1
>>> time_steps = 8
>>> chip_size = 512
>>> datamodule = KWCocoVideoDataModule(
>>>     train_dataset=train_dataset,
>>>     test_dataset=test_dataset,
>>>     batch_size=batch_size,
>>>     channels=channels,
>>>     num_workers=0,
>>>     normalize_inputs=8,
>>>     time_steps=time_steps,
>>>     chip_size=chip_size,
>>>     neg_to_pos_ratio=0,
>>>     min_spacetime_weight=0.5,
>>> )
>>> datamodule.setup('fit')
>>> dl = datamodule.train_dataloader()
>>> dataset = dl.dataset
>>> dataset.requested_tasks['change'] = False
>>> dataset.disable_augmenter = True
>>> target = 0
>>> item, *_ = batch = [dataset[target]]
>>> #item, *_ = batch = next(iter(dl))
>>> # Visualize
>>> canvas = datamodule.draw_batch(batch)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(canvas, doclf=1)
>>> kwplot.show_if_requested()

Example

>>> # xdoctest: +SKIP
>>> # NOTE: I DONT KNOW WHY THIS IS FAILING ON CI AT THE MOMENT. FIXME!
>>> # Run the data module on coco demo datamodules for the CI
>>> from geowatch.tasks.fusion.datamodules.kwcoco_datamodule import *  # NOQA
>>> import kwcoco
>>> import delayed_image
>>> train_dataset = kwcoco.CocoDataset.demo('vidshapes2-multispectral', num_frames=5)
>>> test_dataset = kwcoco.CocoDataset.demo('vidshapes1-multispectral', num_frames=5)
>>> channels = '|'.join([aux['channels'] for aux in train_dataset.imgs[1]['auxiliary']])
>>> chan_spec = delayed_image.channel_spec.FusedChannelSpec.coerce(channels)
>>> #
>>> batch_size = 2
>>> time_steps = 3
>>> chip_size = 128
>>> channels = channels
>>> self = KWCocoVideoDataModule(
>>>     train_dataset=train_dataset,
>>>     test_dataset=test_dataset,
>>>     batch_size=batch_size,
>>>     channels=channels,
>>>     num_workers=0,
>>>     time_steps=time_steps,
>>>     chip_size=chip_size,
>>>     normalize_inputs=True,
>>> )
>>> self.setup('fit')
>>> dl = self.train_dataloader()
>>> item, *_ = batch = next(iter(dl))
>>> expect_shape = (batch_size, time_steps, len(chan_spec), chip_size, chip_size)
>>> assert len(batch) == batch_size
>>> for item in batch:
...     assert len(item['frames']) == time_steps
...     for mode_key, mode_val in item['frames'][0]['modes'].items():
...         assert mode_val.shape[1:3] == (chip_size, chip_size)

For details on accepted arguments see KWCocoVideoDataModuleConfig

classmethod add_argparse_args(parent_parser)[source]

Previously the arguments were in multiple places including here. This has been updated to use the KWCocoVideoDataModuleConfig as the single point where arguments are defined. The functionality of this method is roughly the same as it used to be given that scriptconfig objects can be transformed into argparse objects.

CommandLine

xdoctest -m /home/joncrall/code/watch/geowatch/tasks/fusion/datamodules/kwcoco_datamodule.py add_argparse_args

Example

>>> from geowatch.tasks.fusion.datamodules.kwcoco_datamodule import *  # NOQA
>>> cls = KWCocoVideoDataModule
>>> # TODO: make use of geowatch.utils.lightning_ext import argparse_ext
>>> import argparse
>>> parent_parser = argparse.ArgumentParser()
>>> cls.add_argparse_args(parent_parser)
>>> parent_parser.print_help()
>>> args, _ = parent_parser.parse_known_args(['--use_grid_positives=True'])
>>> assert args.use_grid_positives
>>> args, _ = parent_parser.parse_known_args(['--use_grid_positives=False'])
>>> assert not args.use_grid_positives
>>> args, _ = parent_parser.parse_known_args(['--exclude_sensors=l8,f3'])
>>> assert args.exclude_sensors == 'l8,f3'
>>> args, _ = parent_parser.parse_known_args(['--exclude_sensors=l8'])
>>> assert args.exclude_sensors == 'l8'
classmethod compatible(cfgdict)[source]

Given keyword arguments, find the subset that is compatible with this constructor. This is somewhat hacked because of usage of scriptconfig, but could be made nicer by future updates.

draw_batch(batch, stage='train', outputs=None, max_items=2, overlay_on_image=False, classes=None, **kwargs)[source]

Visualize a batch produced by a KWCocoVideoDataset.

Parameters:
  • batch (Dict[str, List[Tensor]]) – dictionary of uncollated lists of Dataset Items change: [ [T-1, H, W] in [0, 1] forall examples ] saliency: [ [T, H, W, 2] in [0, 1] forall examples ] class: [ [T, H, W, 10] in [0, 1] forall examples ]

  • outputs (Dict[str, Tensor]) – maybe-collated list of network outputs?

  • max_items (int) – Maximum number of items within this batch to draw in a single figure. Defaults to 2.

  • overlay_on_image (bool) – if True overlay annotations on image data for a more compact view. if False separate annotations / images for a less cluttered view.

CommandLine

xdoctest -m geowatch.tasks.fusion.datamodules.kwcoco_datamodule KWCocoVideoDataModule.draw_batch

Example

>>> from geowatch.tasks.fusion.datamodules.kwcoco_datamodule import *  # NOQA
>>> from geowatch.tasks.fusion import datamodules
>>> self = datamodules.KWCocoVideoDataModule(
>>>     train_dataset='special:vidshapes8-multispectral', channels='auto', num_workers=0)
>>> self.setup('fit')
>>> loader = self.train_dataloader()
>>> batch = next(iter(loader))
>>> item = batch[0]
>>> # Visualize
>>> B = len(batch)
>>> C, H, W = ub.peek(item['frames'][0]['modes'].values()).shape
>>> T = len(item['frames'])
>>> import torch
>>> outputs = {'change_probs': [torch.rand(T - 1, H, W) for _ in range(B)]}
>>> outputs.update({'class_probs': [torch.rand(T, H, W, 10) for _ in range(B)]})
>>> stage = 'train'
>>> canvas = self.draw_batch(batch, stage=stage, outputs=outputs)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(canvas)
>>> kwplot.show_if_requested()

Example

>>> # xdoctest: +REQUIRES(--slow)
>>> from geowatch.tasks.fusion.datamodules.kwcoco_datamodule import *  # NOQA
>>> from geowatch.tasks.fusion import datamodules
>>> import geowatch
>>> train_dataset = geowatch.demo.demo_kwcoco_multisensor()
>>> self = datamodules.KWCocoVideoDataModule(
>>>     train_dataset=train_dataset, chip_size=256, time_steps=5, num_workers=0, batch_size=3)
>>> self.setup('fit')
>>> loader = self.train_dataloader()
>>> batch_iter = iter(loader)
>>> batch = next(batch_iter)
>>> batch[1] = None  # simulate a dropped batch item
>>> batch[0] = None  # simulate a dropped batch item
>>> #item = batch[0]
>>> # Visualize
>>> B = len(batch)
>>> outputs = {'change_probs': [], 'class_probs': [], 'saliency_probs': []}
>>> # Add dummy outputs
>>> import torch
>>> for item in batch:
>>>     if item is None:
>>>         [v.append([None]) for v in outputs.values()]
>>>     else:
>>>         [v.append([]) for v in outputs.values()]
>>>         for frame_idx, frame in enumerate(item['frames']):
>>>             H, W = frame['class_idxs'].shape
>>>             if frame_idx > 0:
>>>                 outputs['change_probs'][-1].append(torch.rand(H, W))
>>>             outputs['class_probs'][-1].append(torch.rand(H, W, 10))
>>>             outputs['saliency_probs'][-1].append(torch.rand(H, W, 2))
>>> from geowatch.utils import util_nesting
>>> print(ub.urepr(util_nesting.shape_summary(outputs), nl=1, sort=0))
>>> stage = 'train'
>>> canvas = self.draw_batch(batch, stage=stage, outputs=outputs, max_items=4)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(canvas)
>>> kwplot.show_if_requested()

Example

>>> from geowatch.tasks.fusion.datamodules.kwcoco_datamodule import *  # NOQA
>>> from geowatch.tasks.fusion import datamodules
>>> self = datamodules.KWCocoVideoDataModule(
>>>     batch_size = 3,
>>>     train_dataset='special:vidshapes8-multispectral', channels='auto', num_workers=0)
>>> self.setup('fit')
>>> loader = self.train_dataloader()
>>> batch = next(iter(loader))
>>> batch[1] = None
>>> item = batch[0]
>>> # Visualize
>>> B = len(batch)
>>> C, H, W = ub.peek(item['frames'][0]['modes'].values()).shape
>>> T = len(item['frames'])
>>> import torch
>>> outputs = {'change_probs': [torch.rand(T - 1, H, W) for _ in range(B)]}
>>> outputs.update({'class_probs': [torch.rand(T, H, W, 10) for _ in range(B)]})
>>> stage = 'train'
>>> canvas = self.draw_batch(batch, stage=stage, outputs=outputs)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(canvas)
>>> kwplot.show_if_requested()
setup(stage)[source]
property test_dataset
property train_dataset
property vali_dataset
class geowatch.tasks.fusion.datamodules.KWCocoVideoDataset(sampler, mode='fit', test_with_annot_info=False, autobuild=True, **kwargs)[source]

Bases: Dataset, GetItemMixin, BalanceMixin, PreprocessMixin, IntrospectMixin, MiscMixin, SpacetimeAugmentMixin, BackwardCompatMixin, SMARTDataMixin

Accepted keyword arguments are specified in KWCocoVideoDatasetConfig

Example

>>> # Native Data Sampling
>>> from geowatch.tasks.fusion.datamodules.kwcoco_dataset import *  # NOQA
>>> import ndsampler
>>> import kwcoco
>>> import geowatch
>>> coco_dset = geowatch.coerce_kwcoco('geowatch-multisensor-msi', geodata=True)
>>> print({c.get('sensor_coarse') for c in coco_dset.images().coco_images})
>>> print({c.channels.spec for c in coco_dset.images().coco_images})
>>> sampler = ndsampler.CocoSampler(coco_dset)
>>> self = KWCocoVideoDataset(sampler, time_dims=4, window_dims=(100, 200),
>>>                           input_space_scale='native',
>>>                           window_space_scale='0.05GSD',
>>>                           output_space_scale='native',
>>>                           channels='auto',
>>> )
>>> self.disable_augmenter = True
>>> target = self.sample_grid['targets'][self.sample_grid['positives_indexes'][3]]
>>> item = self[target]
>>> canvas = self.draw_item(item, overlay_on_image=0, rescale=0, max_channels=3)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(canvas)
>>> kwplot.show_if_requested()

Example

>>> # Target GSD Data Sampling
>>> from geowatch.tasks.fusion.datamodules.kwcoco_dataset import *  # NOQA
>>> import ndsampler
>>> import kwcoco
>>> import geowatch
>>> coco_dset = geowatch.coerce_kwcoco('geowatch', geodata=True)
>>> print({c.get('sensor_coarse') for c in coco_dset.images().coco_images})
>>> print({c.channels.spec for c in coco_dset.images().coco_images})
>>> sampler = ndsampler.CocoSampler(coco_dset)
>>> self = KWCocoVideoDataset(sampler, window_dims=(100, 100), time_dims=5,
>>>                           input_space_scale='0.35GSD',
>>>                           window_space_scale='0.7GSD',
>>>                           output_space_scale='0.2GSD',
>>>                           channels='auto',
>>> )
>>> self.disable_augmenter = True
>>> index = self.sample_grid['targets'][self.sample_grid['positives_indexes'][3]]
>>> Box = kwimage.Box
>>> index['space_slice'] = Box.from_slice(index['space_slice']).translate((30, 0)).quantize().to_slice()
>>> item = self[index]
>>> #print('item summary: ' + ub.urepr(self.summarize_item(item), nl=3))
>>> canvas = self.draw_item(item, overlay_on_image=1, rescale=0, max_channels=3)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(canvas)
>>> kwplot.show_if_requested()
Parameters:
  • sampler (kwcoco.CocoDataset | ndsampler.CocoSampler) – kwcoco dataset

  • mode (str) – fit or predict

  • autobuild (bool) – if False, defer potentially expensive initialization. In this case the user must call ._init()

  • **kwargs – see KWCocoVideoDatasetConfig for valid options these options will be stored in the .config attribute.