geowatch.tasks.fusion.datamodules.network_io module¶
Module to define the BatchItem
wrapper class.
This is currently a work in progress, and the goal is to abstract the format of the items produced by the datalaoder to make them ammenable for use with both our heterogeneous networks as well as more standard networks that require data be more regular.
This has a lot in common with prior work in:
~/code/netharn/netharn/data/data_containers.py
- class geowatch.tasks.fusion.datamodules.network_io.BatchItem[source]¶
Bases:
dict
Ideally a batch item is simply an unstructured dictionary. This is the base class for more specific implementations, which are going to all be dictionaries, but the class will expose convinience methods.
- class geowatch.tasks.fusion.datamodules.network_io.HeterogeneousBatchItem[source]¶
Bases:
BatchItem
A BatchItem is a container to help organize the output of the KWCocoVideoDataset. For backwards compatibility it retains the original dictionary interface.
Example
>>> from geowatch.tasks.fusion.datamodules.network_io import * # NOQA >>> self = HeterogeneousBatchItem.demo() >>> print(self) >>> print(ub.urepr(self.summarize(), nl=2)) >>> canvas = self.draw() >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(canvas, fnum=1, pnum=(1, 1, 1)) >>> kwplot.show_if_requested()
- property num_frames¶
- property sensorchan_histogram¶
- draw(item_output=None, combinable_extra=None, max_channels=5, max_dim=224, norm_over_time='auto', overlay_on_image=False, draw_weights=True, rescale='auto', classes=None, predictable_classes=None, show_summary_text=True, requested_tasks=None, legend=True, **kwargs)[source]¶
Visualize this batch item. Corresponds to the dataset
IntrospectMixin.draw_item()
method.Not finished. The dataset draw_item class has context not currently available like predictable classes, that needs to be represented in the item itself.
- summarize(coco_dset=None, stats=False)[source]¶
Return debugging stats about the item
- Parameters:
coco_dset (CocoDataset) – The coco dataset used to generate the item. If specified, allows the summary to lookup extra information
stats (bool) – if True, include statistics on input datas.
- Returns:
a summary of the item
- Return type:
Example
>>> from geowatch.tasks.fusion.datamodules.network_io import * # NOQA >>> self = HeterogeneousBatchItem.demo() >>> item_summary = self.summarize(stats=0) >>> print(f'item_summary = {ub.urepr(item_summary, nl=-2)}')
- class geowatch.tasks.fusion.datamodules.network_io.HomogeneousBatchItem[source]¶
Bases:
HeterogeneousBatchItem
Ideally this is a simplified representation that “just works” with standard off the shelf networks.
Example
>>> from geowatch.tasks.fusion.datamodules.network_io import * # NOQA >>> self = HomogeneousBatchItem.demo() >>> print(self) >>> print(ub.urepr(self.summarize(), nl=2)) >>> canvas = self.draw() >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(canvas, fnum=1, pnum=(1, 1, 1)) >>> kwplot.show_if_requested()
- class geowatch.tasks.fusion.datamodules.network_io.RGBImageBatchItem[source]¶
Bases:
HomogeneousBatchItem
Only allows a single RGB image as the input.
Example
>>> from geowatch.tasks.fusion.datamodules.network_io import * # NOQA >>> self = RGBImageBatchItem.demo() >>> print(self) >>> print(ub.urepr(self.summarize(), nl=2)) >>> canvas = self.draw() >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(canvas, fnum=1, pnum=(1, 1, 1)) >>> kwplot.show_if_requested()
- property frame¶
- property channels¶
- property imdata_chw¶
- property nonlocal_class_ohe¶
- class geowatch.tasks.fusion.datamodules.network_io.UncollatedBatch(iterable=(), /)[source]¶
Bases:
list
A generic list of batch items, which may or may not be collatable.
- class geowatch.tasks.fusion.datamodules.network_io.HeterogeneousBatch(iterable=(), /)[source]¶
Bases:
UncollatedBatch
A HeterogeneousBatch a
List[HeterogeneousBatchItem]
.
- class geowatch.tasks.fusion.datamodules.network_io.UncollatedRGBImageBatch(iterable=(), /)[source]¶
Bases:
UncollatedBatch
A list of collatable RGBImageBatchItem
- class geowatch.tasks.fusion.datamodules.network_io.CollatedRGBImageBatch[source]¶
Bases:
CollatedBatch
- class geowatch.tasks.fusion.datamodules.network_io.NetworkOutputs[source]¶
Bases:
dict
Network outputs should ALWAYS be a dictionary, this is the most flexible way to encode networks such that they can be extended later.
- class geowatch.tasks.fusion.datamodules.network_io.UncollatedNetworkOutputs[source]¶
Bases:
NetworkOutputs
- class geowatch.tasks.fusion.datamodules.network_io.CollatedNetworkOutputs[source]¶
Bases:
NetworkOutputs
Example
>>> from geowatch.tasks.fusion.datamodules.network_io import * # NOQA >>> B, H, W, C = 2, 3, 3, 11 >>> import torch >>> logits = { >>> 'nonlocal_class': (torch.rand(B, C) - 0.5) * 10, >>> 'segmentation_class': (torch.rand(B, W, H, C) - 0.5) * 10, >>> 'nonlocal_saliency': (torch.rand(B, 1) - 0.5) * 10, >>> 'segmentation_saliency': (torch.rand(B, W, H, 1) - 0.5) * 10, >>> } >>> self = CollatedNetworkOutputs( >>> logits=logits, >>> probs={k + '_probs': v.sigmoid() for k, v in logits.items()}, >>> loss_parts={}, >>> loss=10, >>> ) >>> new = self.decollate() >>> self._debug_shape() >>> new._debug_shape()
- geowatch.tasks.fusion.datamodules.network_io.decollate(collated)[source]¶
Breakup a collated batch in a standardized way. Returns a list of items for each batch item with a structure that matches the collated batch, but without the leading batch dimension in each value.
Example
>>> from geowatch.tasks.fusion.datamodules.network_io import * # NOQA >>> import torch >>> B, H, W, C = 5, 2, 3, 7 >>> collated = { >>> 'segmentation_class': torch.rand(B, H, W, C), >>> 'nonlocal_class': torch.rand(B, C), >>> } >>> uncollated = decollate(collated) >>> assert len(uncollated) == B >>> assert (uncollated[0]['nonlocal_class'] == collated['nonlocal_class'][0]).all()