geowatch.tasks.fusion.datamodules.batch_visualization module

class geowatch.tasks.fusion.datamodules.batch_visualization.BatchVisualizationBuilder(item, item_output=None, combinable_extra=None, max_channels=5, max_dim=224, norm_over_time=0, overlay_on_image=False, draw_weights=True, draw_truth=True, classes=None, default_combinable_channels=None, requested_tasks=None, rescale=1)[source]

Bases: object

Helper object to build a batch visualization.

The basic logic is that we will build a column for each timestep and then arrange them from left to right to show how the scene changes over time. Each column will be made of “cells” which could show either the truth, a prediction, loss weights, or raw input channels.

CommandLine

xdoctest -m geowatch.tasks.fusion.datamodules.batch_visualization BatchVisualizationBuilder

Example

>>> from geowatch.tasks.fusion.datamodules.batch_visualization import *  # NOQA
>>> from geowatch.tasks.fusion.datamodules.kwcoco_dataset import KWCocoVideoDataset
>>> import geowatch
>>> coco_dset = geowatch.coerce_kwcoco('vidshapes2-geowatch', num_frames=5)
>>> channels = 'r|g|b,B10|B8a|B1|B8|B11,X.2|Y.2'
>>> combinable_extra = [['B10', 'B8', 'B8a']]  # special behavior
>>> # combinable_extra = None  # uncomment for raw behavior
>>> self = KWCocoVideoDataset(
>>>     coco_dset, time_dims=5, window_dims=(224, 256), channels=channels,
>>>     use_centered_positives=True, neg_to_pos_ratio=0)
>>> index = len(self) // 4
>>> item = self[index]
>>> item_output = BatchVisualizationBuilder.populate_demo_output(item, self.sampler.classes)
>>> #binprobs[0][:] = 0  # first change prob should be all zeros
>>> requested_tasks = self.requested_tasks
>>> builder = BatchVisualizationBuilder(
>>>     item, item_output, classes=self.classes, requested_tasks=requested_tasks,
>>>     default_combinable_channels=self.default_combinable_channels, combinable_extra=combinable_extra)
>>> #builder.overlay_on_image = 1
>>> #canvas = builder.build()
>>> builder.max_channels = 4
>>> builder.overlay_on_image = 0
>>> canvas2 = builder.build()
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> #kwplot.imshow(canvas, fnum=1, pnum=(1, 2, 1))
>>> #kwplot.imshow(canvas2, fnum=1, pnum=(1, 2, 2))
>>> kwplot.imshow(canvas2, fnum=1, doclf=True)
>>> kwplot.show_if_requested()

Example

>>> from geowatch.tasks.fusion.datamodules.batch_visualization import *  # NOQA
>>> from geowatch.tasks.fusion.datamodules.kwcoco_dataset import KWCocoVideoDataset
>>> import geowatch
>>> coco_dset = geowatch.coerce_kwcoco('vidshapes2-geowatch', num_frames=5)
>>> channels = 'r|g|b,B10|B8a|B1|B8|B11,X.2|Y.2'
>>> #coco_dset = geowatch.coerce_kwcoco('vidshapes2', num_frames=5)
>>> #channels = None
>>> combinable_extra = [['B10', 'B8', 'B8a']]  # special behavior
>>> # combinable_extra = None  # uncomment for raw behavior
>>> self = KWCocoVideoDataset(
>>>     coco_dset, time_dims=5, window_dims=(128, 165), channels=channels,
>>>     use_centered_positives=True, neg_to_pos_ratio=0, input_space_scale='native')
>>> index = len(self) // 4
>>> index = 0
>>> target = native_target = self.new_sample_grid['targets'][index].copy()
>>> #target['space_slice'] = (slice(224, 448), slice(224, 448))
>>> target['space_slice'] = (slice(196, 196 + 148), slice(32, 128))
>>> #target['space_slice'] = (slice(0, 196 + 148), slice(0, 128))
>>> target['gids'] = target['gids']
>>> #target['space_slice'] = (slice(16, 196 + 148), slice(16, 198))
>>> #target['space_slice'] = (slice(-70, 196 + 148), slice(-128, 128))
>>> native_target.pop('fliprot_params', None)
>>> native_target['allow_augment'] = 0
>>> native_item = self[native_target]
>>> # Resample the same item, but without native scale sampling for comparison
>>> rescaled_target = native_item['target'].copy()
>>> rescaled_target.pop('fliprot_params', None)
>>> rescaled_target['input_space_scale'] = 1
>>> rescaled_target['output_space_scale'] = 1
>>> rescaled_target['allow_augment'] = 0
>>> rescale = 0
>>> draw_weights = 1
>>> rescaled_item = self[rescaled_target]
>>> print(ub.urepr(self.summarize_item(native_item), nl=-1, sort=0))
>>> native_item_output = BatchVisualizationBuilder.populate_demo_output(native_item, self.sampler.classes, rng=0)
>>> rescaled_item_output = BatchVisualizationBuilder.populate_demo_output(rescaled_item, self.sampler.classes, rng=0)
>>> #rescaled_item_output = None
>>> #rescaled_item_output = None
>>> #binprobs[0][:] = 0  # first change prob should be all zeros
>>> requested_tasks = self.requested_tasks
>>> builder = BatchVisualizationBuilder(
>>>     native_item, native_item_output, classes=self.classes,
>>>     requested_tasks=requested_tasks,
>>>     default_combinable_channels=self.default_combinable_channels,
>>>     combinable_extra=combinable_extra, rescale=rescale, draw_weights=draw_weights)
>>> builder.max_channels = 4
>>> builder.overlay_on_image = 0
>>> native_canvas = builder.build()
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> plt = kwplot.autoplt()
>>> #kwplot.imshow(canvas, fnum=1, pnum=(1, 2, 1))
>>> #kwplot.imshow(canvas2, fnum=1, pnum=(1, 2, 2))
>>> kwplot.imshow(native_canvas, fnum=1, doclf=True, figtitle='Native Sampling')
>>> plt.gcf().tight_layout()
>>> ######
>>> # Resample the same item, but without native sampling for comparison
>>> print(ub.urepr(self.summarize_item(rescaled_item), nl=-1, sort=0))
>>> builder = BatchVisualizationBuilder(
>>>     rescaled_item, rescaled_item_output, classes=self.classes,
>>>     requested_tasks=requested_tasks,
>>>     default_combinable_channels=self.default_combinable_channels,
>>>     combinable_extra=combinable_extra, rescale=rescale, draw_weights=draw_weights)
>>> builder.max_channels = 4
>>> builder.overlay_on_image = 0
>>> rescaled_canvas = builder.build()
>>> kwplot.imshow(rescaled_canvas, fnum=2, doclf=True, figtitle='Rescaled Sampling')
>>> plt.gcf().tight_layout()
>>> ######
>>> from geowatch.tasks.fusion.datamodules.batch_visualization import _debug_sample_in_context
>>> _debug_sample_in_context(self, target)
>>> kwplot.show_if_requested()
classmethod populate_demo_output(item, classes, rng=None)[source]

Make dummy output for a batch item for testing

build()[source]
geowatch.tasks.fusion.datamodules.batch_visualization.colorize_weights(weights)[source]

Normally weights will range between 0 and 1, but in some cases they may range higher. We handle this by coloring the 0-1 range in grayscale and the 1-infinity range in color

Example

>>> from geowatch.tasks.fusion.datamodules.batch_visualization import *  # NOQA
>>> import kwarray
>>> weights = kwimage.gaussian_patch((32, 32))
>>> weights = kwarray.normalize(weights)
>>> weights[:16, :16] *= 10
>>> weights[16:, :16] *= 100
>>> weights[16:, 16:] *= 1000
>>> weights[:16, 16:] *= 10000
>>> canvas = colorize_weights(weights)
>>> # xdoctest: +REQUIRES(--show)
>>> canvas = kwimage.imresize(canvas, dsize=(512, 512), interpolation='nearest').clip(0, 1)
>>> canvas = kwimage.draw_text_on_image(canvas, '0-10', org=(1, 1), border=True)
>>> canvas = kwimage.draw_text_on_image(canvas, '0-100', org=(256, 1), border=True)
>>> canvas = kwimage.draw_text_on_image(canvas, '0-1000', org=(256, 256), border=True)
>>> canvas = kwimage.draw_text_on_image(canvas, '0-10000', org=(1, 256), border=True)
>>> import kwplot
>>> import kwplot
>>> kwplot.plt.ion()
>>> kwplot.imshow(canvas)

Example

>>> from geowatch.tasks.fusion.datamodules.batch_visualization import *  # NOQA
>>> import kwarray
>>> weights = kwimage.gaussian_patch((32, 32))
>>> n = 512
>>> weight_rows = [
>>>     np.linspace(0, 1, n),
>>>     np.linspace(0, 10, n),
>>>     np.linspace(0, 100, n),
>>>     np.linspace(0, 1000, n),
>>>     np.linspace(0, 2000, n),
>>>     np.linspace(0, 5000, n),
>>>     np.linspace(0, 8000, n),
>>>     np.linspace(0, 10000, n),
>>>     np.linspace(0, 100000, n),
>>>     np.linspace(0, 1000000, n),
>>> ]
>>> canvas = np.array([colorize_weights(row[None, :])[0] for row in weight_rows])
>>> # xdoctest: +REQUIRES(--show)
>>> canvas = kwimage.imresize(canvas, dsize=(512, 512), interpolation='nearest').clip(0, 1)
>>> p = int(512 / len(weight_rows))
>>> canvas = kwimage.draw_text_on_image(canvas, '0-1', org=(1, 1 + p * 0), border=True)
>>> canvas = kwimage.draw_text_on_image(canvas, '0-10', org=(1, 1 + p * 1), border=True)
>>> canvas = kwimage.draw_text_on_image(canvas, '0-100', org=(1, 1 + p * 2), border=True)
>>> canvas = kwimage.draw_text_on_image(canvas, '0-1000', org=(1, 1 + p * 3), border=True)
>>> canvas = kwimage.draw_text_on_image(canvas, '0-2000', org=(1, 1 + p * 4), border=True)
>>> canvas = kwimage.draw_text_on_image(canvas, '0-5000', org=(1, 1 + p * 5), border=True)
>>> canvas = kwimage.draw_text_on_image(canvas, '0-8000', org=(1, 1 + p * 6), border=True)
>>> canvas = kwimage.draw_text_on_image(canvas, '0-10000', org=(1, 1 + p * 7), border=True)
>>> canvas = kwimage.draw_text_on_image(canvas, '0-100000', org=(1, 1 + p * 8), border=True)
>>> canvas = kwimage.draw_text_on_image(canvas, '0-1000000', org=(1, 1 + p * 9), border=True)
>>> import kwplot
>>> import kwplot
>>> kwplot.plt.ion()
>>> kwplot.imshow(canvas)