geowatch.tasks.fusion.datamodules.network_io module

Module to define the BatchItem wrapper class.

This is currently a work in progress, and the goal is to abstract the format of the items produced by the datalaoder to make them ammenable for use with both our heterogeneous networks as well as more standard networks that require data be more regular.

This has a lot in common with prior work in:

~/code/netharn/netharn/data/data_containers.py

class geowatch.tasks.fusion.datamodules.network_io.BatchItem[source]

Bases: dict

Ideally a batch item is simply an unstructured dictionary. This is the base class for more specific implementations, which are going to all be dictionaries, but the class will expose convinience methods.

asdict()[source]
draw(**kwargs)[source]
classmethod demo(**kwargs)[source]
classmethod summarize(**kwargs)[source]
class geowatch.tasks.fusion.datamodules.network_io.HeterogeneousBatchItem[source]

Bases: BatchItem

A BatchItem is a container to help organize the output of the KWCocoVideoDataset. For backwards compatibility it retains the original dictionary interface.

Example

>>> from geowatch.tasks.fusion.datamodules.network_io import *  # NOQA
>>> self = HeterogeneousBatchItem.demo()
>>> print(self)
>>> print(ub.urepr(self.summarize(), nl=2))
>>> canvas = self.draw()
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(canvas, fnum=1, pnum=(1, 1, 1))
>>> kwplot.show_if_requested()
property num_frames
property sensorchan_histogram
classmethod demo()[source]
draw(item_output=None, combinable_extra=None, max_channels=5, max_dim=224, norm_over_time='auto', overlay_on_image=False, draw_weights=True, rescale='auto', classes=None, predictable_classes=None, show_summary_text=True, requested_tasks=None, legend=True, **kwargs)[source]

Visualize this batch item. Corresponds to the dataset IntrospectMixin.draw_item() method.

Not finished. The dataset draw_item class has context not currently available like predictable classes, that needs to be represented in the item itself.

summarize(coco_dset=None, stats=False)[source]

Return debugging stats about the item

Parameters:
  • coco_dset (CocoDataset) – The coco dataset used to generate the item. If specified, allows the summary to lookup extra information

  • stats (bool) – if True, include statistics on input datas.

Returns:

a summary of the item

Return type:

dict

Example

>>> from geowatch.tasks.fusion.datamodules.network_io import *  # NOQA
>>> self = HeterogeneousBatchItem.demo()
>>> item_summary = self.summarize(stats=0)
>>> print(f'item_summary = {ub.urepr(item_summary, nl=-2)}')
class geowatch.tasks.fusion.datamodules.network_io.HomogeneousBatchItem[source]

Bases: HeterogeneousBatchItem

Ideally this is a simplified representation that “just works” with standard off the shelf networks.

Example

>>> from geowatch.tasks.fusion.datamodules.network_io import *  # NOQA
>>> self = HomogeneousBatchItem.demo()
>>> print(self)
>>> print(ub.urepr(self.summarize(), nl=2))
>>> canvas = self.draw()
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(canvas, fnum=1, pnum=(1, 1, 1))
>>> kwplot.show_if_requested()
classmethod demo()[source]

Example

cls = HomogeneousBatchItem self = cls.demo()

class geowatch.tasks.fusion.datamodules.network_io.RGBImageBatchItem[source]

Bases: HomogeneousBatchItem

Only allows a single RGB image as the input.

Example

>>> from geowatch.tasks.fusion.datamodules.network_io import *  # NOQA
>>> self = RGBImageBatchItem.demo()
>>> print(self)
>>> print(ub.urepr(self.summarize(), nl=2))
>>> canvas = self.draw()
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(canvas, fnum=1, pnum=(1, 1, 1))
>>> kwplot.show_if_requested()
property frame
property channels
property imdata_chw
property nonlocal_class_ohe
classmethod demo(index=None)[source]

cls = RGBImageBatchItem

class geowatch.tasks.fusion.datamodules.network_io.UncollatedBatch(iterable=(), /)[source]

Bases: list

A generic list of batch items, which may or may not be collatable.

class geowatch.tasks.fusion.datamodules.network_io.HeterogeneousBatch(iterable=(), /)[source]

Bases: UncollatedBatch

A HeterogeneousBatch a List[HeterogeneousBatchItem].

class geowatch.tasks.fusion.datamodules.network_io.UncollatedRGBImageBatch(iterable=(), /)[source]

Bases: UncollatedBatch

A list of collatable RGBImageBatchItem

classmethod demo(num_items=3)[source]
classmethod coerce(data)[source]
classmethod from_items(data)[source]
collate()[source]
Returns:

CollatedRGBImageBatch

Example

>>> from geowatch.tasks.fusion.datamodules.network_io import *  # NOQA
>>> self = UncollatedRGBImageBatch.demo()
>>> batch = self.collate()
class geowatch.tasks.fusion.datamodules.network_io.CollatedBatch[source]

Bases: dict

asdict()[source]
class geowatch.tasks.fusion.datamodules.network_io.CollatedRGBImageBatch[source]

Bases: CollatedBatch

to(device)[source]
class geowatch.tasks.fusion.datamodules.network_io.NetworkOutputs[source]

Bases: dict

Network outputs should ALWAYS be a dictionary, this is the most flexible way to encode networks such that they can be extended later.

class geowatch.tasks.fusion.datamodules.network_io.UncollatedNetworkOutputs[source]

Bases: NetworkOutputs

class geowatch.tasks.fusion.datamodules.network_io.CollatedNetworkOutputs[source]

Bases: NetworkOutputs

Example

>>> from geowatch.tasks.fusion.datamodules.network_io import *  # NOQA
>>> B, H, W, C = 2, 3, 3, 11
>>> import torch
>>> logits = {
>>>     'nonlocal_class': (torch.rand(B, C) - 0.5) * 10,
>>>     'segmentation_class': (torch.rand(B, W, H, C) - 0.5) * 10,
>>>     'nonlocal_saliency': (torch.rand(B, 1) - 0.5) * 10,
>>>     'segmentation_saliency': (torch.rand(B, W, H, 1) - 0.5) * 10,
>>> }
>>> self = CollatedNetworkOutputs(
>>>     logits=logits,
>>>     probs={k + '_probs': v.sigmoid() for k, v in logits.items()},
>>>     loss_parts={},
>>>     loss=10,
>>> )
>>> new = self.decollate()
>>> self._debug_shape()
>>> new._debug_shape()
decollate()[source]

Convert back into a per-item structure for easier analysis / drawing.

geowatch.tasks.fusion.datamodules.network_io.decollate(collated)[source]

Breakup a collated batch in a standardized way. Returns a list of items for each batch item with a structure that matches the collated batch, but without the leading batch dimension in each value.

Example

>>> from geowatch.tasks.fusion.datamodules.network_io import *  # NOQA
>>> import torch
>>> B, H, W, C = 5, 2, 3, 7
>>> collated = {
>>>     'segmentation_class': torch.rand(B, H, W, C),
>>>     'nonlocal_class': torch.rand(B, C),
>>> }
>>> uncollated = decollate(collated)
>>> assert len(uncollated) == B
>>> assert (uncollated[0]['nonlocal_class'] == collated['nonlocal_class'][0]).all()