geowatch.tasks.fusion.datamodules.batch_visualization module¶
- class geowatch.tasks.fusion.datamodules.batch_visualization.BatchVisualizationBuilder(item, item_output=None, combinable_extra=None, max_channels=5, max_dim=224, norm_over_time=0, overlay_on_image=False, draw_weights=True, draw_truth=True, classes=None, default_combinable_channels=None, requested_tasks=None, rescale=1)[source]¶
Bases:
object
Helper object to build a batch visualization.
The basic logic is that we will build a column for each timestep and then arrange them from left to right to show how the scene changes over time. Each column will be made of “cells” which could show either the truth, a prediction, loss weights, or raw input channels.
CommandLine
xdoctest -m geowatch.tasks.fusion.datamodules.batch_visualization BatchVisualizationBuilder
Example
>>> from geowatch.tasks.fusion.datamodules.batch_visualization import * # NOQA >>> from geowatch.tasks.fusion.datamodules.kwcoco_dataset import KWCocoVideoDataset >>> import geowatch >>> coco_dset = geowatch.coerce_kwcoco('vidshapes2-geowatch', num_frames=5) >>> channels = 'r|g|b,B10|B8a|B1|B8|B11,X.2|Y.2' >>> combinable_extra = [['B10', 'B8', 'B8a']] # special behavior >>> # combinable_extra = None # uncomment for raw behavior >>> self = KWCocoVideoDataset( >>> coco_dset, time_dims=5, window_dims=(224, 256), channels=channels, >>> use_centered_positives=True, neg_to_pos_ratio=0) >>> index = len(self) // 4 >>> item = self[index] >>> item_output = BatchVisualizationBuilder.populate_demo_output(item, self.sampler.classes) >>> #binprobs[0][:] = 0 # first change prob should be all zeros >>> requested_tasks = self.requested_tasks >>> builder = BatchVisualizationBuilder( >>> item, item_output, classes=self.classes, requested_tasks=requested_tasks, >>> default_combinable_channels=self.default_combinable_channels, combinable_extra=combinable_extra) >>> #builder.overlay_on_image = 1 >>> #canvas = builder.build() >>> builder.max_channels = 4 >>> builder.overlay_on_image = 0 >>> canvas2 = builder.build() >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> #kwplot.imshow(canvas, fnum=1, pnum=(1, 2, 1)) >>> #kwplot.imshow(canvas2, fnum=1, pnum=(1, 2, 2)) >>> kwplot.imshow(canvas2, fnum=1, doclf=True) >>> kwplot.show_if_requested()
Example
>>> from geowatch.tasks.fusion.datamodules.batch_visualization import * # NOQA >>> from geowatch.tasks.fusion.datamodules.kwcoco_dataset import KWCocoVideoDataset >>> import geowatch >>> coco_dset = geowatch.coerce_kwcoco('vidshapes2-geowatch', num_frames=5) >>> channels = 'r|g|b,B10|B8a|B1|B8|B11,X.2|Y.2' >>> #coco_dset = geowatch.coerce_kwcoco('vidshapes2', num_frames=5) >>> #channels = None >>> combinable_extra = [['B10', 'B8', 'B8a']] # special behavior >>> # combinable_extra = None # uncomment for raw behavior >>> self = KWCocoVideoDataset( >>> coco_dset, time_dims=5, window_dims=(128, 165), channels=channels, >>> use_centered_positives=True, neg_to_pos_ratio=0, input_space_scale='native') >>> index = len(self) // 4 >>> index = 0 >>> target = native_target = self.new_sample_grid['targets'][index].copy() >>> #target['space_slice'] = (slice(224, 448), slice(224, 448)) >>> target['space_slice'] = (slice(196, 196 + 148), slice(32, 128)) >>> #target['space_slice'] = (slice(0, 196 + 148), slice(0, 128)) >>> target['gids'] = target['gids'] >>> #target['space_slice'] = (slice(16, 196 + 148), slice(16, 198)) >>> #target['space_slice'] = (slice(-70, 196 + 148), slice(-128, 128)) >>> native_target.pop('fliprot_params', None) >>> native_target['allow_augment'] = 0 >>> native_item = self[native_target] >>> # Resample the same item, but without native scale sampling for comparison >>> rescaled_target = native_item['target'].copy() >>> rescaled_target.pop('fliprot_params', None) >>> rescaled_target['input_space_scale'] = 1 >>> rescaled_target['output_space_scale'] = 1 >>> rescaled_target['allow_augment'] = 0 >>> rescale = 0 >>> draw_weights = 1 >>> rescaled_item = self[rescaled_target] >>> print(ub.urepr(self.summarize_item(native_item), nl=-1, sort=0)) >>> native_item_output = BatchVisualizationBuilder.populate_demo_output(native_item, self.sampler.classes, rng=0) >>> rescaled_item_output = BatchVisualizationBuilder.populate_demo_output(rescaled_item, self.sampler.classes, rng=0) >>> #rescaled_item_output = None >>> #rescaled_item_output = None >>> #binprobs[0][:] = 0 # first change prob should be all zeros >>> requested_tasks = self.requested_tasks >>> builder = BatchVisualizationBuilder( >>> native_item, native_item_output, classes=self.classes, >>> requested_tasks=requested_tasks, >>> default_combinable_channels=self.default_combinable_channels, >>> combinable_extra=combinable_extra, rescale=rescale, draw_weights=draw_weights) >>> builder.max_channels = 4 >>> builder.overlay_on_image = 0 >>> native_canvas = builder.build() >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> plt = kwplot.autoplt() >>> #kwplot.imshow(canvas, fnum=1, pnum=(1, 2, 1)) >>> #kwplot.imshow(canvas2, fnum=1, pnum=(1, 2, 2)) >>> kwplot.imshow(native_canvas, fnum=1, doclf=True, figtitle='Native Sampling') >>> plt.gcf().tight_layout() >>> ###### >>> # Resample the same item, but without native sampling for comparison >>> print(ub.urepr(self.summarize_item(rescaled_item), nl=-1, sort=0)) >>> builder = BatchVisualizationBuilder( >>> rescaled_item, rescaled_item_output, classes=self.classes, >>> requested_tasks=requested_tasks, >>> default_combinable_channels=self.default_combinable_channels, >>> combinable_extra=combinable_extra, rescale=rescale, draw_weights=draw_weights) >>> builder.max_channels = 4 >>> builder.overlay_on_image = 0 >>> rescaled_canvas = builder.build() >>> kwplot.imshow(rescaled_canvas, fnum=2, doclf=True, figtitle='Rescaled Sampling') >>> plt.gcf().tight_layout() >>> ###### >>> from geowatch.tasks.fusion.datamodules.batch_visualization import _debug_sample_in_context >>> _debug_sample_in_context(self, target) >>> kwplot.show_if_requested()
- geowatch.tasks.fusion.datamodules.batch_visualization.colorize_weights(weights)[source]¶
Normally weights will range between 0 and 1, but in some cases they may range higher. We handle this by coloring the 0-1 range in grayscale and the 1-infinity range in color
Example
>>> from geowatch.tasks.fusion.datamodules.batch_visualization import * # NOQA >>> import kwarray >>> weights = kwimage.gaussian_patch((32, 32)) >>> weights = kwarray.normalize(weights) >>> weights[:16, :16] *= 10 >>> weights[16:, :16] *= 100 >>> weights[16:, 16:] *= 1000 >>> weights[:16, 16:] *= 10000 >>> canvas = colorize_weights(weights) >>> # xdoctest: +REQUIRES(--show) >>> canvas = kwimage.imresize(canvas, dsize=(512, 512), interpolation='nearest').clip(0, 1) >>> canvas = kwimage.draw_text_on_image(canvas, '0-10', org=(1, 1), border=True) >>> canvas = kwimage.draw_text_on_image(canvas, '0-100', org=(256, 1), border=True) >>> canvas = kwimage.draw_text_on_image(canvas, '0-1000', org=(256, 256), border=True) >>> canvas = kwimage.draw_text_on_image(canvas, '0-10000', org=(1, 256), border=True) >>> import kwplot >>> import kwplot >>> kwplot.plt.ion() >>> kwplot.imshow(canvas)
Example
>>> from geowatch.tasks.fusion.datamodules.batch_visualization import * # NOQA >>> import kwarray >>> weights = kwimage.gaussian_patch((32, 32)) >>> n = 512 >>> weight_rows = [ >>> np.linspace(0, 1, n), >>> np.linspace(0, 10, n), >>> np.linspace(0, 100, n), >>> np.linspace(0, 1000, n), >>> np.linspace(0, 2000, n), >>> np.linspace(0, 5000, n), >>> np.linspace(0, 8000, n), >>> np.linspace(0, 10000, n), >>> np.linspace(0, 100000, n), >>> np.linspace(0, 1000000, n), >>> ] >>> canvas = np.array([colorize_weights(row[None, :])[0] for row in weight_rows]) >>> # xdoctest: +REQUIRES(--show) >>> canvas = kwimage.imresize(canvas, dsize=(512, 512), interpolation='nearest').clip(0, 1) >>> p = int(512 / len(weight_rows)) >>> canvas = kwimage.draw_text_on_image(canvas, '0-1', org=(1, 1 + p * 0), border=True) >>> canvas = kwimage.draw_text_on_image(canvas, '0-10', org=(1, 1 + p * 1), border=True) >>> canvas = kwimage.draw_text_on_image(canvas, '0-100', org=(1, 1 + p * 2), border=True) >>> canvas = kwimage.draw_text_on_image(canvas, '0-1000', org=(1, 1 + p * 3), border=True) >>> canvas = kwimage.draw_text_on_image(canvas, '0-2000', org=(1, 1 + p * 4), border=True) >>> canvas = kwimage.draw_text_on_image(canvas, '0-5000', org=(1, 1 + p * 5), border=True) >>> canvas = kwimage.draw_text_on_image(canvas, '0-8000', org=(1, 1 + p * 6), border=True) >>> canvas = kwimage.draw_text_on_image(canvas, '0-10000', org=(1, 1 + p * 7), border=True) >>> canvas = kwimage.draw_text_on_image(canvas, '0-100000', org=(1, 1 + p * 8), border=True) >>> canvas = kwimage.draw_text_on_image(canvas, '0-1000000', org=(1, 1 + p * 9), border=True) >>> import kwplot >>> import kwplot >>> kwplot.plt.ion() >>> kwplot.imshow(canvas)