Source code for kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset

"""
Defines :class:`KWCocoVideoDataset`, a torch Dataset for kwcoco image and video
data.

The configurable input parameters are defined in the
:class:`KWCocoVideoDatasetConfig`, which is used to resolve kwargs passed to
the main :class:`KWCocoVideoDataset` class. These parameters give the developer
fine grined control over how sampling is done. At the most basic level the
developer should specify:

    * window_space_scale - the size of the window (possibly in a virtual sample space) used to build the virtual sample grid.

    * input_space_scale - the scale of the inputs (default to the window space scale, but could be different).

    * time_kernel - or (time_sampling / time_dims) to indicate how many / distribution of frames sampled over time.

The following doctests provide a crash course on what sort of sampling
parameters are available.

CommandLine:
    xdoctest -m kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset __doc__:0 --show
    xdoctest -m kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset __doc__:1 --show

Example:
    >>> # Basic Data Sampling
    >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import *  # NOQA
    >>> import ndsampler
    >>> import kwcoco
    >>> coco_dset = kwcoco.CocoDataset.demo('vidshapes1', num_frames=10)
    >>> sampler = ndsampler.CocoSampler(coco_dset)
    >>> self = KWCocoVideoDataset(sampler, time_dims=4, window_dims=(300, 300),
    >>>                           channels='r|g|b')
    >>> self.disable_augmenter = True
    >>> index = self.sample_grid['targets'][self.sample_grid['positives_indexes'][0]]
    >>> item = self[index]
    >>> # Summarize batch item in text
    >>> summary = self.summarize_item(item)
    >>> print('item summary: ' + ub.urepr(summary, nl=2))
    >>> # Draw batch item
    >>> canvas = self.draw_item(item)
    >>> # xdoctest: +REQUIRES(--show)
    >>> import kwplot
    >>> kwplot.autompl()
    >>> kwplot.imshow(canvas)
    >>> kwplot.show_if_requested()

Example:
    >>> # Basic Data Sampling
    >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import *  # NOQA
    >>> import ndsampler
    >>> import kwcoco
    >>> coco_dset = kwcoco.CocoDataset.demo('vidshapes1', num_frames=10)
    >>> sampler = ndsampler.CocoSampler(coco_dset)
    >>> self = KWCocoVideoDataset(sampler, window_dims='full', channels='r|g|b')
    >>> self.disable_augmenter = True
    >>> index = self.sample_grid['targets'][self.sample_grid['positives_indexes'][0]]
    >>> item = self[index]
    >>> # Summarize batch item in text
    >>> summary = self.summarize_item(item)
    >>> print('item summary: ' + ub.urepr(summary, nl=2))
    >>> # Draw batch item
    >>> canvas = self.draw_item(item)
    >>> # xdoctest: +REQUIRES(--show)
    >>> import kwplot
    >>> kwplot.autompl()
    >>> kwplot.imshow(canvas)
    >>> kwplot.show_if_requested()


Example:
    >>> # Demo toy data without augmentation
    >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import *  # NOQA
    >>> import kwcoco
    >>> coco_dset = kwcoco.CocoDataset.demo('vidshapes2-multispectral', num_frames=10)
    >>> channels = 'B10,B8a|B1,B8'
    >>> self = KWCocoVideoDataset(coco_dset, time_dims=4, window_dims=(300, 300),
    >>>                           channels=channels,
    >>>                           input_space_scale='native',
    >>>                           output_space_scale=None,
    >>>                           window_space_scale=1.2,
    >>>                           augment_space_shift_rate=0.5,
    >>>                           use_grid_negatives=False,
    >>>                           use_grid_positives=False,
    >>>                           use_centered_positives=True,
    >>>                           absolute_weighting=True,
    >>>                           time_sampling='uniform',
    >>>                           time_kernel='-1year,0,1month,1year',
    >>>                           modality_dropout=0.5,
    >>>                           channel_dropout=0.5,
    >>>                           temporal_dropout=0.7,
    >>>                           temporal_dropout_rate=1.0)
    >>> # Add weights to annots
    >>> annots = self.sampler.dset.annots()
    >>> annots.set('weight', 2 + np.random.rand(len(annots)) * 10)
    >>> self.disable_augmenter = False
    >>> index = self.sample_grid['targets'][self.sample_grid['positives_indexes'][3]]
    >>> item = self[index]
    >>> summary = self.summarize_item(item)
    >>> print('item summary: ' + ub.urepr(summary, nl=3))
    >>> canvas = self.draw_item(item, overlay_on_image=0, rescale=0, max_dim=1024)
    >>> # xdoctest: +REQUIRES(--show)
    >>> import kwplot
    >>> kwplot.autompl()
    >>> kwplot.imshow(canvas)
    >>> kwplot.show_if_requested()

Example:
    >>> # Demo toy data with augmentation
    >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import *  # NOQA
    >>> import kwcoco
    >>> coco_dset = kwcoco.CocoDataset.demo('vidshapes2-multispectral', num_frames=10)
    >>> channels = 'B10,B8a|B1,B8'
    >>> self = KWCocoVideoDataset(coco_dset, time_dims=3, window_dims=(300, 300),
    >>>                           channels=channels,
    >>>                           input_space_scale='native',
    >>>                           output_space_scale=None,
    >>>                           window_space_scale=1.2,
    >>>                           time_sampling='soft2+distribute',
    >>>                           time_kernel='-1y,0,1y',
    >>>                           modality_dropout=0.5,
    >>>                           temporal_dropout=0.5)
    >>> assert not self.disable_augmenter
    >>> index = self.sample_grid['targets'][self.sample_grid['positives_indexes'][3]]
    >>> item = self[index]
    >>> assert item['target']['allow_augment']
    >>> print('item summary: ' + ub.urepr(self.summarize_item(item), nl=3))
    >>> canvas = self.draw_item(item, overlay_on_image=0, rescale=0)
    >>> # xdoctest: +REQUIRES(--show)
    >>> import kwplot
    >>> kwplot.autompl()
    >>> kwplot.imshow(canvas)
    >>> kwplot.show_if_requested()


SeeAlso:

    * For notes on spaces, see: ~/code/kwcoco_dataloader/docs/source/manual/development/coding_conventions.rst

Known Issues
------------
- [ ] FIXME: sensorchan codes should exclude non-specified sensors immediately before temporal sampling. Currently temporal sampling is given everything. E.g. (L8,S2):red|green|blue should not allow WV to be included in sampling.


Roadmap
-------

- [ ] Get external feedback and suggestions.
- [ ] Accept albumentations json or more concise spec for custom augmentation
- [ ] Optimize fixed channel case.
- [ ] Optimize fixed image size case.
- [ ] Optimize fixed video size case.
- [ ] Optimize the spacetime grid sampler.
- [ ] Allow input resolution to be specified as a fixed pixel size.
- [ ] Don't force compute of width / height if the window_space_dims is "full".


Ignore:
    # For developers, to extract a copy of this dataloader that does not depend
    # on the rest of kwcoco_dataloader, you can attempt to "liberate" it:
    from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import KWCocoVideoDataset
    import liberator
    lib = liberator.Liberator()
    lib.add_dynamic(KWCocoVideoDataset)
    lib.expand(['kwcoco_dataloader'])
    text = lib.current_sourcecode()
    print(ub.highlight_code(text))
    num_lines = text.count(chr(10))
    print(f'num_lines={num_lines}')

"""
import copy
import einops
import kwarray
import kwcoco
import kwimage
import kwutil
import ndsampler
import numpy as np
import os
import pandas as pd
import rich
import scriptconfig as scfg
import torch
import ubelt as ub
import warnings

from os import getenv
from kwutil import util_time
from shapely.ops import unary_union
from torch.utils import data
from typing import Dict
from typing import NamedTuple

from kwcoco_dataloader import heuristics
from kwcoco_dataloader.utils import kwcoco_extensions
from kwcoco_dataloader.utils import util_kwarray
from kwcoco_dataloader.utils import util_kwimage
from kwcoco_dataloader.tasks.fusion.datamodules import util_positional_encoding
from kwcoco_dataloader.tasks.fusion.datamodules import data_utils
from kwcoco_dataloader.tasks.fusion.datamodules import balanced_sampling
from kwcoco_dataloader.tasks.fusion.datamodules import spacetime_grid_builder
from kwcoco_dataloader.tasks.fusion.datamodules.data_augment import SpacetimeAugmentMixin
from kwcoco_dataloader.tasks.fusion.datamodules.smart_mixins import SMARTDataMixin
from kwcoco_dataloader.tasks.fusion.datamodules.network_io import HeterogeneousBatchItem
from kwcoco_dataloader.tasks.fusion.datamodules.network_io import HomogeneousBatchItem
from kwcoco_dataloader.tasks.fusion.datamodules.network_io import RGBImageBatchItem
from kwcoco_dataloader.tasks.fusion.datamodules.batch_visualization import BatchVisualizationBuilder
from kwcoco_dataloader.tasks.fusion.datamodules.robust_normalizer import RobustNormalizer
from kwcoco_dataloader.tasks.fusion.datamodules.dynamic_channel_handler import DynamicChannels

from delayed_image.channel_spec import FusedChannelSpec
from delayed_image.channel_spec import ChannelSpec
from delayed_image.sensorchan_spec import SensorChanSpec
from delayed_image.sensorchan_spec import FusedSensorChanSpec  # NOQA
from delayed_image.sensorchan_spec import SensorSpec  # NOQA


try:
    from functools import cache
except ImportError:
    from ubelt import memoize as cache


# These are groups to to organize options in the KWCocoVideoDatasetConfig
SPACE_GROUP = 'spacetime (space)'
TIME_GROUP = 'spacetime (time)'
SAMPLE_GROUP = 'sampling'
FILTER_GROUP = 'filtering'
WEIGHT_GROUP = 'weighting'
NORM_GROUP = 'normalization'
AUGMENTATION_GROUP = 'augmentation'
SELECTION_GROUP = 'selection'
MISC_GROUP = 'misc'


class KWCocoVideoDatasetConfig(scfg.DataConfig):
    """
    This is the configuration for a single dataset that could be used for
    train, test, or validation.

    In the future this might be convertible to, or handled by omegaconfig

    The core spacetime parameters are:

        * window_space_scale
        * input_space_scale
        * output_space_scale
        * time_steps
        * time_sampling
        * chip_dims / window_space_dims

    This dataset defines an implicit grid of where it will sample, and it uses
    these "targets" to request data from ndsampler, which is what gives us
    amortized random access to the dataset.

    The logic contained in this class concerns:
        * interacting with the spacetime grids sampler to build a grid
        * target-level augmentations
        * data-level augmentations (need more of these)
        * interacting with ndsampler to read data associated with a grid point
        * balanced sampling over the targets
        * mapping targets to a HeterogeneousBatchItem in the most general case.
    """
    # TODO:
    # 'positive_labels': scfg.Value(None, help=ub.paragraph(
    #     '''
    #     Labels to consider positive (in addition to inferred labels)
    #     ''')),

    sampler_backend = scfg.Value(None, help="Can be None, 'npy', or 'cog'.")

    sampler_workdir = scfg.Value(None, help="A location the sampler can write a cache if a backend is selected.")

    sampler_workers = scfg.Value('avail/2', help="Number of workers to precompute a sampler backend.")

    ###############
    # SPACE OPTIONS
    ###############

    chip_dims = scfg.Value(128, alias=['window_space_dims', 'window_dims', 'chip_size'], group=SPACE_GROUP, help=ub.paragraph(
        '''
        The spatial window dimension (i.e. width and height) used to sample
        from the images. This is the window that is "slid over" images in the
        dataset when building the spacetime grid.  If given as a number it is
        used as both width and height. Can also be a width, height tuple.  Can
        also be a string code. Valid codes are "full", which always read the
        entire image.

        NOTE: The main key will change to window_space_dims in the future.
        '''), nargs='+')
    fixed_resolution = scfg.Value(None, group=SPACE_GROUP, help=ub.paragraph(
        '''
        Convenience argument.
        If specified, fixes resolution of window, output, and input space.
        '''))
    window_space_scale = scfg.Value(None, alias=['window_resolution'], group=SPACE_GROUP, help=ub.paragraph(
        '''
        Change the "scale" or resolution of the video space used by the sliding
        window. Note: this modifies the GSD BEFORE the sample window has been
        selected, so the extent and resolution of the data changes. If
        specified as a numeric value then this is applied to as a scale factor.
        (E.g.  setting this to 2 is equivalent to scaling video space by 2).
        For geospatial data where each video has a "target_gsd", then this can
        be set to as an absolute by including the "GSD" suffix. (e.g. If this
        is set to "10GSD", then video space will be scaled to match).
        '''))
    input_space_scale = scfg.Value(None, alias=['space_scale', 'data_space_scale', 'input_resolution'], group=SPACE_GROUP, help=ub.paragraph(
        '''
        Change the "scale" or resolution of the sampled video space.  Note:
        this modifies the GSD AFTER the sample window has been selected, so the
        extend of the data does NOT change, but the resolution does. If
        specified as a numeric value then this is applied to as a scale factor.
        (E.g. setting this to 2 is equivalent to scaling video space by 2). For
        geospatial data where each video has a "target_gsd", then this can be
        set to as an absolute by including the "GSD" suffix. (e.g. If this is
        set to "10GSD", then video space will be scaled to match). This can
        also be set to "native" to use heterogeneous sampling.
        '''))
    output_space_scale = scfg.Value(None, alias=['target_space_scale', 'output_resolution'], group=SPACE_GROUP, help=ub.paragraph(
        '''
        Change the "scale" or resolution of the desired target
        resolution. Follows other GSD / scale semantics.
        '''))
    chip_overlap = scfg.Value(0.0, alias=['window_space_overlap', 'window_overlap'], group=SPACE_GROUP, help=ub.paragraph(
        '''
        Fraction of the spatial sliding window that will overlap.
        Only applies to training dataset when used in the data
        module.
        '''))

    dynamic_fixed_resolution = scfg.Value(None, type=str, help=ub.paragraph(
        '''
        Experimental. Added in 0.17.1
        This is a test time option. The idea is that we will modify the input,
        output, and window scale for large videos.
        example value: {'max_winspace_full_dims': [1000, 1000]}
        '''))

    ##############
    # TIME OPTIONS
    ##############

    time_steps = scfg.Value(2, alias=['time_dims'], group=TIME_GROUP, help=ub.paragraph(
        '''
        number of temporal samples (i.e. frames) per batch.
        NOTE: The default of this will change to 1 in the future.
        '''))
    time_sampling = scfg.Value('contiguous', type=str, group=TIME_GROUP, help=ub.paragraph(
        '''
        Strategy for expanding the time window across non-contiguous
        frames. Can be auto, contiguous, hard+distribute, or
        dilate_affinity
        '''))
    time_span = scfg.Value(None, group=TIME_GROUP, help=ub.paragraph(
        '''
        Roughly how much time should be between sample frames. This
        argument needs reworking.
        '''))
    time_kernel = scfg.Value(None, type=str, group=TIME_GROUP, help='Mutually exclusive with time_span.')

    ##############
    # MODE OPTIONS
    ##############

    channels = scfg.Value(None, type=str, group='sensorchan', help=ub.paragraph(
        '''
        channels to use should be SensorChanSpec coercible
        '''))
    include_sensors = scfg.Value(None, group='sensorchan', help=ub.paragraph(
        '''
        if specified can be comma separated valid sensors. NOTE:
        this should be specified via a sensorchan speci in channels
        instead
        '''))
    exclude_sensors = scfg.Value(None, type=str, group='sensorchan', help=ub.paragraph(
        '''
        comma delimited list of sensors to avoid, such as S2 or L8
        '''))
    dynamic_channels = scfg.Value(None, type=str, group='sensorchan', help=ub.paragraph(
        '''
        A YAML list to dynamically compute additional channels.
        Each item should be a dict with a key "name" for the name of the new
        channel and a key "expr" with the formula to compute the new channel
        from existing channels. Basic numpy expressions are supported.

        See :class:`DynamicChannels` for more details.
        An example that negates the red channel is:
        ``[{'name': 'negative_red', 'expr': '-red'}]``.
        '''))

    ##############
    # SIZE OPTIONS
    ##############

    select_images = scfg.Value(None, type=str, group=SELECTION_GROUP, help=ub.paragraph(
        '''
        A json query (via the jq spec) that specifies which images belong in
        the subset. Note, this is a passed as the body of the following jq
        query format string to filter valid ids '.images[] |
        select({select_images}) | .id'. Examples for this argument are as
        follows: '.id < 3' will select all image ids less than 3. '.file_name |
        test(".*png")' will select only images with file names that end with
        png.  '.file_name | test(".*png") | not' will select only images with
        file names that do not end with png. '.myattr == "foo"' will select
        only image dictionaries where the value of myattr is "foo". '.id < 3
        and (.file_name | test(".*png"))' will select only images with id less
        than 3 that are also pngs. .myattr | in({"val1": 1, "val4": 1}) will
        take images where myattr is either val1 or val4. Requires the "jq"
        python library is installed.
        '''))
    select_videos = scfg.Value(None, group=SELECTION_GROUP, help=ub.paragraph(
        '''
        A json query (via the jq spec) that specifies which videos belong in
        the subset. Note, this is a passed as the body of the following jq
        query format string to filter valid ids '.videos[] |
        select({select_images}) | .id'. Examples for this argument are as
        follows: '.name | startswith("foo")' will select only videos where the
        name starts with foo. Only applicable for dataset that contain videos.
        Requires the "jq" python library is installed.
        '''))

    # FIXME:
    # This needs to be reworked.
    # It is really the "maximum number of items per epoch".
    # Note: that is number of ITEMS, NOT number of BATCHES!
    # The number of batches is this divided by batch size
    # and the effective batch size is this divided by (batch size * accum)
    # And we really shouldn't specify the maximum here, we should just force it
    # to a specific length, and lean into sampling with replacement.
    max_epoch_length = scfg.Value(None, help=ub.paragraph(
        '''
        If specified, restricts number of ITEMS per epoch
        '''), alias=['max_items_per_epoch'])

    #######################
    # SAMPLING GRID OPTIONS
    #######################

    # TODO: add alias for set_cover_algo to be something more intuitive,
    # maybe grid_spacetime_setcover_algo, or grid_spacetime_sample_density
    # something like that...
    set_cover_algo = scfg.Value(None, group=SAMPLE_GROUP, help=ub.paragraph(
        '''
        Set cover algorithm to remove redundant gids when building space time
        targets. Options are 'approx' (a greedy solution) or 'exact' (an ILP
        solution). If None is passed, set cover is not computed. The 'exact'
        method requires the pulp package (and can be very slow so it is
        generally not recommended).
        '''), choices=[None, 'approx', 'exact'])
    use_centered_positives = scfg.Value(False, group=SAMPLE_GROUP, help=ub.paragraph(
        '''
        Use centers of annotations as window centers Only applies to training
        dataset when used in the data module.  Validation/test dataset defaults
        to False.
        '''))
    use_grid_positives = scfg.Value(True, group=SAMPLE_GROUP, help=ub.paragraph(
        '''
        Use sliding window cells that overlap with positive annotations as
        positives. Only applies to training dataset when used in the data
        module. Validation/test dataset defaults to True.
        '''))
    use_grid_negatives = scfg.Value(True, group=SAMPLE_GROUP, help=ub.paragraph(
        '''
        Use sliding window cells dont overlap with positive annotations as
        negatives. If set to "cleared", then only videos with a True "cleared"
        attribute contribute grid negatives. Only applies to training dataset
        when used in the data module. Validation/test dataset defaults to True.
        '''))
    use_grid_valid_regions = scfg.Value(True, group=SAMPLE_GROUP, help=ub.paragraph(
        '''
        If True, the initial grid will only place windows in valid regions.
        '''))
    neg_to_pos_ratio = scfg.Value(1.0, type=float, group=SAMPLE_GROUP, help=ub.paragraph(
        '''
        maximum ratio of samples with no annotations to samples with annots.
        Only applies to training dataset when used in the data module.
        Validation/test dataset defaults to zero.

        DEPRECATED: Use "balance_options" instead.
        To reproduce this with "balance_options" use:
        balance_options : [
            {attribute: contains_annotation, weights: {True: 0.5, False: 0.5}}
        ]
        '''))

    balance_options = scfg.Value(None, group=SAMPLE_GROUP, help=ub.paragraph(
        '''
        A YAML configuration that determines how to balance across discrete
        samples based on their annotation content. It should be specified as a
        list of dictionaries. Each dictionary must specify "attribute" as the
        name of the attribute to balance across. Each dictionary can optionally
        specify "weight" as a mapping from attribute values to a numeric weight
        indicating the relative importance of sampling an attribute with that
        value. A "default_weight" can be specified for attribute values that
        are not given. The order of the dictionaries matters. The first item
        will be perfectly balanced, everything else will be balanced with
        respect to the previous balancing. New in 0.17.0.
        '''))

    num_balance_trees = scfg.Value(16, group=SAMPLE_GROUP, help=ub.paragraph(
        '''
        The number of trees used to balance samples.
        This is useful only in the multi-label case where each sample may
        contain examples of multiple categories / attributes that are balanced
        over. In the case where each window can only contain one object this
        can be safely set to 1.
        '''))

    use_grid_cache = scfg.Value(True, group=SAMPLE_GROUP, help=ub.paragraph(
        '''
        If true, will cache the spacetime grid to make multiple runs quicker.
        '''))

    failed_sample_policy = scfg.Value('warn', choices=['ignore', 'raise', 'warn'], group=SAMPLE_GROUP, help=ub.paragraph(
        '''
        What to do if sampling fails, either ignore or raise an error
        '''))

    ############################
    # DATA NORMALIZATION OPTIONS
    ############################

    prenormalize_inputs = scfg.Value(None, group=NORM_GROUP, help=ub.paragraph(
        '''
        Can specified as list of dictionaries that effectively contains the
        dataset statistics to use. Details of that will be documented as the
        feature matures. See the kwcoco_dataloader.cli.coco_spectra script to help
        determine reasonable values for this. These normalizations are applied
        at the dataloader getitem level. This should be specified as a list of
        dictionaries each containing: * mean: * std: * min: * max: As well as
        the Modality to which the normalization applies, e.g.: * domain *
        channels * sensor If set to True, then we try to automatically compute
        these values. New in 0.4.3.
        '''))
    normalize_perframe = scfg.Value(False, group=NORM_GROUP, help=ub.paragraph(
        '''
        DEPRECATED. Use robust_normalize instead.
        Applies a pre-normalizaiton that normalizes each frame by itself. This
        is not recommended unless you have a larger chip size because there
        needs to be enough data within a frame for the normalization to be
        effective.
        '''))
    normalize_peritem = scfg.Value(None, group=NORM_GROUP, help=ub.paragraph(
        '''
        DEPRECATED. Use robust_normalize instead.
        Applies a pre-normalization across all frames in an item.  This
        preserves relative temporal variations.
        If True all channels are normalized this way with a default robust
        normalizer. Can be specified as a ChannelSpec, and in this case the
        default robust normalizer will only be applied to these channels.
        '''))

    robust_normalize = scfg.Value(None, group=NORM_GROUP, help=ub.paragraph(
        '''
        A YAML list of robust normalization parameters that provides
        fine-grained control over how groups of sensor / channel items are
        normalized within a batch.  For full details see
        :func:`RobustNormalizer.coerce`.

        An example to robustly normalize per-frame "r" and "g" values in
        different ways is:
        ``{'separate_time': True,
            'groups': [
                {'channels': 'r', 'mode': 'linear'},
                {'channels': 'g', 'mode': 'sigmoid', 'high': 0.91, 'low': 0.33}]}``
        '''))

    ###################
    # WEIGHTING OPTIONS
    ###################

    ignore_dilate = scfg.Value(0, group=WEIGHT_GROUP, help='Dilation applied to ignore masks.')
    weight_dilate = scfg.Value(0, group=WEIGHT_GROUP, help='Dilation applied to weight masks.')
    absolute_weighting = scfg.Value(False, group=WEIGHT_GROUP, help=ub.paragraph(
        '''
        if True allow weights to be larger than 1, otherwise item
        weights are rescaled.
        '''))
    min_spacetime_weight = scfg.Value(0.9, group=WEIGHT_GROUP, help=ub.paragraph(
        '''
        Minimum space-time dilation weight. Used in conjunction with
        '''))
    upweight_centers = scfg.Value(True, group=WEIGHT_GROUP, help=ub.paragraph(
        '''
        Applies a weighting such that the center of the frame incurs
        more loss.
        '''))
    upweight_time = scfg.Value(None, group=WEIGHT_GROUP, help=ub.paragraph(
        '''
        A number between 0.0 and 1.0 representing where to upweight
        time the most (1.0 is last frame 0.0 is the first frame).
        '''))
    dist_weights = scfg.Value(0, group=WEIGHT_GROUP, help=ub.paragraph(
        '''
        To use distance-transform based weights on annotations or not
        '''))
    balance_areas = scfg.Value(False, group=WEIGHT_GROUP, help=ub.paragraph(
        '''
        if True balance the weight of small and large polygons
        '''))

    default_class_behavior = scfg.Value('background', group=WEIGHT_GROUP, help=ub.paragraph(
        '''
        Toggles between new and old behavior for what value to use for the
        class truth index raster. Can be "background" for old behavior, which
        ensures that there is a predictable background class and initializes
        non-annotated areas with that value.  Alternatively, can be "ignore"
        which fills the index truth with a negative value indicating that those
        regions should be ignored.  The default of this param may change in the
        future.
        '''))

    ##################################
    # DYNAMIC FILTER / MASKING OPTIONS
    ##################################

    use_cloudmask = scfg.Value(None, group=FILTER_GROUP, help=ub.paragraph(
        '''
        Allow the dataloader to use the quality band to skip frames.
        DEPRECATED: set quality_threshold=0 to disable the cloudmask. Set to a
        positive value to use it, up to that threshold.
        '''))
    quality_threshold = scfg.Value(0.0, group=FILTER_GROUP, help=ub.paragraph(
        '''
        The minimum fraction of usable pixels required in a frame sample. If a
        frame has fewer than this fraction of usable pixels (i.e. not clouds or
        other quality flags), it is marked for resampling as a "bad" frame.
        '''))
    mask_low_quality = scfg.Value(False, group=FILTER_GROUP, help=ub.paragraph(
        '''
        if True, mask low quality pixels with nans
        '''))
    mask_nan_bands = scfg.Value('', group=FILTER_GROUP, help=ub.paragraph(
        '''
        Channels that propagate their nans to other bands / streams.
        This should be FusedChannelSpec coercible.
        '''))
    mask_samecolor_method = scfg.Value(None, group=FILTER_GROUP, help=ub.paragraph(
        '''
        If enabled, set as method to use for
        SAMECOLOR_QUALITY_HEURISTIC. Can be histogram or region.
        '''))
    mask_samecolor_bands = scfg.Value('red', group=FILTER_GROUP, help=ub.paragraph(
        '''
        Channels to use for SAMECOLOR_QUALITY_HEURISTIC. This should be
        FusedChannelSpec coercible.
        '''))
    mask_samecolor_values = scfg.Value(0, group=FILTER_GROUP, help=ub.paragraph(
        '''
        List of values to use for SAMECOLOR_QUALITY_HEURISTIC.
        Can be an integer or list of integers
        '''))
    force_bad_frames = scfg.Value(False, group=FILTER_GROUP, help=ub.paragraph(
        '''
        if True, force loading, even if data is nan / missing
        '''))
    observable_threshold = scfg.Value(0.0, group=FILTER_GROUP, help=ub.paragraph(
        '''
        The minimum fraction of non-nan pixels required in a frame sample. If a
        frame has fewer than this fraction of usable pixels (i.e. not clouds or
        other quality flags), it is marked for resampling as a "bad" frame.
        '''))

    downweight_nan_regions = scfg.Value(True, group=FILTER_GROUP, help=ub.paragraph(
        '''
        if True, unobservable (i.e. nan) pixels are downweighted
        '''))

    resample_invalid_frames = scfg.Value(3, alias=['resample_max_tries'], group=FILTER_GROUP, help=ub.paragraph(
        '''
        Number of attempts to resample any frame marked as invalid via quality
        or nodata checks.
        '''))

    ######################
    # AUGMENTATION OPTIONS
    ######################
    ### TODO: these should likely become a nested jsonargparse
    ### style config for a more general "augmentation scheme".

    # See: ./data_augment.py

    augment_space_shift_rate = scfg.Value(0.9, group=AUGMENTATION_GROUP, help=ub.paragraph(
        '''
        In fit mode, perform translation augmentations in this
        fraction of batch items.
        '''))
    augment_space_xflip = scfg.Value(True, group=AUGMENTATION_GROUP, help=ub.paragraph(
        '''
        In fit mode, if true, perform random x-flips
        '''))
    augment_space_yflip = scfg.Value(True, group=AUGMENTATION_GROUP, help=ub.paragraph(
        '''
        In fit mode, if true, perform random y-flips
        '''))
    augment_space_rot = scfg.Value(True, group=AUGMENTATION_GROUP, help=ub.paragraph(
        '''
        In fit mode, if true, perform random 90 degree rotations
        '''))
    augment_time_resample_rate = scfg.Value(0.8, group=AUGMENTATION_GROUP, help=ub.paragraph(
        '''
        In fit mode, perform temporal jitter this fraction of batch items.
        '''))
    temporal_dropout_rate = scfg.Value(1.0, type=float, group=AUGMENTATION_GROUP, help=ub.paragraph(
        '''
        Drops frames in a fraction of batch items.
        '''))
    temporal_dropout = scfg.Value(0.0, type=float, group=AUGMENTATION_GROUP, help=ub.paragraph(
        '''
        Given that a batch item is selected for temporal dropout,
        this is the probability that each frame is dropped out. The
        main frame is never removed.
        '''))
    modality_dropout_rate = scfg.Value(0.0, type=float, group=AUGMENTATION_GROUP, help=ub.paragraph(
        '''
        The fraction of batch-items modality dropout is applied to.
        '''))
    modality_dropout = scfg.Value(0.0, type=float, group=AUGMENTATION_GROUP, help=ub.paragraph(
        '''
        Drops late-fused modalities in each frame with this
        probability, except if the frame only has one modality left.
        '''))

    # TODO: specify channels that are allowed to be dropped out?
    channel_dropout_rate = scfg.Value(0.0, type=float, group=AUGMENTATION_GROUP, help=ub.paragraph(
        '''
        The fraction of batch-items channel dropout is applied to.
        '''))
    channel_dropout = scfg.Value(0.0, type=float, group=AUGMENTATION_GROUP, help=ub.paragraph(
        '''
        Drops early-fused channels within each modality with this probability
        except if it is the last channel within a modality.
        '''))

    # TODO:
    # 'metadata_dropout': scfg.Value(0.0, type=float, help=ub.paragraph(
    #     '''
    #     Drops extra metadata provided to the model for positional encodings.
    #     '''), group=AUGMENTATION_GROUP),

    # TODO:
    # 'augment_scale': scfg.Value(0.0, type=float, help=ub.paragraph(
    #     '''
    #     Train at multiple scales.
    #     '''), group=AUGMENTATION_GROUP),

    reseed_fit_random_generators = scfg.Value(True, type=bool, group=AUGMENTATION_GROUP, help=ub.paragraph(
        '''
        This option forces the dataloader random number generator to reseed
        itself, effectively ignoring any global seed in non- test mode. In test
        mode, this has no effect. The reason this defaults to True is because
        of our balanced sampling approach, where the index of a sample passed
        to getitem is ignored and we randomly return an item according to the
        balanced distribution. This relies on randomness and if this was set to
        False dataloader clones for ddp or multiple workers would generate the
        same sequence of data regardless of split indexes.
        '''))

    ##############
    # MISC OPTIONS
    ##############
    reduce_item_size = scfg.Value(False, group=MISC_GROUP, help=ub.paragraph(
        '''
        Introduced as a CIFAR optimization to prevent generated items from
        containing more information than is necessary. In the future this will
        likely be restructured so items are produced with the minimal amount of
        information by default, and there must be a request to grab the
        enriched variant.
        '''))

    output_type = scfg.Value('heterogeneous', help=ub.paragraph(
        '''
        Can be heterogeneous, homogeneous, or rgb. This is a performance
        parameter that allows implementation assumptions to be made.
        Experimental in 0.18.4
        '''))

    requested_tasks = scfg.Value('auto', help=ub.paragraph(
        '''
        If auto, uses heuristics to define what task targets are generated.
        Otherwise, can be a YAML dict that updates the defaults.
        '''))

    def __post_init__(self):
        if isinstance(self['exclude_sensors'], str):
            self['exclude_sensors'] = [s.strip() for s in self['exclude_sensors'].split(',')]
        self['time_steps'] = int(self['time_steps'])

        if self['chip_dims'] is not None:
            arg = self['chip_dims']
            if isinstance(arg, str):
                if ',' in arg:
                    p1, p2 = arg.split(',')
                    arg = [int(p1), int(p2)]
            if isinstance(arg, list):
                assert len(arg) == 2, 'arglist should be len 2'
                arg = [int(arg[0]), int(arg[1])]
            if isinstance(arg, int):
                arg = [arg, arg]
            self['chip_dims'] = arg

        if self['mask_samecolor_method'] == 'None':
            self['mask_samecolor_method'] = None

        if self['fixed_resolution'] not in {None, 'None', 'none', 'null'}:
            self['window_space_scale'] = self['fixed_resolution']
            self['input_space_scale'] = self['fixed_resolution']
            self['output_space_scale'] = self['fixed_resolution']

        if self['input_space_scale'] in {None, 'None', 'window'}:
            self['input_space_scale'] = self['window_space_scale']

        if self['output_space_scale'] is {None, 'None', 'input'}:
            self['output_space_scale'] = self['input_space_scale']

        if self['output_space_scale'] == 'window':
            self['output_space_scale'] = self['window_space_scale']

        if self['time_sampling'] == 'auto':
            self['time_sampling'] = 'hard+distribute'

        if self['use_cloudmask'] is not None:
            if not self['use_cloudmask']:
                self['quality_threshold'] = 0


class TruthMixin:
    """
    Methods related to drawing truth rasters / training objectives

    ComamndLine:
        LINE_PROFILE=1 xdoctest -m kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset TruthMixin:0 --bench

    Example:
        >>> # xdoctest: +REQUIRES(--bench)
        >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import KWCocoVideoDataset
        >>> import ndsampler
        >>> import kwcoco_dataloader
        >>> coco_dset = kwcoco.CocoDataset.demo('vidshapes2', num_frames=10)
        >>> sampler = ndsampler.CocoSampler(coco_dset)
        >>> self = KWCocoVideoDataset(sampler, mode="fit", time_dims=4, window_dims=(196, 196),
        >>>                           channels='r|g|b', neg_to_pos_ratio=0)
        >>> for index in ub.ProgIter(range(1000)):
        >>>     self.getitem(index)
    """

    def _prepare_truth_info(self, final_gids, gid_to_sample, num_frames, target, target_):
        """
        Helper used to construct information about the truth before we start
        constructing the frames. This handles contextual relabeling of classes
        (i.e. if all frames show post construction relabel it as background).

        note: `target` is the original input value (guaranteed to not be
        modified from user input) and `target_` is the resolved variant we (may
        have modified).
        """
        # build up info about the tracks
        dset = self.sampler.dset
        gid_to_dets: Dict[int, kwimage.Detections] = {}
        tid_to_aids = ub.ddict(list)
        tid_to_cids = ub.ddict(list)
        # tid_to_catnames = ub.ddict(list)
        for gid in final_gids:
            stream_sample = gid_to_sample[gid]
            frame_dets = None
            for mode_sample in stream_sample.values():
                if 'annots' in mode_sample:
                    frame_dets: kwimage.Detections = mode_sample['annots']['frame_dets'][0]
                    break
            if frame_dets is None:
                raise AssertionError(ub.paragraph(
                    f'''
                    Did not sample correctly.
                    Please send this info to jon.crall@kitware.com:
                    {dset=!r}
                    {gid=!r}
                    {target=!r}
                    {target_=!r}
                    '''
                ))
            # The return detections will live in the "input/data" space
            gid_to_dets[gid] = frame_dets

        for gid, frame_dets in gid_to_dets.items():
            aids = frame_dets.data['aids']
            cids = frame_dets.data['cids']
            frame_annots = dset.annots(aids)
            tids = frame_annots.lookup('track_id', None)
            frame_dets.data['tids'] = tids
            frame_dets.data['weights'] = frame_annots.lookup('weight', 1.0)

            for tid, aid, cid in zip(tids, aids, cids):
                tid_to_aids[tid].append(aid)
                tid_to_cids[tid].append(cid)

        tid_to_frame_cids = ub.ddict(list)
        for gid, frame_dets in gid_to_dets.items():
            cids = frame_dets.data['cids']
            tids = frame_dets.data['tids']
            frame_tid_to_cid = ub.dzip(tids, cids)
            for tid in tid_to_aids.keys():
                cid = frame_tid_to_cid.get(tid, None)
                tid_to_frame_cids[tid].append(cid)

        # TODO: be more efficient at this
        tid_to_frame_cnames = ub.map_vals(
            lambda cids: list(ub.take(self.classes.id_to_node, cids, None)),
            tid_to_frame_cids
        )

        task_tid_to_cnames = {
            'saliency': {},
            'class': {},
        }
        for tid, cnames in tid_to_frame_cnames.items():
            task_tid_to_cnames['class'][tid] = heuristics.hack_track_categories(cnames, 'class')
            task_tid_to_cnames['saliency'][tid] = heuristics.hack_track_categories(cnames, 'saliency')

        if self.config['upweight_centers'] or self.config['upweight_time'] is not None:
            if self.config['upweight_time'] is None:
                upweight_time = 0.5
            else:
                upweight_time = self.config['upweight_time']

            # Learn more from the center of the space-time patch
            time_weights = util_kwarray.biased_1d_weights(upweight_time, num_frames)

            time_weights = time_weights / time_weights.max()
            time_weights = time_weights.clip(0, 1)
            time_weights = np.maximum(time_weights, self.config['min_spacetime_weight'])
        else:
            time_weights = 1

        truth_info = {
            'task_tid_to_cnames': task_tid_to_cnames,
            'gid_to_dets': gid_to_dets,
            'time_weights': time_weights,
            'dist_weights': target_.get('dist_weights', False),
        }
        return truth_info

    def _populate_frame_labels(self, frame_item, gid, output_dsize, time_idx,
                               mode_to_invalid_mask, resolution_info,
                               truth_info, meta_info):
        """
        Enrich a ``frame_item`` with rasterized truth-labels.

        No return value ``frame_item`` is modified inplace.

        Helper function to populate truth labels for a frame in a video
        sequence.

        TODO:
            - [ ] Reduce number of input parameters such that this function is
                  ammenable to a MWE doctest. This was factored out of the
                  original getitem, and could use work to reduce the number of
                  input params.
        """

        common_input_scale = resolution_info['common_input_scale']
        common_output_scale = resolution_info['common_output_scale']

        # The frame detections will be in a scaled videos space the
        # constant scale case.
        # TODO: will need special handling for "native" resolutions on
        # a per-mode / frame basis, we will need the concept of an
        # annotation window (where ndsampler lets us assume the corners
        # of each window are in correspondence)

        task_tid_to_cnames = truth_info['task_tid_to_cnames']
        gid_to_dets = truth_info['gid_to_dets']

        wants_saliency = self.requested_tasks['saliency']
        wants_class = self.requested_tasks['class']
        wants_change = self.requested_tasks['change']
        wants_boxes = self.requested_tasks['boxes']
        wants_nonlocal_class = self.requested_tasks['nonlocal_class']

        wants_class_sseg = wants_class or wants_change
        wants_saliency_sseg = wants_saliency
        wants_any_sseg = wants_saliency_sseg or wants_class_sseg
        wants_any_localization = wants_boxes or wants_any_sseg

        input_is_native = (isinstance(common_input_scale, str) and common_input_scale == 'native')
        output_is_native = (isinstance(common_output_scale, str) and common_output_scale == 'native')

        frame_dets = gid_to_dets[gid]
        if frame_dets is None:
            raise AssertionError('frame_dets = {!r}'.format(frame_dets))

        if wants_any_localization:
            # As of ndsampler >= 0.7.1 the dets are sampled in the input space
            if input_is_native:
                if output_is_native:
                    # Both scales are native, use detections as-is.
                    dets = frame_dets.copy()
                else:
                    # Input scale is native, but output scale is given,
                    # Need to resize. We enriched the dets with metadata
                    # to do this earlier.
                    annot_input_dsize = frame_dets.meta['input_dims'][::-1]
                    dets_scale = output_dsize / annot_input_dsize
                    dets = frame_dets.scale(dets_scale, inplace=True)
            else:
                if output_is_native:
                    raise NotImplementedError(
                        'input scale is constant and output scale is native. '
                        'no logic for this case yet.'
                    )
                else:
                    # Simple case where input/output scales are constant
                    dets_scale = common_output_scale / common_input_scale
                    dets = frame_dets.scale(dets_scale, inplace=True)
        else:
            dets = frame_dets

        # Create truth masks
        if self.config['default_class_behavior'] == 'background':
            default_class_index = self.bg_idx
        else:
            default_class_index = self.ignore_index

        frame_target_shape = output_dsize[::-1]
        space_shape = frame_target_shape

        if wants_class_sseg:
            frame_cidxs = np.full(space_shape, dtype=np.int32,
                                  fill_value=default_class_index)

        # A "Salient" class is anything that is a foreground class
        task_target_ohe = {}
        task_target_ignore = {}
        task_target_weight = {}

        # Rasterize frame targets into semantic segmentation masks
        ann_aids    = dets.data['aids']
        ann_cids    = dets.data['cids']
        ann_tids    = dets.data['tids']
        ann_weights = dets.data['weights']
        ann_boxes   = dets.data['boxes']

        if wants_any_sseg:
            ann_polys = dets.data['segmentations'].to_polygon_list()

            missing_poly_flags = [poly is None for poly in ann_polys]
            if any(missing_poly_flags):
                # Note: this will convert boxes into box-polygons, which is
                # generally a non-optimial segmentation target objective.  We
                # might do better by having a "policy" to control implicit
                # conversion of boxes to segmentation masks.  we could use an
                # ellipse and downweight edges, which might be more suitable
                # for general use-cases. We may also want to downweight the
                # entire polygon itself as it is a weak segmentation.
                missing_idxs = np.where(missing_poly_flags)[0]
                _box_polys = ann_boxes[missing_idxs].to_polygons()
                for idx, _poly in zip(missing_idxs, _box_polys):
                    ann_polys.data[idx] = _poly

            # Associate weights with polygons
            for poly, weight in zip(ann_polys, ann_weights):
                if weight is None:
                    weight = 1.0
                if poly is not None:
                    poly.meta['weight'] = weight

        if wants_any_localization:
            frame_box = kwimage.Box.from_dsize(space_shape[::-1])
            frame_box = frame_box.to_shapely()

        if wants_nonlocal_class:
            # Create an indicator vector that just says if the category appears
            # in the frame or not. This will help support simple classification
            # networks.
            nonlocal_class_ohe = np.zeros((self.num_predictable_classes,), dtype=np.uint8)
            for cid in ann_cids:
                cname = self.classes.id_to_node[cid]
                if cname in self.predictable_classes.node_to_idx:
                    cidx = self.predictable_classes.node_to_idx[cname]
                    nonlocal_class_ohe[cidx] = 1
            frame_item['nonlocal_class_ohe'] = nonlocal_class_ohe

        # Note: it is important to respect class indexes, ids, and
        # name mappings
        if wants_boxes:
            ann_ltrb = ann_boxes.to_ltrb().data

            box_labels = {
                'box_ltrb': [],
                # 'box_tids': [],
                'box_cidxs': [],
                'box_class_weights': [],
                'box_saliency_weights': [],
            }
            # Do we want saliency boxes and class boxes?
            for ltrb, cid, tid in zip(ann_ltrb, ann_cids, ann_tids):
                new_salient_catname = task_tid_to_cnames['saliency'][tid][time_idx]
                new_class_catname = task_tid_to_cnames['class'][tid][time_idx]
                new_class_cidx = self.classes.node_to_idx[new_class_catname]
                box_labels['box_ltrb'].append(ltrb)
                # box_labels['box_tids'].append(-1 if tid is None else tid)
                box_labels['box_cidxs'].append(new_class_cidx)
                box_labels['box_saliency_weights'].append(
                    float(new_salient_catname in self.salient_classes))
                box_labels['box_class_weights'].append(
                    float(new_class_catname in self.class_foreground_classes))
            box_labels['box_ltrb'] = np.array(box_labels['box_ltrb']).astype(np.float32)
            # box_labels['box_tids'] = np.array(box_labels['box_tids']).astype(np.int64)
            box_labels['box_cidxs'] = np.array(box_labels['box_cidxs']).astype(np.int64)
            box_labels['box_class_weights'] = np.array(box_labels['box_class_weights']).astype(np.float32)
            box_labels['box_saliency_weights'] = np.array(box_labels['box_saliency_weights']).astype(np.float32)
            frame_item.update(box_labels)

        if wants_saliency:
            ### Build single frame SALIENCY target labels and weights
            task_target_ohe['saliency'] = np.zeros(space_shape, dtype=np.uint8)
            task_target_ignore['saliency'] = np.zeros(space_shape, dtype=np.uint8)
            task_target_weight['saliency'] = np.empty(space_shape, dtype=np.float32)

            # Group polygons into foreground / background for the saliency task
            saliency_sseg_groups = {
                'foreground': [],
                'background': [],
                'ignore': [],
            }
            for poly, tid in zip(ann_polys, ann_tids):
                new_salient_catname = task_tid_to_cnames['saliency'][tid][time_idx]
                if new_salient_catname in self.salient_classes:
                    saliency_sseg_groups['foreground'].append(poly)
                elif new_salient_catname in self.salient_ignore_classes:
                    saliency_sseg_groups['ignore'].append(poly)
                elif new_salient_catname in self.non_salient_classes:
                    saliency_sseg_groups['background'].append(poly)
                else:
                    raise AssertionError

            balance_areas = self.config['balance_areas']

            if balance_areas:
                import shapely
                # num_fg_polys = len(saliency_sseg_groups['foreground'])
                big_poly_fg = unary_union([p.to_shapely() for p in saliency_sseg_groups['foreground']])
                big_poly_ignore = unary_union([p.to_shapely() for p in saliency_sseg_groups['ignore']])
                try:
                    big_poly_bg = (frame_box - big_poly_fg) - big_poly_ignore
                except shapely.errors.GEOSException:
                    print('Warning: cannot balance areas, shapely error')
                    task_target_weight['saliency'][:] = 1
                    balance_areas = False
                else:
                    #unit_area_share = fg_polys.area / len(fg_polys)
                    total_area = frame_box.area
                    bg_cover_frac = big_poly_bg.area / (total_area + 1)
                    # fg_cover_frac = big_poly_fg.area / (total_area + 1)
                    bg_weight_share = (1 - bg_cover_frac)
                    task_target_weight['saliency'][:] = bg_weight_share ** 0.5
            else:
                task_target_weight['saliency'][:] = 1

            for poly in saliency_sseg_groups['background']:
                weight = poly.meta['weight']
                if balance_areas:
                    area_weight = (1 - (poly.area / (total_area + 1)))
                    weight = weight * area_weight
                if weight != 1:
                    poly.fill(task_target_weight['saliency'], value=weight, assert_inplace=True)

            for poly in saliency_sseg_groups['foreground']:
                task_target_ohe['saliency'] = poly.fill(task_target_ohe['saliency'], value=1, assert_inplace=True)
                weight = poly.meta['weight']

                if balance_areas:
                    area_weight = (1 - (poly.area / (total_area + 1)))
                    weight = weight * area_weight
                if weight != 1:
                    poly.fill(task_target_weight['saliency'], value=weight, assert_inplace=True)

            if truth_info['dist_weights']:
                # New feature where we encode that we care much more about
                # segmenting the inside of the object than the outside.
                # Effectively boundaries become uncertain.
                # ---
                # handle distance weight transform in one go, might want to
                # bring in prior code wrt to balance areas and polygon specific
                # weights here so overlapping polygons dont clobber each other.
                # (i.e. the order of the polygons matter when they overlap, and this
                # logic works around that to some degree. The real way to
                # handle this would be a layering system.)
                dist, poly_mask = util_kwimage.multiple_polygon_distance_transform_weighting(
                    saliency_sseg_groups['foreground'], shape=space_shape)
                max_dist = dist.max()
                if max_dist > 0:
                    dist_weight = dist / max_dist
                    weight_mask = dist_weight + (1 - poly_mask)
                    task_target_weight['saliency'] = task_target_weight['saliency'] * weight_mask

            for poly in saliency_sseg_groups['ignore']:
                #poly.fill(task_target_ohe['saliency'], value=1, assert_inplace=True)
                poly.fill(task_target_ignore['saliency'], value=1, assert_inplace=True)

            if not self.config['absolute_weighting']:
                max_weight = task_target_weight['saliency'].max()
                if max_weight > 0:
                    task_target_weight['saliency'] /= max_weight

        if wants_class_sseg:
            ### Build single frame CLASS target labels and weights
            task_target_ohe['class'] = np.zeros((self.num_predictable_classes,) + space_shape, dtype=np.uint8)
            task_target_ignore['class'] = np.zeros(space_shape, dtype=np.uint8)
            task_target_weight['class'] = np.ones(space_shape, dtype=np.float32)

            # Group polygons into foreground / background for the class task
            class_sseg_groups = {
                'foreground': [],
                'background': [],
                'ignore': [],
                'undistinguished': [],
            }
            for poly, cid, tid in zip(ann_polys, ann_cids, ann_tids):
                new_class_catname = task_tid_to_cnames['class'][tid][time_idx]
                new_class_cidx = self.classes.node_to_idx[new_class_catname]
                # orig_cidx = self.classes.id_to_idx[cid]
                poly.meta['new_class_cidx'] = new_class_cidx
                # poly.meta['orig_cidx'] = orig_cidx
                if new_class_catname in self.ignore_classes:
                    class_sseg_groups['ignore'].append(poly)
                elif new_class_catname in self.class_foreground_classes.intersection(self.predictable_classes):
                    class_sseg_groups['foreground'].append(poly)
                elif new_class_catname in self.background_classes.intersection(self.predictable_classes):
                    class_sseg_groups['background'].append(poly)
                elif new_class_catname in self.undistinguished_classes.intersection(self.predictable_classes):
                    class_sseg_groups['undistinguished'].append(poly)

            balance_areas = self.config['balance_areas']

            if balance_areas:
                big_poly_fg = unary_union([p.to_shapely() for p in class_sseg_groups['foreground']])
                big_poly_ignore = unary_union([p.to_shapely() for p in class_sseg_groups['ignore']])
                big_poly_undistinguished = unary_union([p.to_shapely() for p in class_sseg_groups['undistinguished']])
                big_poly_bg = ((frame_box - big_poly_fg) - big_poly_ignore) - big_poly_undistinguished
                total_area = frame_box.area
                bg_cover_frac = big_poly_bg.area / (total_area + 1)
                # fg_cover_frac = big_poly_fg.area / (total_area + 1)
                bg_weight_share = (1 - bg_cover_frac)
                task_target_weight['class'][:] = bg_weight_share ** 0.5
            else:
                task_target_weight['class'][:] = 1

            for poly in class_sseg_groups['ignore']:
                poly.fill(task_target_ignore['class'], value=1, assert_inplace=True)

            for poly in class_sseg_groups['background']:
                idx = self.dataset_class_idx_to_predictable_class_idx[poly.meta['new_class_cidx']]
                poly.fill(task_target_ohe['class'][idx], value=1, assert_inplace=True)

            for poly in class_sseg_groups['undistinguished']:
                task_target_ignore['class'] = poly.fill(task_target_ignore['class'], value=1, assert_inplace=True)
                idx = self.dataset_class_idx_to_predictable_class_idx[poly.meta['new_class_cidx']]
                poly.fill(task_target_ohe['class'][idx], value=1, assert_inplace=True)

            for poly in class_sseg_groups['foreground']:
                idx = self.dataset_class_idx_to_predictable_class_idx[poly.meta['new_class_cidx']]
                poly.fill(task_target_ohe['class'][idx], value=1, assert_inplace=True)
                weight = poly.meta['weight']

                if balance_areas:
                    area_weight = (1 - (poly.area / (total_area + 1)))
                    weight = weight * area_weight

                if weight != 1:
                    poly.fill(task_target_weight['class'], value=weight, assert_inplace=True)

            if truth_info['dist_weights']:
                # New feature where we encode that we care much more about
                # segmenting the inside of the object than the outside.
                # Effectively boundaries become uncertain.
                # ---
                # handle distance weight transform in one go, might want to
                # bring in prior code wrt to balance areas and polygon specific
                # weights here so overlapping polygons dont clobber each other.
                # (i.e. the order of the polygons matter when they overlap, and this
                # logic works around that to some degree. The real way to
                # handle this would be a layering system.)
                dist, poly_mask = util_kwimage.multiple_polygon_distance_transform_weighting(
                    class_sseg_groups['foreground'], shape=space_shape)
                max_dist = dist.max()
                if max_dist > 0:
                    dist_weight = dist / max_dist
                    weight_mask = dist_weight + (1 - poly_mask)
                    task_target_weight['class'] = task_target_weight['class'] * weight_mask

            if not self.config['absolute_weighting']:
                max_weight = task_target_weight['class'].max()
                if max_weight > 0:
                    task_target_weight['class'] /= max_weight

        generic_frame_weight = self._build_generic_frame_weights(output_dsize,
                                                                 mode_to_invalid_mask,
                                                                 meta_info,
                                                                 time_idx)

        # Dilate ignore masks (dont care about the surrounding area # either)
        # frame_saliency = kwimage.morphology(frame_saliency, 'dilate', kernel=ignore_dilate)
        if self.config['ignore_dilate'] > 0:
            for k, v in task_target_ignore.items():
                task_target_ignore[k] = kwimage.morphology(v, 'dilate', kernel=self.config['ignore_dilate'])

        if self.config['weight_dilate'] > 0:
            for k, v in task_target_weight.items():
                task_target_weight[k] = kwimage.morphology(v, 'dilate', kernel=self.config['weight_dilate'])

        frame_item['ann_aids'] = ann_aids
        if wants_class_sseg:
            # Postprocess (Dilate?) the truth map
            for cidx, class_map in enumerate(task_target_ohe['class']):
                # class_map = kwimage.morphology(class_map, 'dilate', kernel=5)
                frame_cidxs[class_map > 0] = cidx

            task_frame_weight = (
                (1 - task_target_ignore['class']) *
                task_target_weight['class'] *
                generic_frame_weight
            )
            # TODO: no need to pass class-cidxs if class-ohe is present.
            # TODO: add metadata to the frame item to indicate the channel
            # ordering of each dimension (or used xarray / named tensors when
            # they become supported)
            frame_item['class_idxs'] = frame_cidxs
            frame_item['class_ohe'] = einops.rearrange(task_target_ohe['class'], 'c h w -> h w c')
            frame_item['class_weights'] = np.clip(task_frame_weight, 0, None)

        if wants_saliency_sseg:
            task_frame_weight = (
                (1 - task_target_ignore['saliency']) *
                task_target_weight['saliency'] *
                generic_frame_weight
            )
            frame_item['saliency'] = task_target_ohe['saliency']
            frame_item['saliency_weights'] = np.clip(task_frame_weight, 0, None)

        wants_outputs = self.requested_tasks['outputs']
        if wants_outputs:
            frame_item['output_weights'] = generic_frame_weight


class GetItemMixin(TruthMixin):
    """
    This mixin defines what is needed for the getitem method.
    """

    def _prepare_meta_info(self, num_frames):
        if self.config['upweight_centers'] or self.config['upweight_time'] is not None:
            if self.config['upweight_time'] is None:
                upweight_time = 0.5
            else:
                upweight_time = self.config['upweight_time']

            # Learn more from the center of the space-time patch
            time_weights = util_kwarray.biased_1d_weights(upweight_time, num_frames)

            time_weights = time_weights / time_weights.max()
            time_weights = time_weights.clip(0, 1)
            time_weights = np.maximum(time_weights, self.config['min_spacetime_weight'])
        else:
            time_weights = 1
        meta_info = {
            'time_weights': time_weights,
        }
        return meta_info

    def _build_frame_items(self, final_gids, gid_to_sample,
                           truth_info, meta_info, resolution_info):
        """
        Returns:
            List[Dict]:
                A dictionary for each frame containing metadata, input tensors,
                and (optionally) truth tensors for the frame.
        """

        common_outspace_box = resolution_info['common_outspace_box']
        vidspace_dsize = resolution_info['vidspace_dsize']
        vidspace_box = resolution_info['vidspace_box']
        video_dsize = resolution_info['video_dsize']

        coco_dset = self.sampler.dset
        # TODO: handle all augmentation before we construct any labels
        frame_items = []
        for time_idx, gid in enumerate(final_gids):
            img = coco_dset.index.imgs[gid]

            stream_sample = gid_to_sample[gid]
            assert len(stream_sample) > 0, 'should have at least one stream'

            # Collect image data from all modes within this frame
            mode_to_imdata = {}
            mode_to_invalid_mask = {}
            mode_to_dsize = {}
            for mode_key, mode_sample in stream_sample.items():

                mode_imdata = mode_sample['im'][0]
                mode_invalid_mask = mode_sample.get('invalid_mask', None)
                if mode_invalid_mask is not None:
                    mode_invalid_mask = mode_invalid_mask[0]

                mode_imdata = np.asarray(mode_imdata, dtype=np.float32)
                # ensure channel dim is not squeezed
                mode_hwc = kwarray.atleast_nd(mode_imdata, 3)
                # rearrange image axes for pytorch
                mode_chw = einops.rearrange(mode_hwc, 'h w c -> c h w')
                mode_to_imdata[mode_key] = mode_chw
                mode_to_invalid_mask[mode_key] = mode_invalid_mask
                h, w = mode_hwc.shape[0:2]
                mode_to_dsize[mode_key] = (w, h)

            # For each frame we need to choose a resolution for the truth.
            # Using the maximum resolution mode should be decent choice.
            # We could choose this to be arbitrary or independent of the input
            # dimensions, but it makes sense to pin it to the input data
            # in most cases.
            if common_outspace_box is None:
                # In the native case, we use the size of the largest mode for
                # each frame.
                max_mode_dsize = np.array(max(mode_to_dsize.values(), key=np.prod))
                # Compute the scale factor for this frame wrt video space
                scale_inspace_from_vid = max_mode_dsize / vidspace_dsize
                frame_outspace_box = vidspace_box.scale(scale_inspace_from_vid).quantize(inplace=True)
            else:
                frame_outspace_box = common_outspace_box

            output_dsize = frame_outspace_box.dsize
            # output_dsize = np.array(output_dsize)

            dt_captured = img.get('date_captured', None)
            if dt_captured:
                dt_captured = util_time.coerce_datetime(dt_captured)
                timestamp = dt_captured.timestamp()
            else:
                timestamp = np.nan

            sensor = img.get('sensor_coarse', img.get('sensor', '*'))

            frame_item = {
                'gid': gid,
                'date_captured': img.get('date_captured', ''),
                'timestamp': timestamp,
                'time_index': time_idx,
                'sensor': sensor,
                'modes': mode_to_imdata,
                'change': None,
                'class_idxs_ignore_index': self.ignore_index,
                'class_idxs': None,
                'class_ohe': None,
                'saliency': None,
                'change_weights': None,
                'class_weights': None,
                'saliency_weights': None,
            }

            output_dims = output_dsize[::-1]  # the size we want to predict
            frame_item['output_dims'] = output_dims

            if not self.config['reduce_item_size']:
                scale_outspace_from_vid = output_dsize / np.array(vidspace_dsize)
                # The size of the larger image this output is expected to be
                # embedded in.
                outimg_dsize = video_dsize * scale_outspace_from_vid
                outimg_box = kwimage.Box.from_dsize(outimg_dsize).quantize(inplace=True)
                frame_item.update({
                    # Could group these into head and input/head specific dictionaries?
                    # info for how to construct the output.
                    'change_output_dims': None if time_idx == 0 else output_dims,
                    'class_output_dims': output_dims,
                    'saliency_output_dims': output_dims,
                    #
                    'output_space_slice': frame_outspace_box.to_slice(),
                    'output_image_dsize': outimg_box.dsize,
                    'scale_outspace_from_vid': scale_outspace_from_vid,
                    'ann_aids': None,
                })

            if not self.inference_only:
                # Build single-frame truth
                self._populate_frame_labels(
                    frame_item, gid, output_dsize, time_idx,
                    mode_to_invalid_mask, resolution_info, truth_info,
                    meta_info)

                # VALIDATE_SHAPES = 1 or __debug__
                # if VALIDATE_SHAPES:
                #     # can test for the no scale case in resolution_info
                #     key0 = list(mode_to_imdata.keys())[0]
                #     val0 = mode_to_imdata[key0]
                #     _, h2, w2 = val0.shape
                #     h1, w1 = frame_item['saliency'].shape
                #     import xdev
                #     with xdev.embed_on_exception_context:
                #         assert h2 == h1

            wants_outputs = self.requested_tasks['outputs']
            if wants_outputs and 'output_weights' not in frame_item:
                output_weights = self._build_generic_frame_weights(output_dsize, mode_to_invalid_mask, meta_info, time_idx)
                frame_item['output_weights'] = output_weights

            frame_items.append(frame_item)
        return frame_items

    def _build_generic_frame_weights(self, output_dsize, mode_to_invalid_mask, meta_info, time_idx):
        """
        Ignore:
            >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import KWCocoVideoDataset
            >>> import ndsampler
            >>> import kwcoco_dataloader
            >>> coco_dset = kwcoco.CocoDataset.demo('vidshapes2', num_frames=10)
            >>> sampler = ndsampler.CocoSampler(coco_dset)
            >>> self = KWCocoVideoDataset(sampler, mode="fit", time_dims=4, window_dims=(196, 196),
            >>>                           channels='r|g|b', neg_to_pos_ratio=0, autobuild=False, upweight_centers=True)
            >>> # Setup example inputs this function might be run with
            >>> output_dsize = (196, 196)
            >>> mode_to_invalid_mask = {'r|g|b': None}
            >>> meta_info = {'time_weights': np.array([0.9, 1. , 1. , 0.9])}
            >>> time_idx = 0
            >>> generic_frame_weight = self._build_generic_frame_weights(output_dsize, mode_to_invalid_mask, meta_info, time_idx)
        """

        # frame_poly_weights = np.maximum(frame_poly_weights, self.config['min_spacetime_weight'])
        if self.config['upweight_centers']:
            time_weights = meta_info['time_weights']
            frame_target_shape = output_dsize[::-1]
            space_shape = frame_target_shape
            space_weights = _space_weights(space_shape)
            space_weights = np.maximum(space_weights, self.config['min_spacetime_weight'])
            spacetime_weights = space_weights * time_weights[time_idx]
        else:
            spacetime_weights = 1

        # Note: ensure this is resampled into target output space
        # Module the pixelwise weights by the 1 - the fraction of modes
        # that have nodata.
        if self.config['downweight_nan_regions']:
            nodata_total = 0.0
            for mask in mode_to_invalid_mask.values():
                if mask is not None:
                    if len(mask.shape) == 3:
                        mask_ = mask.mean(axis=2)
                        # mask_ = ((mask.sum(axis=2) / mask.shape[2])).astype(float)
                    else:
                        mask_ = mask.astype(float)
                    mask_ = kwimage.imresize(mask_, dsize=output_dsize)
                    nodata_total += mask_
            total_bands = len(mode_to_invalid_mask)
            nodata_frac = nodata_total / total_bands
            nodata_weight = 1 - nodata_frac
        else:
            nodata_weight = 1
            # frame_weights = frame_weights * nodata_weight

        generic_frame_weight = nodata_weight * spacetime_weights
        return generic_frame_weight

    def _populate_positional_information(self, frame_items):
        """
        Enrich each frame with information the model can use to build its
        positional encodings. It currently returns these, but it shouldn't

        Args:
            frame_items (List[Dict]):

        NOTE:
            There is a part of this where we actually compute a sinusoidal
            positional encoding, but that should be part of the model, not the
            dataset.

            The dataset can provide metadata to tell the model what it can use
            to build positional encodings, but it should never tell it how to
            use them!
        """
        ...
        # TODO: what is the standard way to do the learned embedding
        # "input vector"?

        # TODO: preprocess any auxiliary learnable information into a
        # Tensor. It is likely ideal to pre-stack whenever possible, but we
        # need to keep the row-form data to make visualization
        # straight-forward. We could use a flag to toggle it depending on
        # if we need to visualize or not.
        permode_datas = ub.ddict(list)
        prev_timestamp = None

        # TODO: this should be part of the model.
        # The dataloader should know nothing about positional encodings
        # except what is needed in order to pass the data to the model.
        time_index_encoding = util_positional_encoding.ordinal_position_encoding(len(frame_items), 8).numpy()

        for frame_item in frame_items:

            k = 'timestamp'
            frame_timestamp = np.array([frame_item[k]]).astype(np.float32)

            for mode_code in frame_item['modes'].keys():
                # Maybe this should be a model responsibility.
                # I dont like defining the positional encoding in the
                # dataset
                key_tensor = data_utils._string_to_hashvec(mode_code)
                permode_datas['mode_tensor'].append(key_tensor)
                #
                k = 'time_index'
                time_index = frame_item[k]
                # v = np.array([frame_item[k]]).astype(np.float32)
                v = time_index_encoding[time_index]
                permode_datas[k].append(v)

                if prev_timestamp is None:
                    time_offset = np.array([0]).astype(np.float32)
                else:
                    time_offset = frame_timestamp - prev_timestamp

                # TODO: add seasonal positional encoding

                permode_datas['time_offset'].append(time_offset)

                k = 'sensor'
                key_tensor = data_utils._string_to_hashvec(k)
                permode_datas[k].append(key_tensor)

            frame_item['time_offset'] = time_offset
            prev_timestamp = frame_timestamp

        positional_arrays = ub.map_vals(np.stack, permode_datas)
        time_offset = positional_arrays.pop('time_offset', None)
        if time_offset is not None:
            scaled_time_offset = data_utils.abslog_scaling(time_offset)
            positional_arrays['time_offset'] = scaled_time_offset
        else:
            print('NONE TIME OFFSET: {}'.format(list(permode_datas.keys())))

        # This is flattened for each frame for each mode.
        # A bit hacky, not in love with it.
        positional_tensors = ub.map_vals(torch.from_numpy, positional_arrays)
        return positional_tensors

    def _coerce_target(self, index):
        """
        Returns a target dictionary given an index or an explicit target dictionary
        """
        # The index can be specified as either
        # * directly as a target (target) dictionary, or
        # * an integer index

        if isinstance(index, dict):
            # print(f'index={index}')
            target = index
            requested_index = 'given-as-dictionary'
            resolved_index = 'given-as-dictionary'
        else:
            requested_index = index
            if self.mode == 'test':
                # In test-mode the index directly determines the grid location.
                resolved_index = requested_index
            else:
                if self.balanced_sampler is None:
                    # If we don't construct a balancer, then do the normal
                    # sequential sampling.
                    resolved_index = requested_index
                else:
                    # In non-test-mode we discard the user index and randomly
                    # sample a grid location to achieve balanced sampling.
                    try:
                        resolved_index = self.balanced_sampler.sample()
                    except Exception as ex:
                        raise FailedSample(f'Failed to sample grid location: {ex=}')
            target = self.sample_grid['targets'][resolved_index]

        target = target.copy()
        target['requested_index'] = requested_index
        target['resolved_index'] = resolved_index
        # LOCAL_RANK = os.environ.get('LOCAL_RANK', '0')
        # print(f'{LOCAL_RANK=}, {index=} {self.mode=} {self.balanced_sampler.sample()} {target=}')
        if target is None:
            raise FailedSample('no target')
        return target

    def _resolve_target(self, target):
        """
        Creates a copy of the target with modified information expected by the
        getitem method. This applies any sampling augmentation if enabled.  It
        also handles enriching the target with configuration level information
        if needed. There are other places in the code that do that, and it may
        be better if those are moved here.
        """
        sampler = self.sampler
        coco_dset = self.sampler.dset
        target_ = target.copy()

        target_['as_xarray'] = False
        target_['legacy_annots'] = False
        target_['legacy_target'] = False

        if 'video_id' not in target_:
            if 'gids' not in target_:
                raise NotImplementedError('TODO: Check if annot_ids is available to resolve the image id')
            _gid = ub.peek(target_['gids'])
            target_['video_id'] = sampler.dset.imgs[_gid]['video_id']

        vidid = target_['video_id']
        try:
            video = coco_dset.index.videos[vidid]
        except KeyError:
            # hack for single image datasets
            assert len(target_['gids']) == 1, 'should have only 1 image id'
            gid = target_['gids'][0]
            video = coco_dset.index.imgs[gid]

        vidid = target_['video_id']
        # video = coco_dset.index.videos[vidid]
        resolution_info = self._resolve_resolution(target_, video)

        # Resolve per-target parameters
        allow_augment = target_.get('allow_augment', (not self.disable_augmenter) and self.mode == 'fit')
        target_['allow_augment'] = allow_augment

        if not self.inference_only:
            target_['dist_weights'] = target_.get('dist_weights', self.config['dist_weights'])

        if allow_augment:
            target_ = self._augment_spacetime_target(target_)
        return target_, video, resolution_info

    @ub.memoize_method
    def _cached_sample_sensorchan_matching_sensor(self, sensor_coarse):
        matching_sensorchan = self.input_sensorchan.matching_sensor(sensor_coarse)
        return matching_sensorchan

    def _sample_one_frame(self, gid, sampler, coco_dset, target_, with_annots,
                          gid_to_isbad, gid_to_sample):
        """
        Core logic that uses the target dictionary to sample a single frame at
        a time via ndsampler. Some post-loading augmentation is also done here.
        """
        # helper that was previously a nested function moved out for profiling
        coco_img = coco_dset.coco_image(gid)
        sensor_coarse = coco_img.img.get('sensor_coarse', coco_img.img.get('sensor', '*'))
        matching_sensorchan = self._cached_sample_sensorchan_matching_sensor(sensor_coarse)

        def _ensure_list(x):
            return x if isinstance(x, list) else [x]

        SAMECOLOR_QUALITY_HEURISTIC = target_.get('SAMECOLOR_QUALITY_HEURISTIC', self.config['mask_samecolor_method'])
        SAMECOLOR_BANDS = target_.get('SAMECOLOR_BANDS', FusedChannelSpec.coerce(self.config['mask_samecolor_bands']).as_set())
        SAMECOLOR_VALUES = target_.get('SAMECOLOR_VALUES', _ensure_list(self.config['mask_samecolor_values']))
        use_samecolor_region_method = SAMECOLOR_QUALITY_HEURISTIC == 'region'

        force_bad_frames = target_.get('force_bad_frames', self.config['force_bad_frames'])
        stop_on_bad_image = not force_bad_frames
        quality_threshold = target_.get('quality_threshold', self.config['quality_threshold'])
        observable_threshold = target_.get('observable_threshold', self.config['observable_threshold'])
        mask_low_quality = target_.get('mask_low_quality', self.config['mask_low_quality'])

        PROPAGATE_NAN_BANDS = target_.get('PROPAGATE_NAN_BANDS', FusedChannelSpec.coerce(self.config['mask_nan_bands']).as_set())

        tr_frame = target_.copy()
        tr_frame['gids'] = [gid]

        # TODO: separate ndsampler annotation loading function
        first_with_annot = with_annots

        # Flag will be set to true if any heuristic on any channel stream
        # forces us to mark this image as bad.
        force_bad = False

        # Track pixel positions we will force to nan
        unobservable_mask = data_utils.MultiscaleMask()

        # Handle a special quality band channel.
        if quality_threshold > 0 or mask_low_quality:
            # Skip if quality mask indicates more than 50% clouds.
            is_low_quality = self._interpret_quality_mask(
                sampler, coco_img, tr_frame)
            if is_low_quality is not None:
                is_low_quality = is_low_quality[0]  # just first frame
                cloud_threshold = (1 - quality_threshold)
                # TODO: account for nodata values here.
                # such that quality threshold is over the valid data
                # observations.
                cloud_frac = is_low_quality.mean()
                if cloud_frac > cloud_threshold:
                    force_bad = 'too cloudy'
                if mask_low_quality:
                    unobservable_mask.update(is_low_quality)
        else:
            is_low_quality = None

        if matching_sensorchan.chans.numel() == 0:
            force_bad = f'Missing requested channels. {sensor_coarse=}, {matching_sensorchan=}, {self.sample_sensorchan=}'

        modality_streams = matching_sensorchan.streams()

        if target_['allow_augment'] and self.config['modality_dropout_rate']:
            # Augment by dropping out modalities, but always keep at least one.
            if self.config['modality_dropout_rate'] > self.augment_rng.rand():
                if self.config['modality_dropout']:
                    keep_score = self.augment_rng.rand(len(modality_streams))
                    keep_idxs = util_kwarray.argsort_threshold(
                        keep_score, self.config['modality_dropout'], num_top=1)
                    modality_streams = list(ub.take(modality_streams, keep_idxs))

        # Sample information from each stream (each stream is a separate mode)
        sample_streams = {}
        for input_stream in modality_streams:
            if stop_on_bad_image and force_bad:
                break

            # Determine what of the channels can be directly loaded versus needs
            # dynamic compute.
            if input_stream.spec in self._special_inputs:
                dynamic_task = self._special_inputs[input_stream.spec]
                # fixme matching_sensorchan could be multiple late fuse channels
                # hack because FusedSensorChanSpec.coerce is broken on 2024-03-31
                sample_stream = SensorChanSpec.coerce(dynamic_task['sample_sensorchan']).streams()[0]
            else:
                sample_stream = input_stream
                dynamic_task = None

            sample_chans = sample_stream.chans
            input_chans = input_stream.chans

            tr_frame['channels'] = sample_chans
            tr_frame['padkw' ] = {'constant_values': np.nan}
            tr_frame['nodata' ] = 'float'
            tr_frame['dtype'] = np.float32

            if sample_chans.spec == 'unknown-chan':
                # Hack: unknown channels mean that the kwcoco didnt specify
                # them. Thus we should assume there are homogeneous.
                tr_frame.pop('channels', None)

            # FIXME: each kwcoco asset should be able to control its own
            # interpolation as a function of its role.
            sample = sampler.load_sample(
                tr_frame, with_annots=first_with_annot,
            )

            if dynamic_task is not None:
                # NOTE: We can optimize this code quite a bit if it is useful
                _sample_chan_names = sample_chans.to_list()
                # Lookup how to compute each dynamic channel and do it
                _chan_to_idx = {k: i for i, k in enumerate(_sample_chan_names)}
                # Splitting up with direct indexing is much faster than einops
                # and dict 2.3us vs 1.7us. TODO: Can go faster if we know what
                # the requested channels are.
                requested = _chan_to_idx.keys()  # todo: update this to only the needed channels.
                new_channel_lut = {
                    name: sample['im'][..., _chan_to_idx[name]]
                    for name in requested
                }

                # Question: is it more efficient to break up channels into a
                # dictionary here, or if we compute a mapping from the channel
                # name to the index, would it be faster to lookup the channel
                # by slice instead? Maybe avoid a copy? E.g.
                # if 0:
                #     new_channel_lut = {}
                #     for arg_idx, arg_chan in enumerate(_sample_chan_names):
                #         arg_data = sample['im'][..., arg_idx]
                #         new_channel_lut[arg_chan] = arg_data
                #     # But maybe only do this on demand for channels that were
                #     # requreted? Might not matter. Dont premature optimize.

                # Tell DynamicChannels to compute the relevant data.
                dynamic_results = self._dynamic_channels.evaluate(new_channel_lut, dynamic_task['dynamic_chans'])
                new_channel_lut.update(dynamic_results)

                # Construct the input channels as if we could have loaded them
                # all from disk, and then continue. The rest of the code should
                # not care that some of these channels are dynamic.
                new_im = np.stack([new_channel_lut[n] for n in input_chans], axis=3)
                sample['im'] = new_im

            stream_oset = ub.oset(input_chans)
            if SAMECOLOR_QUALITY_HEURISTIC:
                # Update our observable mask based on bands heuristically
                # marked as valid or observable (i.e. rgb bands)
                relevant_bands = stream_oset & SAMECOLOR_BANDS
                if relevant_bands:
                    samecolor_mask = data_utils.samecolor_nodata_mask(
                        input_chans, sample['im'][0], relevant_bands,
                        use_regions=use_samecolor_region_method,
                        samecolor_values=SAMECOLOR_VALUES)
                    unobservable_mask.update(samecolor_mask)

            relevant_bands = stream_oset & PROPAGATE_NAN_BANDS
            for band in relevant_bands:
                # Mark the nans in these bands as unobservable.
                bx = stream_oset.index(band)
                band = sample['im'][0][:, :, bx]
                nodata_mask = np.isnan(band)
                unobservable_mask.update(nodata_mask)

            if unobservable_mask.masked_fraction == 1.0:
                force_bad = 'unobservable sample'
                if stop_on_bad_image:
                    break

            if observable_threshold:
                invalid_frac = unobservable_mask.masked_fraction
                observable_frac = 1 - invalid_frac
                if observable_frac < observable_threshold:
                    force_bad = 'failed observable threshold'
                    if stop_on_bad_image:
                        break

            if target_['allow_augment'] and self.config['channel_dropout']:
                if self.config['channel_dropout_rate'] > self.augment_rng.rand():
                    num_bands = sample['im'].shape[3]
                    if num_bands > 1:
                        keep_score = self.augment_rng.rand(num_bands)
                        keep_idxs = util_kwarray.argsort_threshold(
                            keep_score, self.config['channel_dropout'],
                            num_top=1)
                        drop_flags = ~kwarray.boolmask(keep_idxs, num_bands)
                        if np.any(drop_flags):
                            sample['im'][:, :, :, drop_flags] = np.nan

            sample_streams[input_chans.spec] = sample
            if 'annots' in sample:
                # dont ask for annotations multiple times
                first_with_annot = False

        coco_video = coco_img.video
        if coco_video is None:
            domain = None
        else:
            # domain = coco_video.get('domain', coco_video.get('name', None))
            domain = coco_video.get('domain', None)

        # After all channels are sampled, apply final invalid mask.
        for input_chans, sample in sample_streams.items():

            if self.prenormalizers is not None:
                modality = Modality(channels=input_chans, sensor=sensor_coarse,
                                    domain=domain)
                stats = self.prenormalizers[modality]
                out = sample['im']
                # norm = kwarray.normalize(
                #     sample['im'],
                #     mode='sigmoid',
                #     alpha=stats['std'][None, None, None, :],
                #     beta=stats['mean'][None, None, None, :],
                #     min_val=stats['min'][None, None, None, :],
                #     max_val=stats['max'][None, None, None, :],
                # )
                from scipy.special import expit as sigmoid
                np.minimum(out, stats['max'][None, None, None, :], out=out)
                np.maximum(out, stats['min'][None, None, None, :], out=out)
                presig = (out - stats['mean'][None, None, None, :]) / stats['std'][None, None, None, :]
                norm = sigmoid(presig)
                out[:] = norm

            unobservable_mask.apply(sample['im'][0], np.nan)
            invalid_mask = np.isnan(sample['im'])
            any_invalid = np.any(invalid_mask)
            if any_invalid:
                sample['invalid_mask'] = invalid_mask
            else:
                sample['invalid_mask'] = None

        if not force_bad:
            if len(sample_streams) == 0:
                force_bad = 'no-streams'

        gid_to_isbad[gid] = force_bad
        gid_to_sample[gid] = sample_streams

        HACK_FIX_NATIVE_ANNOT_SIZE = getenv("HACK_FIX_NATIVE_ANNOT_SIZE", "True").lower() in ('true', '1', 't')
        if HACK_FIX_NATIVE_ANNOT_SIZE:
            # When sampling in native resolution, the annotations will be
            # sampled at that resolution. However, when there are multiple
            # modes for a input frame, it becomes unclear which native scale is
            # the right one to sample the annotations in. Thus we find the
            # maximum dimension over all the modes, and then upscale the
            # annotations to match that.
            annot_mode_dims = None
            all_mode_dims = []
            frame_dets = None
            for sample in sample_streams.values():
                mode_dims = sample['im'].shape[1:3]
                if 'annots' in sample:
                    frame_dets = sample['annots']['frame_dets'][0]
                    annot_mode_dims = mode_dims
                all_mode_dims.append(mode_dims)
            if all_mode_dims:
                max_mode_dims = np.array(max(all_mode_dims, key=np.prod))
                if frame_dets is not None:
                    fixup_scale = (max_mode_dims / annot_mode_dims)[::-1]
                    frame_dets.scale(fixup_scale, inplace=True)
                    # Save the input dimensions we scaled to.
                    # We will need to transform this to the output dims later.
                    frame_dets.meta['input_dims'] = max_mode_dims

    def getitem(self, index):
        """
        This is just the same thing as `__getitem__` but it raises an error
        when it fails, which is handled by `__getitem__`.

        Args:
            index (int | Dict): index or target

        Returns:
            Dict

        CommandLine:
            LINE_PROFILE=1 xdoctest -m kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset GetItemMixin.getitem

        CommandLine:
            xdoctest -m kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset GetItemMixin.getitem --show

        Example:
            >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import *  # NOQA
            >>> import kwcoco
            >>> import kwcoco_dataloader
            >>> coco_dset = kwcoco_dataloader.coerce_kwcoco('kwcoco_dataloader-msi-dates-geodata-heatmap', num_frames=5, image_size=(128, 128), num_videos=1)
            >>> # Remove two annotations to test new time weights
            >>> aids = coco_dset.images().take([0]).annots[0].lookup('id')
            >>> coco_dset.remove_annotations(aids)
            >>> #
            >>> # Each sensor uses all of its own channels
            >>> channels = 'auto'
            >>> self = KWCocoVideoDataset(coco_dset, time_dims=5,
            >>>                           window_resolution='0.09GSD',
            >>>                           input_resolution='0.09GSD',
            >>>                           window_dims=(128, 128),
            >>>                           channels=channels,
            >>>                           balance_areas=True,
            >>>                           weight_dilate=3)
            >>> self.disable_augmenter = True
            >>> # Pretend that some external object has given us information about desired class weights
            >>> # this could be frequency based, but we will use random weights here.
            >>> dataset_stats = self.cached_dataset_stats()
            >>> import kwarray
            >>> rng = kwarray.ensure_rng(0)
            >>> class_keys = dataset_stats['class_freq']
            >>> catname_to_weight = {c: rng.rand() for c in class_keys}
            >>> catname_to_weight['star'] = 2.0
            >>> self.catname_to_weight = catname_to_weight
            >>> #
            >>> index = 0
            >>> index = target = self.sample_grid['targets'][self.sample_grid['positives_indexes'][4]]
            >>> item = self[index]
            >>> # xdoctest: +REQUIRES(--show)
            >>> canvas = self.draw_item(item)
            >>> import kwplot
            >>> kwplot.autompl()
            >>> kwplot.imshow(canvas)
            >>> kwplot.show_if_requested()
        """
        target = self._coerce_target(index)

        # Handle details about the sampling target
        # Fill in details that might be missing. Does not modify the input.
        target_, video, resolution_info = self._resolve_target(target)

        vidspace_box = resolution_info['vidspace_box']
        try:
            final_gids, gid_to_sample = self._sample_from_target(target_, vidspace_box)
        except FailedSample as ex:
            from kwutil.util_exception import add_exception_note
            raise add_exception_note(ex, f'target_ = {ub.urepr(target_, nl=1)}')
        except Exception as ex:
            print(f'target_ = {ub.urepr(target_, nl=1)}')
            msg = f'Unknown sample error: ex = {ub.urepr(ex, nl=1)}'
            print(msg)
            warnings.warn(msg)
            raise
            raise FailedSample(msg)

        num_frames = len(final_gids)
        if num_frames == 0:
            raise Exception('0 frames')

        ###
        # Process sampled data
        meta_info = self._prepare_meta_info(num_frames)

        if not self.inference_only:
            truth_info = self._prepare_truth_info(final_gids, gid_to_sample,
                                                  num_frames, target, target_)
        else:
            truth_info = None

        frame_items = self._build_frame_items(final_gids, gid_to_sample,
                                              truth_info, meta_info,
                                              resolution_info)

        # if self.config['prenormalize_inputs'] is not None:
        #     raise NotImplementedError

        self._robust_normalize_frame_items(frame_items, target)

        # Add in change truth
        if not self.inference_only:
            # Build multi-frame truth
            if self.requested_tasks['change']:
                if frame_items:
                    frame1 = frame_items[0]
                for frame1, frame2 in ub.iter_window(frame_items, 2):
                    class_weights1 = frame1['class_weights']
                    class_weights2 = frame2['class_weights']
                    # TODO: prefer class-ohe if available
                    class_idxs1 = frame1['class_idxs']
                    class_idxs2 = frame2['class_idxs']
                    if class_idxs2.shape != class_idxs1.shape:
                        class_idxs1 = kwimage.imresize(
                            class_idxs1, dsize=class_idxs2.shape[0:2][::-1],
                            interpolation='nearest')
                        class_weights1 = kwimage.imresize(
                            class_weights1, dsize=class_weights2.shape[0:2][::-1],
                            interpolation='nearest')
                    frame_change = (class_idxs1 != class_idxs2).astype(np.uint8)
                    # ToDO: configure kernel size here
                    frame_change = kwimage.morphology(frame_change, 'open', kernel=3)
                    change_weights = np.sqrt(class_weights1 * class_weights2)
                    frame2['change'] = frame_change
                    frame2['change_weights'] = change_weights.clip(0, None)

        pixelwise_truth_keys = [
            'change', 'class_idxs', 'class_ohe',
            'saliency', 'class_weights',
            'saliency_weights', 'change_weights',
            'output_weights',
        ]
        annotwise_truth_keys = [
            'box_ltrb',
            # 'box_tids',
            'box_cidx', 'box_weight',
        ]
        framewise_truth_keys = [
            'nonlocal_class_ohe'
        ]
        coord_truth_keys = [
            'box_ltrb',
        ]
        truth_keys = pixelwise_truth_keys + annotwise_truth_keys + framewise_truth_keys

        # If we are augmenting
        fliprot_params = target_.get('fliprot_params', None)
        if fliprot_params is not None:
            for frame_item in frame_items:
                frame_modes = frame_item['modes']
                for mode_key in list(frame_modes.keys()):
                    # Augment the underlying data
                    mode_data = frame_modes[mode_key]
                    frame_modes[mode_key] = data_utils.fliprot(mode_data, **fliprot_params, axes=[1, 2])
                for key in pixelwise_truth_keys:
                    # Augment the truth rasters in the same way
                    data = frame_item.get(key, None)
                    if data is not None:
                        if key == 'class_ohe':
                            frame_item[key] = data_utils.fliprot(data, **fliprot_params, axes=[-3, -2])
                        else:
                            frame_item[key] = data_utils.fliprot(data, **fliprot_params, axes=[-2, -1])
                for key in coord_truth_keys:
                    # Augment the truth coordinates in the same way
                    data = frame_item.get(key, None)
                    if data is not None:
                        output_dims = frame_item['output_dims']
                        frame_item[key] = data_utils.fliprot_annot(
                            kwimage.Boxes(data, 'ltrb'), **fliprot_params, axes=[-2, -1], canvas_dsize=output_dims).data

        # Convert data to torch
        for frame_item in frame_items:
            frame_modes = frame_item['modes']
            for mode_key in list(frame_modes.keys()):
                mode_data = frame_modes[mode_key]
                frame_modes[mode_key] = kwarray.ArrayAPI.tensor(mode_data)
            for key in truth_keys:
                data = frame_item.get(key, None)
                if data is not None:
                    try:
                        frame_item[key] = kwarray.ArrayAPI.tensor(data)
                    except TypeError:
                        frame_item[key] = torch.tensor(data)

        # Only pass back some of the metadata (because I think torch
        # multiprocessing makes a new file descriptor for every Python object
        # or something like that)
        relevant_target_keys = {
            'gids', 'space_slice', 'video_id', 'fliprot_params',
            'main_idx', 'scale', 'main_skip_reason', 'allow_augment'
        }
        resolved_target_subset = ub.dict_isect(target_, relevant_target_keys)
        requested_target_subset = ub.dict_isect(target, relevant_target_keys)

        resolved_input_scale = resolution_info['resolved_input_scale']
        resolved_output_scale = resolution_info['resolved_output_scale']

        # Future directions:
        # Currently we submit items based on the idea that the outputs will be
        # alignable to the inputs. This need not be the case.  This should
        # provide the input sequence as well as what the desired output
        # sequence should be. The requested output sequence could be disjoint
        # from the input sequence. It could also be aligned, or perhaps it is
        # just a single classification prediction over the entire sequence.
        vidid = target_['video_id']
        item = {
            'frames': frame_items,
        }

        if not self.config['reduce_item_size']:
            LOCAL_RANK = os.environ.get('LOCAL_RANK', '-1')
            item.update({
                'producer_mode': self.mode,
                'producer_rank': LOCAL_RANK,
                'requested_index': target.get('requested_index', None),
                'resolved_index': target.get('resolved_index', None),
                # '_new_inputs': ...,
                # '_new_outputs': ...,
                'video_id': vidid,
                'video_name': video.get('name', None),
                'domain': video.get('domain', video.get('name', None)),
                'input_gsd': resolved_input_scale.get('gsd', None),
                'output_gsd': resolved_output_scale.get('gsd', None),

                # TODO: rename 'target' to resolved_target
                'target': resolved_target_subset,
                'requested_target': requested_target_subset,
            })
            # Probably not the job of the dataset to produce positional
            # encodings
            positional_tensors = self._populate_positional_information(frame_items)
            item['positional_tensors'] = positional_tensors
            # Abstract away details of the dictionary structure by wrapping in
            # a helper class.
            item['predictable_classes'] = self.predictable_classes
            item['requested_tasks'] = self.requested_tasks
        else:
            # overhead should be small to at least return context by default
            item.update({
                'target': resolved_target_subset,
            })

        if self.config['reduce_item_size']:
            nonessential_frame_keys = [
                'gid',
                'date_captured',
                'timestamp',
                'time_index',
                'sensor',

                # Could group these into head and input/head specific dictionaries?
                # info for how to construct the output.
                'change_output_dims',
                'class_output_dims',
                'saliency_output_dims',
                #
                # 'output_dims',
                'output_space_slice',
                'output_image_dsize',
                'scale_outspace_from_vid',
                'ann_aids',
            ]
            for frame in item['frames']:
                for k in nonessential_frame_keys:
                    frame.pop(k, None)

        if True:
            # Wrap the dictionary item in a convinience class
            output_type = target.get('output_type', self.config['output_type'])
            if output_type == 'heterogeneous':
                item = HeterogeneousBatchItem(item)
            elif output_type == 'homogeneous':
                item = HomogeneousBatchItem(item)
            elif output_type == 'rgb':
                item = RGBImageBatchItem(item)
            else:
                raise KeyError(output_type)

        return item

    def _robust_normalize_frame_items(self, frame_items, target):
        """
        SeeAlso:
            _init_robust_normalizers

        Example:
            >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import *  # NOQA
            >>> import kwcoco_dataloader
            >>> coco_dset = kwcoco_dataloader.coerce_kwcoco('kwcoco_dataloader-msi', num_frames=5, image_size=(128, 128), num_videos=1)
            >>> # Test full independent normalization
            >>> self = KWCocoVideoDataset(coco_dset, time_dims=5,
            >>>                           window_dims=(128, 128),
            >>>                           robust_normalize=ub.codeblock(
                '''
                separate_channels: True
                separate_time: True
                separate_sensors: True
                sensorchan: '*'
                '''))
            >>> print(f'self.robust_normalizer._normalizer_items = {ub.urepr(self.robust_normalizer._normalizer_items, nl=1)}')
            >>> index = 0
            >>> item = self[index]
            >>> # Check that all items were normalized independently, so each
            >>> # raster (assuming there are at least 2 unique values) should range
            >>> # between 0 and 1.
            >>> for mode_key, data in item.iter_modes():
            >>>     stats = kwarray.stats_dict(data, axis=(1, 2), nan=True)
            >>>     print(f'stats{mode_key} = {ub.urepr(stats, nl=1)}')
            >>>     assert np.allclose(stats['min'], 0)
            >>>     assert np.allclose(stats['max'], 1)
            >>> # xdoctest: +REQUIRES(--show)
            >>> canvas = self.draw_item(item)
            >>> import kwplot
            >>> kwplot.autompl()
            >>> kwplot.imshow(canvas)
            >>> kwplot.show_if_requested()

        Example:
            >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import *  # NOQA
            >>> import kwcoco_dataloader
            >>> coco_dset = kwcoco_dataloader.coerce_kwcoco('kwcoco_dataloader-msi', num_frames=5, image_size=(128, 128), num_videos=1)
            >>> # Test dependant normalization
            >>> self = KWCocoVideoDataset(coco_dset, time_dims=5,
            >>>                           window_dims=(128, 128),
            >>>                           robust_normalize=ub.codeblock(
                '''
                separate_channels: False
                separate_time: False
                separate_sensors: False
                groups:
                  - sensorchan: "r|g|b"
                  - sensorchan: "disparity|gauss"
                '''))
            >>> print(f'self.robust_normalizer._normalizer_items = {ub.urepr(self.robust_normalizer._normalizer_items, nl=1)}')
            >>> index = 0
            >>> item = self[index]
            >>> # Check that all items were normalized independently, so each
            >>> # raster (assuming there are at least 2 unique values) should range
            >>> # between 0 and 1.
            >>> for mode_key, data in item.iter_modes():
            >>>     stats = kwarray.stats_dict(data, axis=(1, 2), nan=True)
            >>>     print(f'stats{mode_key} = {ub.urepr(stats, nl=1)}')
            >>>     assert np.any(stats['min'] != 0)
            >>>     assert np.any(stats['max'] != 1)
            >>> # xdoctest: +REQUIRES(--show)
            >>> canvas = self.draw_item(item)
            >>> import kwplot
            >>> kwplot.autompl()
            >>> kwplot.imshow(canvas)
            >>> kwplot.show_if_requested()

        Example:
            >>> # Test normalization with dynamic channels
            >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import *  # NOQA
            >>> import kwcoco_dataloader
            >>> coco_dset = kwcoco_dataloader.coerce_kwcoco('vidshapes2')
            >>> self = KWCocoVideoDataset(coco_dset, time_dims=5,
            >>>                           window_dims=(128, 128),
            >>>                           channels='r|neg_g|weird_b',
            >>>                           robust_normalize=ub.codeblock(
                                                '''
                                                separate_time: True
                                                groups:
                                                  - channels: "r"
                                                  - channels: "neg_g"
                                                    mode: sigmoid
                                                    mode: linear
                                                    high: 0.91
                                                    low: 0.33
                                                  - channels: "weird_b"
                                                    mode: linear
                                                    high: 0.51
                                                    low: 0.49
                                                '''),
            >>>                           dynamic_channels=ub.codeblock(
                                              '''
                                              - name: neg_g
                                                expr: '-g'
                                              - name: weird_b
                                                expr: 'exp(b / 255) ** 3 + 1'
                                              '''))
            >>> index = 0
            >>> index = target = self.sample_grid['targets'][self.sample_grid['positives_indexes'][4]]
            >>> item = self[index]
            >>> # xdoctest: +REQUIRES(--show)
            >>> canvas = self.draw_item(item)
            >>> import kwplot
            >>> kwplot.autompl()
            >>> kwplot.imshow(canvas)
            >>> kwplot.show_if_requested()
        """
        # Allow each item to be robustly normalized independently
        robust_normalizer = target.get('robust_normalize', None)
        if robust_normalizer is None:
            robust_normalizer = self.robust_normalizer
        else:
            # Note: this is not efficient, and should only be used for
            # debugging / exploration.
            robust_normalizer = RobustNormalizer.coerce(
                robust_normalizer, default_sensorchan=self.input_sensorchan)

        if robust_normalizer is not None:
            # print(f'robust_normalizer._normalizer_items = {ub.urepr(robust_normalizer._normalizer_items, nl=1)}')
            robust_normalizer.normalize(frame_items, _debug=0)

        if self.config['normalize_perframe']:
            assert not robust_normalizer, 'normalize_perframe cannot be used with robust_normalize'
            # DEPRECATED
            ub.schedule_deprecation(
                'kwcoco_dataloader', 'normalize_perframe', 'param',
                'use robust_normalize instead with separate_time=True',
                deprecate='0.1.0', error='1.0.0', remove='1.1.0',
            )
            for frame_item in frame_items:
                frame_modes = frame_item['modes']
                for mode_key in list(frame_modes.keys()):
                    mode_data = frame_modes[mode_key]
                    to_restack = []
                    for item in mode_data:
                        # TODO: use real nodata values? Ideally they have
                        # already been converted into nans
                        mask = (item != 0) & np.isfinite(item)

                        # FIXME: The normalizer defaults are not the same
                        # between perframe and peritem. Should they be? In
                        # either case we need to let the user specify them.
                        norm_item = kwimage.normalize_intensity(item, params={
                            'high': 0.90,
                            'mid': 0.5,
                            'low': 0.01,
                            'mode': 'linear',
                        }, mask=mask)
                        to_restack.append(norm_item)
                    mode_data_normed = np.stack(to_restack, axis=0)
                    frame_modes[mode_key] = mode_data_normed

        if self.config['normalize_peritem'] is not None and self.config['normalize_peritem'] is not False:
            assert not robust_normalizer, 'normalize_peritem cannot be used with robust_normalize'
            ub.schedule_deprecation(
                'kwcoco_dataloader', 'normalize_peritem', 'param',
                'use robust_normalize instead with separate_time=False',
                deprecate='0.1.0', error='1.0.0', remove='1.1.0',
            )
            # DEPRECATED
            # Gather items that need normalization
            needs_norm = ub.ddict(list)
            for frame_item in frame_items:
                sensor = frame_item['sensor']
                frame_modes = frame_item['modes']
                for mode_key in list(frame_modes.keys()):
                    mode_chan = FusedChannelSpec.coerce(mode_key)
                    common_key = mode_chan.intersection(self.config['normalize_peritem'])
                    if common_key:
                        parent_data = frame_modes[mode_key]
                        for chan_name, chan_sl in mode_chan.component_indices(axis=0).items():
                            if chan_name in common_key:
                                chan_data = parent_data[chan_sl]
                                valid_mask = np.isfinite(chan_data)
                                needs_norm[(sensor, chan_name)].append((chan_data, valid_mask, parent_data, chan_sl))

            # TODO: we could do data augmentation with these or let the user
            # specify a better way.
            peritem_normalizer_params = {
                'high': 0.95,
                # 'mid': 0.5,
                'mid': 0.5,
                'low': 0.00,
                # 'mode': 'sigmoid',
                'mode': 'linear',
            }
            # print('DO NORM')
            for key, norm_items in needs_norm.items():
                raw_datas = np.concatenate([t[0].ravel() for t in norm_items], axis=0)
                valid_mask = np.concatenate([t[1].ravel() for t in norm_items], axis=0)
                valid_raw_datas = raw_datas[valid_mask]
                # Compute normalizers over the entire temporal range per-sensor
                normalizer = kwarray.find_robust_normalizers(valid_raw_datas,
                                                             params=peritem_normalizer_params)
                # Postprocess / regularize the normalizer
                # FIXME: This postprocess step unintuitive and not easy to
                # explain, we should mark this as legacy behavior and introduce
                # a new more reasonable default for peritem normalization.
                prior_min = min(0, normalizer['min_val'])
                alpha = 0.5
                normalizer['min_val'] * alpha + (1 - alpha) * prior_min
                # normalizer['min_val'] = 0  # keep min
                # print(f'normalizer={normalizer}')

                # HACK: For backwards compatibility re-introduce the same bug.
                normalizer.pop('min_val', None)
                normalizer.pop('max_val', None)
                # Apply the normalize to the original data
                for chan_data, valid_mask, parent_data, chan_sl in norm_items:
                    valid_data = chan_data[valid_mask]
                    # Apply normalizer (todo: use kwimage variant)
                    imdata_normalized = util_kwarray.apply_robust_normalizer(
                        normalizer, chan_data, valid_data, valid_mask,
                        dtype=np.float32, copy=True)
                    # Overwrite original data with new normalized variants
                    parent_data[chan_sl] = imdata_normalized

    def _sample_from_target(self, target_, vidspace_box):
        """
        Given a space-time target, samples frame rasters and annotation vectors.

        This includes
            * rejection sampling
            * quality masking
            * dynamic resolution
        """
        ###
        # Execute data sampling
        ###
        sampler = self.sampler
        coco_dset = self.sampler.dset

        vidid = target_['video_id']
        try:
            video = coco_dset.index.videos[vidid]
        except KeyError:
            # Hack for loose images
            assert len(target_['gids']) == 1, 'should have only 1 image id'
            gid = target_['gids'][0]
            video = coco_dset.index.imgs[gid]
            is_loose_img = True
        else:
            is_loose_img = False

        with_annots = False if self.inference_only else ['boxes', 'segmentation']

        # These dictionaries help maintain sampling state.
        # It might be nice to abstract this out into a smaller testable
        # component with the _sample_one_frame method.
        gid_to_sample: Dict[int, Dict] = {}
        gid_to_isbad: Dict[int, bool] = {}
        for gid in target_['gids']:
            self._sample_one_frame(gid, sampler, coco_dset, target_, with_annots,
                                   gid_to_isbad, gid_to_sample)

        # TODO: remove the need to access sample_grid if the target is already
        # resolved and we don't need to do any resampling.
        try:
            time_sampler = self.sample_grid['vidid_to_time_sampler'][vidid]
            video_gids = time_sampler.video_gids
        except AttributeError:
            video_gids = list(self.coco_dset.images(video_id=vidid))

        # If we skipped the main gid, record why
        main_gid = target_.get('main_gid', None)
        if main_gid is not None and gid_to_isbad[main_gid]:
            main_skip_reason = gid_to_isbad[main_gid]
        else:
            main_skip_reason = None

        resample_invalid = target_.get('resample_invalid_frames', self.config['resample_invalid_frames'])
        num_images_wanted = len(target_['gids'])
        if resample_invalid and not is_loose_img:
            if resample_invalid is True:
                max_tries = 3
            else:
                max_tries = int(resample_invalid)
            vidname = video['name']
            self._resample_bad_images(
                video_gids, gid_to_isbad, sampler, coco_dset, target_,
                num_images_wanted, with_annots, gid_to_sample, vidspace_box,
                vidname, max_tries)

        good_gids = {gid for gid, flag in gid_to_isbad.items() if not flag}
        if len(good_gids) == 0:
            raise FailedSample(ub.paragraph(
                f'''
                Cannot force a good sample. Tried to sample {len(gid_to_isbad)}
                frames, but all were marked as bad. Reported reasoning:
                gid_to_isbad={gid_to_isbad}
                '''))

        force_bad_frames = target_.get('force_bad_frames', 0)
        if force_bad_frames:
            final_gids = [g for g in video_gids if g in gid_to_isbad]
        else:
            final_gids = [g for g in video_gids if g in good_gids]

        target_['gids'] = final_gids

        if main_skip_reason:
            target_['main_skip_reason'] = main_skip_reason

        return final_gids, gid_to_sample

    def _resolve_resolution(self, target_, video):
        # Compute scale if we are doing that
        # This should live somewhere else, but lets just get it hooked up
        vidspace_gsd = video.get('target_gsd', None)

        # The target is allowed to overload the scales
        if target_.get('input_space_scale', None) is None:
            target_['input_space_scale'] = self.config['input_space_scale']
        if target_.get('output_space_scale', None) is None:
            target_['output_space_scale'] = self.config['output_space_scale']
        # Resolve spatial scale code
        resolved_input_scale = data_utils.resolve_scale_request(
            request=target_['input_space_scale'], data_gsd=vidspace_gsd)

        resolved_output_scale = data_utils.resolve_scale_request(
            request=target_['output_space_scale'], data_gsd=vidspace_gsd)

        common_input_scale = resolved_input_scale['scale']
        common_output_scale = resolved_output_scale['scale']
        target_['scale'] = common_input_scale

        # Put the target slice in video space.
        OPTIMIZE = 0
        if OPTIMIZE:
            # Need to have kwimage 0.10.1 to enable this
            sl_y, sl_x = target_['space_slice']
            y1 = sl_y.start
            y2 = sl_y.stop
            x1 = sl_x.start
            x2 = sl_x.stop
            vidspace_ltrb = np.array([[x1, y1, x2, y2]])
            _boxes = kwimage.Boxes(vidspace_ltrb, 'ltrb', canonical=True)
            vidspace_box = kwimage.Box(_boxes)
            vidspace_dsize = (x2 - x1, y2 - y1)
        else:
            vidspace_box = kwimage.Box.from_slice(target_['space_slice'],
                                                  clip=False, wrap=False)
            vidspace_dsize = vidspace_box.dsize

        # Size of the video the target is embedded in.
        video_dsize = np.array([video['width'], video['height']])

        if isinstance(common_input_scale, str) and common_input_scale == 'native':
            target_.pop('scale')
            # native scales will only work in late-fused modes
            target_['use_native_scale'] = True
            target_['realign_native'] = 'largest'
        else:
            if isinstance(common_output_scale, str) and common_output_scale == 'native':
                raise Exception(
                    'output scale can only be native when input scale is native')

        if isinstance(common_output_scale, str) and common_output_scale == 'native':
            common_outspace_box = None
        else:
            # Compute where this output chip should live in its output space canvas.
            common_output_scale = resolved_output_scale['scale']
            if OPTIMIZE:
                sx = sy = common_output_scale
                common_outspace_ltrb = vidspace_ltrb * np.array([[sx, sy, sx, sy]])
                _boxes = kwimage.Boxes(common_outspace_ltrb, 'ltrb', canonical=True)
                common_outspace_box = kwimage.Box(_boxes)
            else:
                common_outspace_box = vidspace_box.scale(common_output_scale)
                common_outspace_box = common_outspace_box.quantize(inplace=True)

        # fixme: giant tuple returns are error prone
        resolution_info = {
            'common_outspace_box': common_outspace_box,
            'vidspace_box': vidspace_box,
            'video_dsize': video_dsize,
            'vidspace_dsize': vidspace_dsize,
            'resolved_input_scale': resolved_input_scale,
            'resolved_output_scale': resolved_output_scale,
            'common_input_scale': common_input_scale,
            'common_output_scale': common_output_scale,

        }
        return resolution_info

    def _resample_bad_images(self, video_gids, gid_to_isbad, sampler,
                             coco_dset, target_, num_images_wanted, with_annots,
                             gid_to_sample, vidspace_box, vidname, max_tries):
        """
        If the initial sample has marked any of the images as "bad", then we
        attempt to find replacements by reusing the temporal sampler, but with
        extra arguments to exclude the bad frames.
        """
        # If any image is junk allow for a resample
        if any(gid_to_isbad.values()):
            vidid = target_['video_id']
            time_sampler = self.sample_grid['vidid_to_time_sampler'][vidid]
            for iter_idx in range(max_tries):
                # print(f'resample try iter_idx={iter_idx}')
                good_gids = np.array([gid for gid, flag in gid_to_isbad.items() if not flag])
                if len(good_gids) == num_images_wanted:
                    break
                bad_gids = np.array([gid for gid, flag in gid_to_isbad.items() if flag])
                include_idxs = np.where(kwarray.isect_flags(time_sampler.video_gids, good_gids))[0]
                exclude_idxs = np.where(kwarray.isect_flags(time_sampler.video_gids, bad_gids))[0]
                try:
                    chosen = time_sampler.sample(include=include_idxs,
                                                 exclude=exclude_idxs,
                                                 error_level=0,
                                                 return_info=False)
                except Exception:
                    break
                new_idxs = np.setdiff1d(chosen, include_idxs)
                new_gids = video_gids[new_idxs]
                # print('new_gids = {!r}'.format(new_gids))
                if not len(new_gids):
                    # warnings.warn('exhausted resample possibilities')
                    _bad_reasons = repr({k: v for k, v in gid_to_isbad.items() if v})
                    vidspace_box_str = str(vidspace_box)
                    if 0:
                        print(f'exhausted resample possibilities: {vidname} {vidspace_box_str} {_bad_reasons}')
                    # Exhausted all possibilities
                    break
                for gid in new_gids:
                    self._sample_one_frame(gid, sampler, coco_dset, target_,
                                           with_annots, gid_to_isbad,
                                           gid_to_sample)


class IntrospectMixin:
    """
    Methods for introspection / visualization of data
    """

    def draw_item(self, item, item_output=None, combinable_extra=None,
                  max_channels=5, max_dim=224, norm_over_time='auto',
                  overlay_on_image=False, draw_weights=True, rescale='auto',
                  classes=None, show_summary_text=True, **kwargs):
        """
        Visualize an item produced by this DataSet.

        Each channel will be a row, and each column will be a timestep.

        Args:
            item (Dict): An item returned from the torch Dataset.

            overlay_on_image (bool):
                if True, the truth and prediction is drawn on top of
                an image, otherwise it is drawn on a black image.

            max_dim (int):
                max dimension to resize each grid cell to.

            max_channels (int) :
                maximum number of channel rows to draw

            item_output (Dict):
                Special task keys that we know how to plot.
                These should be some sort of binary or class prediction from
                the network. I'm not sure how best to pass the details
                of how they should be interpreted.

                Known keys:
                    change_probs
                    saliency_probs
                    class_probs
                    pred_ltrb

            classes (kwcoco.CategoryTree | None):
                Classes any "class_probs" in the 'item_output' dictionary
                corresponds to.  If unspecified uses the classes from the
                datamodule.

            show_summary_text (bool):
                if True, draw additional summary debug information.
                Defaults to True.

            **kwargs:
                additional arguments to :class:`BatchVisualizationBuilder`.

        Note:
            The ``self.requested_tasks`` controls the task labels returned by
            getitem, and hence what can be visualized here.

        Note:
            In the future, the returned :class:`HeterogeneousBatchItem` will
            control how it is drawn, removing this responsibility from the
            dataset itself.

        Example:
            >>> # Basic Data Sampling with lots of small objects
            >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import *  # NOQA
            >>> import kwcoco_dataloader
            >>> anchors = np.array([[0.1, 0.1]])
            >>> size = (96, 96)
            >>> coco_dset = kwcoco.CocoDataset.demo('vidshapes1', num_frames=4, num_tracks=40, anchors=anchors, image_size=size)
            >>> self = KWCocoVideoDataset(coco_dset, time_dims=4, window_dims=size, default_class_behavior='ignore')
            >>> self._notify_about_tasks(predictable_classes=['star', 'eff'])
            >>> self.requested_tasks['change'] = False
            >>> index = self.sample_grid['targets'][self.sample_grid['positives_indexes'][0]]
            >>> item = self[index]
            >>> canvas = self.draw_item(item, draw_weights=False)
            >>> # xdoctest: +REQUIRES(--show)
            >>> import kwplot
            >>> kwplot.autompl()
            >>> label_to_color = {
            >>>     node: data['color']
            >>>     for node, data in self.predictable_classes.graph.nodes.items()}
            >>> label_to_color = ub.sorted_keys(label_to_color)
            >>> legend_img = kwplot.make_legend_img(label_to_color)
            >>> legend_img = kwimage.imresize(legend_img, scale=4.0)
            >>> show_canvas = kwimage.stack_images([canvas, legend_img], axis=1)
            >>> kwplot.imshow(show_canvas)
            >>> kwplot.show_if_requested()

        Example:
            >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import *  # NOQA
            >>> import kwcoco
            >>> import kwarray
            >>> import rich
            >>> coco_dset = kwcoco.CocoDataset.demo('vidshapes2-multispectral', num_frames=5)
            >>> channels = 'B10|B8a|B1|B8|B11'
            >>> combinable_extra = [['B10', 'B8', 'B8a']]  # special behavior
            >>> # combinable_extra = None  # uncomment for raw behavior
            >>> mode = 'fit'
            >>> mode = 'test'
            >>> coco_dset.clear_annotations()
            >>> self = KWCocoVideoDataset(coco_dset, mode=mode, time_dims=5, window_dims=(530, 610), channels=channels, balance_areas=True)
            >>> #index = len(self) // 4
            >>> #index = self.sample_grid['targets'][self.sample_grid['positives_indexes'][5]]
            >>> index = self.sample_grid['targets'][0]
            >>> # More controlled settings for debug
            >>> self.disable_augmenter = True
            >>> item = self[index]
            >>> item_output = self._build_demo_outputs(item)
            >>> rich.print('item summary: ' + ub.urepr(self.summarize_item(item), nl=3))
            >>> canvas = self.draw_item(item, item_output, combinable_extra=combinable_extra, overlay_on_image=1)
            >>> canvas2 = self.draw_item(item, item_output, combinable_extra=combinable_extra, max_channels=3, overlay_on_image=0)
            >>> # xdoctest: +REQUIRES(--show)
            >>> import kwplot
            >>> kwplot.autompl()
            >>> kwplot.imshow(canvas, fnum=1, pnum=(1, 2, 1))
            >>> kwplot.imshow(canvas2, fnum=1, pnum=(1, 2, 2))
            >>> kwplot.show_if_requested()

        Ignore:
            >>> # xdoctest: +REQUIRES(env:DVC_DPATH)
            >>> import kwcoco_dataloader
            >>> import rich
            >>> data_dvc_dpath = kwcoco_dataloader.find_dvc_dpath(tags='phase2_data', hardware='auto')
            >>> expt_dvc_dpath = kwcoco_dataloader.find_dvc_dpath(tags='phase2_expt', hardware='auto')
            >>> coco_fpath = data_dvc_dpath / 'KHQ_Tutorial6_Data/Aligned-KHQ_Tutorial6_Data/KHQ_R001/imgonly-KHQ_R001-rawbands.kwcoco.zip'
            >>> from kwcoco_dataloader.tasks.fusion.predict import _prepare_predict_modules, PredictConfig
            >>> config = PredictConfig(**{
            >>>     'key': 'set_cover_algo',
            >>>     'test_dataset': coco_fpath,
            >>>     'mask_low_quality': True,
            >>>     'pred_dataset': expt_dvc_dpath / '_demo_khq_doctest/pred.kwcoco.zip',
            >>>     'package_fpath': expt_dvc_dpath / 'models/fusion/Drop7-MedianNoWinter10GSD/packages/Drop7-MedianNoWinter10GSD_bgrn_split6_V74/Drop7-MedianNoWinter10GSD_bgrn_split6_V74_epoch46_step4042.pt',
            >>>     'devices': [0],
            >>> })
            >>> config, model, datamodule = _prepare_predict_modules(config)
            >>> self = datamodule.torch_datasets['test']
            >>> index = self.sample_grid['targets'][0]
            >>> # More controlled settings for debug
            >>> self.disable_augmenter = True
            >>> combinable_extra = None
            >>> item = self[index]
            >>> item_output = self._build_demo_outputs(item)
            >>> rich.print('item summary: ' + ub.urepr(self.summarize_item(item), nl=3))
            >>> canvas = self.draw_item(item, item_output, combinable_extra=combinable_extra, overlay_on_image=1)
            >>> canvas2 = self.draw_item(item, item_output, combinable_extra=combinable_extra, max_channels=3, overlay_on_image=0)
            >>> # xdoctest: +REQUIRES(--show)
            >>> import kwplot
            >>> kwplot.autompl()
            >>> kwplot.imshow(canvas, fnum=1, pnum=(1, 2, 1))
            >>> kwplot.imshow(canvas2, fnum=1, pnum=(1, 2, 2))
            >>> kwplot.show_if_requested()
        """
        if rescale == 'auto':
            rescale = self.config['input_space_scale'] != 'native'

        if item is None:
            # BIG RED X
            # h, w = vertical_stack[-1].shape[0:2]
            h = w = (max_dim or 224)
            bad_canvas = kwimage.draw_text_on_image(
                {'width': w, 'height': h}, 'X', org=(w // 2, h // 2),
                valign='center', halign='center', fontScale=10,
                color='red')
            return bad_canvas

        if False:
            # TODO: ready, use the class method
            # HeterogeneousBatchItem.draw(legend=False)
            ...

        default_combinable_channels = self.default_combinable_channels

        if norm_over_time == 'auto':
            norm_over_time = self.config['normalize_peritem'] is not None

        # Hack to force the categories to draw right for SMART
        # FIXME: Use the correct class colors in visualization.
        from kwcoco_dataloader import heuristics
        heuristics.ensure_heuristic_category_tree_colors(self.predictable_classes, force=True)

        # FIXME: requested_tasks from user input is not respected
        builder = BatchVisualizationBuilder(
            item=item, item_output=item_output,
            default_combinable_channels=default_combinable_channels,
            norm_over_time=norm_over_time, max_dim=max_dim,
            max_channels=max_channels, overlay_on_image=overlay_on_image,
            draw_weights=draw_weights, combinable_extra=combinable_extra,
            classes=self.predictable_classes, requested_tasks=self.requested_tasks,
            rescale=rescale, **kwargs)
        canvas = builder.build()

        if show_summary_text:
            summary = self.summarize_item(item)
            summary = ub.udict(summary) - {'frame_summaries'}
            summary_text = ub.urepr(summary, nobr=1, precision=2, nl=-1)
            header = kwimage.draw_text_on_image(None, text=summary_text, halign='left', color='kitware_blue')
            canvas = kwimage.stack_images([canvas, header])
        return canvas

    def summarize_item(self, item, stats=False):
        """
        Return debugging stats about the item

        Args:
            item (dict): an item returned by __getitem__
            stats (bool): if True, include statistics on input data.

        Returns:
            dict : a summary of the item

        Example:
            >>> # xdoctest: +SKIP
            >>> from kwcoco_dataloader.tasks.fusion.datamodules import kwcoco_dataset
            >>> import kwcoco
            >>> import kwcoco_dataloader
            >>> coco_dset = kwcoco.CocoDataset.demo('vidshapes1', num_frames=10)
            >>> self = kwcoco_dataset.KWCocoVideoDataset(
            >>>     coco_dset, time_dims=4, window_dims=(300, 300),
            >>>     channels='r|g|b')
            >>> self.disable_augmenter = True
            >>> index = self.sample_grid['targets'][self.sample_grid['positives_indexes'][0]]
            >>> item = self[index]
            >>> item_summary = self.summarize_item(item, stats=True)
            >>> print(f'item_summary = {ub.urepr(item_summary, nl=-1)}')
        """
        if item is None:
            raise ValueError('Cant summarize a failed sample item=None')
        # Refactored to use the new HeterogeneousBatchItem class.
        item_summary = HeterogeneousBatchItem.summarize(item, stats=stats)
        return item_summary


class BalanceMixin:
    """
    Helpers to build the sample grid and balance it

    CommandLine:
        LINE_PROFILE=1 xdoctest -m kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset BalanceMixin:1 --bench

    Example:
        >>> # Test the legacy neg_to_pos_ratio setting (todo: use more general balance_options)
        >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import KWCocoVideoDataset
        >>> import ndsampler
        >>> import kwcoco_dataloader
        >>> coco_dset = kwcoco.CocoDataset.demo('vidshapes2', num_frames=10, rng=0)
        >>> sampler = ndsampler.CocoSampler(coco_dset)
        >>> num_samples = 50
        >>> neg_to_pos_ratio = 0
        >>> self = KWCocoVideoDataset(sampler, mode="fit", time_dims=4, window_dims=(300, 300),
        >>>                           channels='r|g|b', neg_to_pos_ratio=neg_to_pos_ratio)
        >>> self.reseed(0)
        >>> num_targets = len(self.sample_grid['targets'])
        >>> positives_indexes = self.sample_grid['positives_indexes']
        >>> negatives_indexes = self.sample_grid['negatives_indexes']
        >>> print('dataset positive ratio:', len(positives_indexes) / num_targets)
        >>> print('dataset negative ratio:', len(negatives_indexes) / num_targets)
        >>> print('specified neg_to_pos_ratio:', neg_to_pos_ratio)
        >>> sampled_indexes = [self[x]['resolved_index'] for x in range(num_samples)]
        >>> num_positives = sum([x in positives_indexes for x in sampled_indexes])
        >>> num_negatives = num_samples - num_positives
        >>> print('sampled positive ratio:', num_positives / num_samples)
        >>> print('sampled negative ratio:', num_negatives / num_samples)
        >>> assert all([x in positives_indexes for x in sampled_indexes])
        >>> assert num_negatives == 0
        >>> assert num_positives > num_negatives
        >>> #...
        >>> neg_to_pos_ratio = .1
        >>> self = KWCocoVideoDataset(sampler, time_dims=4, window_dims=(300, 300),
        >>>                           channels='r|g|b', neg_to_pos_ratio=neg_to_pos_ratio)
        >>> self.reseed(0)
        >>> num_targets = len(self.sample_grid['targets'])
        >>> positives_indexes = self.sample_grid['positives_indexes']
        >>> negatives_indexes = self.sample_grid['negatives_indexes']
        >>> print('dataset positive ratio:', len(positives_indexes) / num_targets)
        >>> print('dataset negative ratio:', len(negatives_indexes) / num_targets)
        >>> print('specified neg_to_pos_ratio:', neg_to_pos_ratio)
        >>> sampled_indexes = [self[x]['resolved_index'] for x in range(num_samples)]
        >>> num_positives = sum([x in positives_indexes for x in sampled_indexes])
        >>> num_negatives = num_samples - num_positives
        >>> print('sampled positive ratio:', num_positives / num_samples)
        >>> print('sampled negative ratio:', num_negatives / num_samples)
        >>> assert num_negatives > 0
        >>> assert num_positives > num_negatives

    Example:
        >>> # xdoctest: +REQUIRES(--bench)
        >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import KWCocoVideoDataset
        >>> import ndsampler
        >>> import kwcoco_dataloader
        >>> import kwcoco
        >>> coco_fpath = '/media/joncrall/flash1/smart_phase3_data/Drop8-Cropped2GSD-V1/data_vali_rawbands_split6_n004_f9b08cce.kwcoco.zip'
        >>> coco_fpath = '/media/joncrall/flash1/smart_drop7/Drop7-Cropped2GSD-V2/data_vali_rawbands_split6.kwcoco.zip'
        >>> coco_dset = kwcoco.CocoDataset(coco_fpath)
        >>> self = KWCocoVideoDataset(coco_dset, mode="fit", time_dims=4, window_dims=(300, 300),
        >>>                           channels='red|green|blue', neg_to_pos_ratio=1.0)
    """

    def _get_video_names(self, vidids):
        unique_vidids, _idx_to_unique_idx = np.unique(vidids, return_inverse=True)
        coco_dset = self.sampler.dset
        try:
            unique_vidnames = self.sampler.dset.videos(unique_vidids).lookup('name')
        except KeyError:
            # handle loose images
            unique_vidnames = []
            for video_id in unique_vidids:
                if video_id in coco_dset.index.videos:
                    vidname = coco_dset.index.videos[video_id]['name']
                else:
                    vidname = video_id
                unique_vidnames.append(vidname)
        vidnames = list(ub.take(unique_vidnames, _idx_to_unique_idx))
        return vidnames

    def _get_region_names(self, vidnames):
        # create mapping from video name to region name
        from kwutil import util_pattern
        pat = util_pattern.Pattern.coerce(r'\w+_[A-Z]\d+_.*', 'regex')
        self.vidname_to_region_name = {}
        for vidname in set(vidnames):
            if pat.match(vidname):
                self.vidname_to_region_name[vidname] = "_".join(vidname.split('_')[:2])
            else:
                self.vidname_to_region_name[vidname] = vidname
        return list(ub.take(self.vidname_to_region_name, vidnames))

    def _load_target_annots(self, target, sequence=False):
        """
        TODO: need an ndsampler endpoint that just finds the annotations in a
        sample quickly.
        """
        space_slice = target['space_slice']
        vid_space_box = kwimage.Box.from_slice(space_slice)
        sampler = self.sampler
        if sequence:
            all_aids = []
            for gid in target['gids']:
                warp_vid_from_img = sampler.dset.coco_image(gid).warp_vid_from_img
                warp_img_from_vid = warp_vid_from_img.inv()
                img_space_box = vid_space_box.warp(warp_img_from_vid)
                aids = sampler.regions.overlapping_aids(gid, img_space_box.boxes)
                all_aids.extend(aids)
            return all_aids
        else:
            gid = target['main_gid']
            warp_vid_from_img = sampler.dset.coco_image(gid).warp_vid_from_img
            warp_img_from_vid = warp_vid_from_img.inv()
            img_space_box = vid_space_box.warp(warp_img_from_vid)
            return sampler.regions.overlapping_aids(gid, img_space_box.boxes)

    def _get_observed_annotations(self, targets):
        observed_cats = []
        for target in ub.ProgIter(targets, desc='Building observed annots'):
            aids = self._load_target_annots(target)
            catnames = self.sampler.dset.annots(aids).category_names
            observed_cats.append(ub.dict_hist(catnames))
        return observed_cats

    def _setup_attribute_dataframe(self, new_sample_grid):
        """
        Build a dataframe of attributes (for each sample) that can be used for balancing.
        """
        video_ids = [v['video_id'] for v in new_sample_grid['targets']]
        video_names = self._get_video_names(video_ids)
        region_names = self._get_region_names(video_names)
        #observed_annots = self._get_observed_annotations(new_sample_grid['targets'])
        observed_catfreq = [v['annot_info']['main_gid_catnames'] for v in new_sample_grid['targets']]
        observed_annot = [len(f) > 0 for f in observed_catfreq]

        column_attrs = {
            'video_id': video_ids,
            'video_name': video_names,
            'region': region_names,
            'contains_annotation': observed_annot,
            'class': observed_catfreq,
        }

        from kwutil import util_environ
        # Transition away from SMART-specific processing, but keep a way to
        # reintroduce it if needed.
        SMART_PHASE_COMPAT = util_environ.envflag('SMART_PHASE_COMPAT', 0)
        if SMART_PHASE_COMPAT:
            # Hard coded heuristic attributes for the particular problem.
            # This will eventually be removed in favor of a more flexible
            # configuration.
            observed_phases = list(map(lambda x: ub.dict_subset(x, set(heuristics.PHASES).intersection(x.keys())), observed_catfreq))
            # associate target window with positive / negative
            # TODO: we should deprecate 'positive-indexes' and instead make it
            # something like "indexes-with-annots"
            # target_type = kwarray.boolmask(new_sample_grid['positives_indexes'], len(new_sample_grid['targets']))
            column_attrs['contains_phase'] = [any(x) for x in observed_phases]
            column_attrs['phases'] = observed_phases

        if self.BACKWARDS_COMPAT_NEG_TO_POS:
            # To maintain compatibility with old neg_to_pos_ratio build an
            # indicator array that flags the samples the prev v0.17 code
            # considered as positive / negative. We will eventually remove
            # this logic. Samples were previously considered as negative if
            # they had no annotations OR the only annotations were
            # hueristically marked as negative.
            old_has_class_of_interest = [
                len(ub.udict.difference(f, self._old_balance_as_negative_classes)) > 0
                for f in observed_catfreq
            ]
            column_attrs['old_has_class_of_interest'] = old_has_class_of_interest
            # column_attrs['target_type'] = old_has_class_of_interest

        # build a dataframe with target attributes
        df_sample_attributes = pd.DataFrame(column_attrs).reset_index(drop=False)

        # Mark which attributes are multi-label
        multilabel_attributes = []
        if SMART_PHASE_COMPAT:
            multilabel_attributes.append('phases')
        multilabel_attributes.append('class')

        return df_sample_attributes, multilabel_attributes

    def _init_balance(self, new_sample_grid):
        """
        Build data structure used for balanced sampling.

        Helper for __init__ which constructs a BalancedSampleTree to balance sampling
        across input domains.
        """

        import kwutil
        # Balance options are specified as an ordered list of the properties we
        # want to balance over, which can contain optional information about
        # how to do balancing.
        if self.config['balance_options'] == 'sequential_without_replacement':
            self.balanced_sampler = None
            return

        balance_options = kwutil.Yaml.coerce(self.config['balance_options'])

        if balance_options is None:
            balance_options = []

        if self.BACKWARDS_COMPAT_NEG_TO_POS:
            # If the old neg_to_pos_ratio config option is given, then add a
            # new balance option to the list that reconstructs it.
            npr = self.config['neg_to_pos_ratio']
            if npr is not None:
                npr_dist = np.asarray([1, npr]) / (1 + npr)
                balance_options = [{
                    'attribute': 'old_has_class_of_interest',
                    'weights': {
                        True: npr_dist[0],
                        False: npr_dist[1],
                    }
                }] + balance_options

        print('⚖️ - Balancing over attributes')
        df_sample_attributes, multilabel_attributes = self._setup_attribute_dataframe(new_sample_grid)
        sample_grid = df_sample_attributes.to_dict('records')
        balance_attrs = [d['attribute'] for d in balance_options]
        has_multilabel_attributes = set(balance_attrs) & set(multilabel_attributes)

        # Initialize an instance of BalancedSampleTree
        # rng = self.rng
        rng = kwarray.ensure_rng(rng=None)
        if has_multilabel_attributes and self.config['num_balance_trees'] > 1:
            # If we are going to subdivide on multi-label attributes we want to
            # use a forest instead of tree.
            print('Constructing balance forest 🌲🌳🌲🌳')
            balanced_sampler = balanced_sampling.BalancedSampleForest(
                sample_grid, rng=rng, n_trees=self.config['num_balance_trees'])
        else:
            print('Constructing balance tree 🌲')
            balanced_sampler = balanced_sampling.BalancedSampleTree(
                sample_grid, rng=rng)

        for balance_option in balance_options:
            print(f'Subdivide with balance_option = {ub.urepr(balance_option, nl=1)}')
            key = balance_option['attribute']
            key_weights = balance_option.get('weights', None)
            default_weight = balance_option.get('default_weight', 0)
            balanced_sampler.subdivide(key=key, weights=key_weights,
                                           default_weight=default_weight)

        from kwutil import util_environ
        # TODO: make this a config option
        REPORT_BALANCE = util_environ.envflag('REPORT_BALANCE', 0)
        if REPORT_BALANCE:
            # Reporting for debugging
            targets = new_sample_grid['targets']

            num_targets = len(targets)
            num_samples = min(10_000, num_targets)

            normalizer = num_samples / num_targets
            print('Report Balance ⚖️  :')
            print(f'num_targets={num_targets}')
            print(f'num_samples={num_samples}')
            print(f'normalizer={normalizer}')
            print(ub.paragraph(
                '''
                Note: you can ctrl+c to skip this step while it is running.
                Or you can set the environment variable ``REPORT_BALANCE=0`` to
                prevent it from running at all.
                '''))

            try:
                sampled_idxs = [balanced_sampler.sample() for _ in ub.ProgIter(range(num_samples), desc='sample')]

                # Inspect the attributes you balanced over and compare to the naive
                # case.
                balance_attrs = [d['attribute'] for d in balance_options]

                naive_targets = df_sample_attributes.copy()
                balanced_targets = naive_targets.iloc[sampled_idxs]
                for attr in balance_attrs:
                    if attr in multilabel_attributes:
                        from collections import Counter
                        naive = Counter()
                        for row in naive_targets[attr]:
                            naive.update(row)
                        balanced = Counter()
                        for row in balanced_targets[attr]:
                            balanced.update(row)
                    else:
                        naive = naive_targets.value_counts(attr)
                        balanced = balanced_targets.value_counts(attr)
                    freq_table = pd.DataFrame({'balanced': balanced, 'naive': naive})
                    freq_table = freq_table.sort_values('balanced', ascending=0)
                    freq_table['naive'] *= normalizer
                    print('--- Balance Report ---')
                    print(f'attr={attr}')
                    print(freq_table.to_string())
            except KeyboardInterrupt:
                print('Caught keyboard interrupt. Skipping balance report')

        self.balanced_sampler = balanced_sampler
        if self.config['reseed_fit_random_generators']:
            self.reseed()


class PreprocessMixin:
    """
    Methods related to dataset preprocessing
    """

    def cached_dataset_stats(self, num=None, num_workers=0, batch_size=2,
                             with_intensity=True, with_class=True):
        """
        Compute the normalization stats, and caches them

        TODO:
            - [ ] Does this dataset have access to the workdir?
            - [ ] Cacher needs to depend on any part of the config of this
                  dataset that could impact the pixel intensity distribution.
        """
        # Get stats on the dataset (todo: nice way to disable augmentation temporarily for this)
        depends = ub.odict([
            ('num', num),
            ('hashid', self.sampler.dset._cached_hashid()),
            ('sensorchan', self.input_sensorchan.concise().spec),
            ('normalize_perframe', self.config['normalize_perframe']),  # deprecate
            ('with_intensity', with_intensity),
            ('with_class', with_class),
            ('prenormalizers', self.prenormalizers),
            ('depends_version', 22),  # bump if `compute_dataset_stats` changes
        ])
        if self.config['normalize_peritem']:
            # deprecate
            depends['normalize_peritem'] = self.config['normalize_peritem'].concise().spec
        if self.config['robust_normalize']:
            depends['robust_normalize'] = self.config['robust_normalize']
        workdir = None
        print('📊 Gather dataset stats')
        cacher = ub.Cacher('dset_mean', dpath=workdir,
                           appname='kwcoco_dataset/dataset_stats',
                           depends=depends, verbose=3)
        dataset_stats = cacher.tryload()
        if dataset_stats is None or ub.argflag('--force-recompute-stats'):
            dataset_stats = self.compute_dataset_stats(
                num, num_workers=num_workers, batch_size=batch_size)
            cacher.save(dataset_stats)
        print(f'dataset_stats = {ub.urepr(dataset_stats, nl=1)}')
        return dataset_stats

    def compute_dataset_stats(self, num=None, num_workers=0, batch_size=2,
                              with_intensity=True, with_class=True,
                              with_vidid=True):
        """
        Args:
            num (int | None): number of input items to compute stats for

        CommandLine:
            xdoctest -m kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset KWCocoVideoDataset.compute_dataset_stats:2

        Example:
            >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import *  # NOQA
            >>> import kwcoco
            >>> dct_dset = coco_dset = kwcoco.CocoDataset.demo('vidshapes2-multispectral', num_frames=3)
            >>> self = KWCocoVideoDataset(dct_dset, time_dims=2, window_dims=(256, 256), channels='auto')
            >>> self.compute_dataset_stats(num_workers=0)

        Example:
            >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import *  # NOQA
            >>> import kwcoco
            >>> coco_dset = kwcoco.CocoDataset.demo('vidshapes2')
            >>> self = KWCocoVideoDataset(coco_dset, time_dims=2, window_dims=(256, 256), channels='auto')
            >>> stats = self.compute_dataset_stats()
            >>> assert stats['class_freq']['star'] > 0 or stats['class_freq']['superstar'] > 0 or stats['class_freq']['eff'] > 0
            >>> #assert stats['class_freq']['background'] > 0

        Example:
            >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import *  # NOQA
            >>> import kwcoco_dataloader
            >>> from kwcoco_dataloader.tasks.fusion import datamodules
            >>> num = 1
            >>> datamodule = datamodules.KWCocoVideoDataModule(
            >>>     train_dataset='vidshapes-kwcoco_dataloader', window_dims=64, time_steps=3,
            >>>     num_workers=0, batch_size=3, channels='auto',
            >>>     normalize_inputs=num)
            >>> datamodule.setup('fit')
            >>> self = datamodule.torch_datasets['train']
            >>> coco_dset = self.sampler.dset
            >>> print({c.get('sensor_coarse') for c in coco_dset.images().coco_images})
            >>> print({c.channels.spec for c in coco_dset.images().coco_images})
            >>> num_workers = 0
            >>> batch_size = 6
            >>> s = (self.compute_dataset_stats(num=num))
            >>> print('s = {}'.format(ub.urepr(s, nl=3)))
            >>> stats1 = self.compute_dataset_stats(num=num, with_intensity=False)
            >>> stats2 = self.compute_dataset_stats(num=num, with_class=False)
            >>> stats3 = self.compute_dataset_stats(num=num, with_class=False, with_intensity=False)

        Ignore:
            >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import *  # NOQA
            >>> import kwcoco
            >>> coco_dset = kwcoco.CocoDataset.demo('vidshapes2')
            >>> for img in coco_dset.imgs.values():
            ...     img['sensor_coarse'] = 'demo'  # hack in a sensor
            >>> self = KWCocoVideoDataset(coco_dset, time_dims=1, window_dims=(32, 32), channels='demo:(r|g,b,n)')
            >>> stats = self.compute_dataset_stats(batch_size=1)

        Ignore:
            import sys, ubelt
            sys.path.append(ubelt.expandpath('~/code/watch'))
            globals().update(xdev.get_func_kwargs(KWCocoVideoDataset.compute_dataset_stats))

        """
        num = num if isinstance(num, int) and num is not True else 1000
        if not with_class and not with_intensity:
            num = 1  # efficiency hack

        print('🖥️ 📊 Compute dataset stats')
        stats_idxs = kwarray.shuffle(np.arange(len(self)), rng=0)[0:min(num, len(self))]
        stats_subset = torch.utils.data.Subset(self, stats_idxs)

        # Hack: disable augmentation if we are doing that
        self.disable_augmenter = True

        loader = self.make_loader(subset=stats_subset, num_workers=num_workers,
                                  shuffle=True, batch_size=batch_size)

        # Track moving average of each fused channel stream
        norm_stats = ub.ddict(lambda: kwarray.RunningStats(nan_policy='omit'))

        classes = self.predictable_classes
        num_classes = len(classes)
        bins = np.arange(num_classes + 1)
        total_freq = np.zeros(num_classes, dtype=np.int64)

        sensor_mode_hist = ub.ddict(int)
        video_id_histogram = ub.ddict(int)
        image_id_histogram = ub.ddict(int)

        # Make a list of all unique modes in the dataset.
        # User specifies all of this explicitly now
        unique_sensor_modes = set(
            (s.sensor.spec, s.chans.spec)
            for s in self.input_sensorchan.streams())

        print('unique_sensor_modes = {}'.format(ub.urepr(unique_sensor_modes, nl=1)))
        intensity_dtype = np.float64

        # Ensure instance level frequency data in addition to pixel level
        USE_INSTANCE_LEVEL_CLASS_STATS = 1
        if USE_INSTANCE_LEVEL_CLASS_STATS:
            annots = self.sampler.dset.annots()
            track_ids = annots.lookup('track_id', None)
            cnames = annots.cnames
            trackid_to_cnames = ub.udict(ub.group_items(cnames, track_ids))
            trackid_to_cnames = trackid_to_cnames.map_values(set)
            track_classes = list(ub.flatten(trackid_to_cnames.values()))
            annot_class_freq = ub.udict(ub.dict_hist(cnames)).sorted_keys()
            track_class_freq = ub.udict(ub.dict_hist(track_classes)).sorted_keys()
            print('annot_class_freq = {}'.format(ub.urepr(annot_class_freq, nl=1)))
            print('track_class_freq = {}'.format(ub.urepr(track_class_freq, nl=1)))
        else:
            track_class_freq = None
            annot_class_freq = None

        def current_input_stats():
            """
            Summarizes current stats estimates either for display or for the
            final output.
            """
            modality_input_stats = {}
            for modality, running in norm_stats.items():
                # ensure we have the expected shape
                try:
                    perchan_stats = running.summarize(axis=ub.NoParam, keepdims=True)
                except RuntimeError:
                    perchan_stats = {
                        'mean': np.array([np.nan]),
                        'std': np.array([np.nan]),
                        'min': np.array([np.nan]),
                        'max': np.array([np.nan]),
                        'n': np.array([np.nan]),
                    }
                chan_mean = perchan_stats['mean']
                chan_std = perchan_stats['std']
                chan_min = perchan_stats['min']
                chan_max = perchan_stats['max']
                chan_num = perchan_stats['n']

                # For nans, set the mean to zero and set the std to a huge
                # number if we dont have any data on it. That will prevent
                # the network from doing much with it which is really the
                # best we can do here.
                chan_mean[np.isnan(chan_mean)] = 0
                chan_std[np.isnan(chan_std)] = 1e8

                chan_mean = chan_mean.round(6)
                chan_std = chan_std.round(6)
                # print('perchan_stats = {}'.format(ub.urepr(perchan_stats, nl=1)))
                modality_input_stats[modality] = {
                    'mean': chan_mean,
                    'std': chan_std,
                    'min': chan_min,
                    'max': chan_max,
                    'n': chan_num,
                }

            # We are now computing input stats across a finer set of modality
            # variables. For backwards compatibility also return the old-style
            # input stats that are only over sensor/channel
            grouped_stats = ub.group_items(modality_input_stats.values(), [
                (u.sensor, u.channels) for u in modality_input_stats.keys()])
            old_input_stats = {}
            for (sensor, chan), group in grouped_stats.items():
                means = np.stack([m['mean'] for m in group], axis=0)
                stds = np.stack([m['std'] for m in group], axis=0)
                nums = np.stack([m['n'] for m in group], axis=0)
                maxs = np.stack([m['max'] for m in group], axis=0)
                mins = np.stack([m['min'] for m in group], axis=0)

                combo_mins = np.nanmin(mins, axis=0)
                combo_maxs = np.nanmax(maxs, axis=0)
                combo_mean, combo_std, combo_nums = util_kwarray.combine_mean_stds(
                    means, stds, nums, axis=0)
                combo = {
                    'mean': combo_mean[:, None, None],
                    'std': combo_std[:, None, None],
                    'min': combo_mins[:, None, None],
                    'max': combo_maxs[:, None, None],
                    'n': combo_nums[:, None, None],
                }
                old_input_stats[(sensor, chan)] = combo
            return modality_input_stats, old_input_stats

        def update_displayed_estimates(pman):
            """
            Build an intermediate summary to display in the progress bar
            """
            stat_lines = ['Current Estimated Dataset Statistics: ']
            from kwutil.slugify_ext import smart_truncate
            if with_intensity:
                modality_input_stats, old_input_stats = current_input_stats()
                input_stats2 = {}
                for mode, stats in old_input_stats.items():
                    sensor, channels = mode
                    sensorchan = SensorChanSpec.coerce(f'{sensor}:{channels}')
                    key = sensorchan.concise().spec
                    inner_stats = {}
                    for statname, arr in stats.items():
                        if statname == 'n':
                            arr_str = ub.urepr(arr.ravel().tolist(), nl=0, precision=0)
                        else:
                            arr_str = ub.urepr(arr.ravel().tolist(), nl=0, precision=4)
                        arr_str = smart_truncate(arr_str, max_length=120, trunc_loc=0.5, head='...~', tail='~...')
                        inner_stats[statname] = arr_str
                    if sensorchan.chans.numel() == 1:
                        input_stats2[key] = ub.urepr(inner_stats, sv=1, nl=0)
                    else:
                        input_stats2[key] = ub.urepr(inner_stats, sv=1)
                spectra_stats_text = ub.urepr(input_stats2, sv=1)
                intensity_info_text = 'Spectra Stats: ' + spectra_stats_text
                stat_lines.append(intensity_info_text)
            if with_class:
                class_stats = ub.sorted_vals(ub.dzip(classes, total_freq), reverse=True)
                class_info_text = 'Class Stats: ' + ub.urepr(class_stats)
                stat_lines.append(class_info_text)
            stat_lines.append('Unique Image Samples: {}'.format(len(image_id_histogram)))
            stat_lines.append('Unique Video Samples: {}'.format(len(video_id_histogram)))
            info_text = '\n'.join(stat_lines).strip()
            if info_text:
                # TODO: cleanup once new kwutil is released
                if (hasattr(pman, 'is_rich') and pman.is_rich) or hasattr(pman.backend, 'setup_rich'):
                    from rich.markup import escape
                    info_text = escape(info_text)
                pman.update_info(info_text)

        def update_intensity_estimates(item):
            # Update pixel-level intensity histogram
            domain = item.get('domain', None)
            for frame_item in item['frames']:
                sensor_code = frame_item.get('sensor', None)
                modes = frame_item['modes']

                for mode_code, mode_val in modes.items():
                    # FIXME: we need to avoid using any custom class
                    # in dataset stats because it might be pickled.
                    # We should use raw python and numpy / torch types only.
                    modality = Modality(sensor_code, mode_code, domain)

                    sensor_mode_hist[(sensor_code, mode_code)] += 1
                    running = norm_stats[modality]
                    val = mode_val.numpy().astype(intensity_dtype)
                    weights = np.isfinite(val).astype(intensity_dtype)

                    # Put channels last so we can update multiple at once
                    flat_vals = val.transpose(1, 2, 0).reshape(-1, val.shape[0])
                    flat_weights = weights.transpose(1, 2, 0).reshape(-1, weights.shape[0])
                    running.update_many(flat_vals, weights=flat_weights)

        def update_stats(item, total_freq):
            if with_vidid:
                vidid = item.get('video_id', None)
                video_id_histogram[vidid] += 1

            for frame_item in item['frames']:
                image_id_histogram[frame_item.get('gid', None)] += 1
                if with_class:
                    # Update pixel-level class histogram
                    # TODO: prefer class-ohe if available
                    class_idxs = frame_item['class_idxs']
                    if class_idxs is not None:
                        item_freq = np.histogram(class_idxs.ravel(), bins=bins)[0]
                        total_freq += item_freq

            if with_intensity:
                update_intensity_estimates(item)

        from kwutil import util_progress
        from kwutil import util_environ
        pman = util_progress.ProgressManager()
        # pman = util_progress.ProgressManager('progiter')

        # Create timer to periodically summarize intermediate results while
        # full dataset stats are accumulating
        timer = ub.Timer().tic()
        timer._first = True
        timer.postfix_update_threshold = 5  # seconds

        # TODO: we can compute the intensity histogram more efficiently by
        # only doing it for unique channels (which might be duplicated)
        with pman, warnings.catch_warnings():
            warnings.filterwarnings('ignore', 'invalid value encountered in true_divide', category=RuntimeWarning)
            prog = pman.progiter(loader, desc='estimate dataset stats', verbose=1)
            iter_ = iter(prog)

            for batch_items in iter_:

                for item in batch_items:
                    if item is None:
                        continue
                    update_stats(item, total_freq)

                if timer._first or timer.toc() > timer.postfix_update_threshold:
                    update_displayed_estimates(pman)
                    timer._first = 0
                    timer.tic()

            update_displayed_estimates(pman)

            # TODO: we should ensure we include at least one sample from each type
            # of modality.  Note: the requested order of the channels could be
            # different that what is registered in the dataset. Need to find a good
            # way to account for this.
            MISSING_SENSOR_FALLBACK = util_environ.envflag('MISSING_SENSOR_FALLBACK', 1)
            if MISSING_SENSOR_FALLBACK and with_intensity:
                missing_sensor_modes = set(unique_sensor_modes) - set(sensor_mode_hist)
                # Try to find a few examples with these missing modes
                if missing_sensor_modes:
                    print(f'Warning: we are missing stats for {missing_sensor_modes}. '
                          'We will try to force something for them')
                    coco_images = self.sampler.dset.images().coco_images
                    sensor_to_images = ub.group_items(
                        coco_images,
                        key=lambda x: x.img.get('sensor_coarse', x.img.get('sensor', None))
                    )
                    extra_sample_groups = []
                    for sensor, mode in missing_sensor_modes:
                        candidate_images = sensor_to_images.get(sensor, [])
                        if len(candidate_images) == 0:
                            print(f'sensor warning: unable to sample data for {sensor}:{mode}')
                        else:
                            filtered = []
                            for img in candidate_images:
                                if (img.channels & mode).numel():
                                    filtered.append(img)
                            if not filtered:
                                print(f'mode warning: unable to sample data for {sensor}:{mode}')
                            extra_sample_groups.append(filtered)

                    # Build extra fallback samples
                    for group in ub.ProgIter(extra_sample_groups, desc='process fallbacks'):
                        image_ids = [g.img['id'] for g in group]
                        images = self.sampler.dset.images(image_ids)
                        vidid_to_gids = ub.group_items(image_ids, images.lookup('video_id'))
                        for vidid, gids in vidid_to_gids.items():
                            video = self.sampler.dset.index.videos[vidid]
                            # Hack: just use the entire video, if that fails we should
                            # implement windowing here.
                            space_slice = (
                                slice(0, video['height']),
                                slice(0, video['width']),
                            )
                            sample = {'video_id': vidid, 'gids': gids[0:1],
                                      'space_slice': space_slice}
                            item = self[sample]
                            if item is not None:
                                update_stats(item, total_freq)

        self.disable_augmenter = False

        # Return the raw counts and let the model choose how to handle it
        if with_class:
            class_freq = ub.dzip(classes, total_freq)
        else:
            class_freq = None

        if with_intensity:
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore', 'invalid value encountered in true_divide', category=RuntimeWarning)
                modality_input_stats, old_input_stats = current_input_stats()
        else:
            # modality_input_stats = None
            old_input_stats = None

        dataset_stats = {
            'unique_sensor_modes': unique_sensor_modes,
            'sensor_mode_hist': dict(sensor_mode_hist),
            'input_stats': old_input_stats,
            'class_freq': class_freq,  # pixelwise

            # 'modality_input_stats': modality_input_stats,  # new, might be too much information

            'annot_class_freq': annot_class_freq,
            'track_class_freq': track_class_freq,
            # 'video_id_histogram': dict(video_id_histogram),
        }
        return dataset_stats


class MiscMixin:
    """
    TODO: better groups
    """

    def reseed(self, rng='auto'):
        """
        Reinitialize the random number generator

        TODO:
            HELP WANTED: Lack of determenism likely comes from this module and
            the order it gives data to predict. It would be very nice if we
            could fix that.
        """
        # Randomize across DDP workers
        if getattr(self, 'balanced_sampler', None) is not None:
            if rng == 'auto':
                rng = kwarray.ensure_rng(rng=None)
                import secrets
                import time
                # Really try to be random
                rng_seed = rng.randint(0, int(2 ** 31 - 2))
                rank_seed = int(ub.hash_data(int(os.environ.get('LOCAL_RANK', '0')), base=10)[0:9])
                secret_seed = secrets.randbits(22) + int(time.time())
                seed = secret_seed ^ rank_seed ^ rng_seed
                rng = kwarray.ensure_rng(rng=seed)
            self.balanced_sampler.reseed(rng)

    @property
    def coco_dset(self):
        return self.sampler.dset

    def _setup_predictable_classes(self, predictable_classes: list):
        """
        Currently called twice, once on the original dataset construction, and
        again if the dataset is notified about a model / tasks.
        """
        self.predictable_classes = kwcoco.CategoryTree.coerce(predictable_classes)
        self.num_predictable_classes = len(self.predictable_classes)

        # Map from the "dataset" classes to the "predictable" class indexes.
        self.dataset_class_idx_to_predictable_class_idx = {
            self.classes.node_to_idx[class_name]: self.predictable_classes.node_to_idx[class_name]
            for class_name in self.predictable_classes
        }

        if self.config['default_class_behavior'] == 'background':
            # Ensure that predictable classes updates bg_idx (which is a hacky
            # construct that should be removed)
            predictable_bg_classes = set(self.background_classes) & set(self.predictable_classes)
            assert len(predictable_bg_classes) > 0, 'need to have at least 1 background predictable class'
            bg_catname = ub.peek(sorted(predictable_bg_classes))
            self.bg_idx = self.predictable_classes.node_to_idx[bg_catname]

        heuristics.category_tree_ensure_color(self.predictable_classes)

    def _notify_about_tasks(self, requested_tasks=None, model=None, predictable_classes=None):
        """
        Hacky method. Given the multimodal model, tell all the datasets which
        tasks they will need to generate data for. (This helps make the
        visualizations cleaner).

        TODO:
            Come up with a better protocol for a model to notify the dataloader
            about what it wants. Always let the user override this, but maybe
            the model can warn if the user doesn't give it all of the things it
            thinks it will want?
        """
        if model is not None:
            assert requested_tasks is None, 'requested tasks should be none'
            if hasattr(model, 'global_head_weights'):
                requested_tasks = {k: w > 0 for k, w in model.global_head_weights.items()}
            if hasattr(model, 'predictable_classes'):
                predictable_classes = model.predictable_classes
            else:
                warnings.warn(ub.paragraph(
                    f'''
                    Model {model.__class__} does not have the structure needed
                    to notify the dataset about tasks. A better design to make
                    specifying tasks easier is needed without relying on the
                    ``global_head_weights``.
                    '''))
        print(f'dataset notified: requested_tasks={requested_tasks}, predictable_classes={predictable_classes}')
        if requested_tasks is not None:
            self.requested_tasks.update(requested_tasks)

        if ('class' in self.requested_tasks) and (predictable_classes is not None):
            if set(predictable_classes).issubset(set(self.classes)):
                self._setup_predictable_classes(predictable_classes)
            else:
                print("predictable classes does not intersect classes")
                print('classes= {}'.format(ub.urepr(self.classes.category_names, nl=1)))
                print('predictable_classes= {}'.format(ub.urepr(predictable_classes, nl=1)))
                raise ValueError

    def _build_demo_outputs(self, item):
        """
        Construct dummy outputs that we would expect a network to generate.

        Note:
            The ability to construct this method is a motivating factor behind
            the design decision that "a batch item should describe what its
            expected output should look like".
        """
        fliprot_params = item['target'].get('fliprot_params', None)
        rng = kwarray.ensure_rng(None)
        #
        # Generate random predicted change probabilities for each frame
        item_output = {}
        change_prob_list = []
        for frame in item['frames'][1:]:  # first frame does not have change
            change_prob = kwimage.Heatmap.random(
                dims=frame['output_dims'], classes=1, rng=rng).data['class_probs'][0]
            if fliprot_params:
                change_prob = data_utils.fliprot(change_prob, **fliprot_params)
            change_prob_list += [change_prob]
        change_probs = np.stack(change_prob_list)
        item_output['change_probs'] = change_probs
        #
        # Generate random predicted class probabilities for each frame
        class_prob_list = []
        frame_pred_ltrb_list = []
        for frame in item['frames']:
            class_prob = kwimage.Heatmap.random(
                dims=frame['output_dims'], classes=list(self.classes), rng=rng).data['class_probs']
            class_prob_ = einops.rearrange(class_prob, 'c h w -> h w c')
            if fliprot_params:
                class_prob_ = data_utils.fliprot(class_prob_, **fliprot_params)
            class_prob_list += [class_prob_]
            # Also generate a predicted box for each frame
            frame_output_dsize = frame['output_dims'][::-1]
            num_pred_boxes = rng.randint(0, 8)
            pred_boxes = kwimage.Boxes.random(num_pred_boxes).scale(frame_output_dsize)
            frame_pred_ltrb_list.append(pred_boxes.to_ltrb().data)
        class_probs = np.stack(class_prob_list)
        item_output['class_probs'] = class_probs
        item_output['pred_ltrb'] = frame_pred_ltrb_list
        return item_output

    def make_loader(self, subset=None, batch_size=1, num_workers=0, shuffle=False,
                    pin_memory=False, collate_fn='identity'):
        """
        Use this to make the dataloader so we ensure that we have the right
        worker init function.

        Args:
            subset (None | Dataset): if specified, the loader is made for
                this dataset instead of ``self``.

            collate_fn (callable | str):
                Can be 'identity' or 'stack' or a callable.
                The normal torch default is 'stack', but for heterogeneous
                batch item support, we defaults to 'identity'.

        Example:
            >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import *  # NOQA
            >>> import kwcoco
            >>> coco_dset = kwcoco.CocoDataset.demo('vidshapes2-multispectral', num_frames=5)
            >>> self = KWCocoVideoDataset(coco_dset, time_dims=3, window_dims=(530, 610), channels='auto')
            >>> loader = self.make_loader(batch_size=2)
            >>> batch = next(iter(loader))
        """
        if subset is None:
            dataset = self
        else:
            dataset = subset

        if collate_fn is None:
            collate_fn = ub.identity
        elif isinstance(collate_fn, str):
            if collate_fn == 'identity':
                collate_fn = ub.identity
            elif collate_fn in {'stack', 'torch-default'}:
                import torch.utils.data as torch_data
                collate_fn = torch_data.dataloader.default_collate
            else:
                raise KeyError(collate_fn)

        loader = torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, num_workers=num_workers,
            shuffle=shuffle, pin_memory=pin_memory,
            worker_init_fn=worker_init_fn,
            collate_fn=collate_fn,
        )
        return loader


class BackwardCompatMixin:
    """
    Backwards compatibility for modified properties.
    (These may eventually be deprecated).
    """

    @property
    def new_sample_grid(self):
        ub.schedule_deprecation()
        return self.sample_grid


[docs] class KWCocoVideoDataset(data.Dataset, GetItemMixin, BalanceMixin, PreprocessMixin, IntrospectMixin, MiscMixin, SpacetimeAugmentMixin, BackwardCompatMixin, SMARTDataMixin): """ Accepted keyword arguments are specified in :class:`KWCocoVideoDatasetConfig` Example: >>> # Native Data Sampling >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import * # NOQA >>> import ndsampler >>> import kwcoco >>> import kwcoco_dataloader >>> coco_dset = kwcoco_dataloader.coerce_kwcoco('kwcoco_dataloader-multisensor-msi', geodata=True) >>> print({c.get('sensor_coarse') for c in coco_dset.images().coco_images}) >>> print({c.channels.spec for c in coco_dset.images().coco_images}) >>> sampler = ndsampler.CocoSampler(coco_dset) >>> self = KWCocoVideoDataset(sampler, time_dims=4, window_dims=(100, 200), >>> input_space_scale='native', >>> window_space_scale='0.05GSD', >>> output_space_scale='native', >>> channels='auto', >>> ) >>> self.disable_augmenter = True >>> target = self.sample_grid['targets'][self.sample_grid['positives_indexes'][3]] >>> item = self[target] >>> canvas = self.draw_item(item, overlay_on_image=0, rescale=0, max_channels=3) >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(canvas) >>> kwplot.show_if_requested() Example: >>> # Target GSD Data Sampling >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import * # NOQA >>> import ndsampler >>> import kwcoco >>> import kwcoco_dataloader >>> coco_dset = kwcoco_dataloader.coerce_kwcoco('kwcoco_dataloader', geodata=True) >>> print({c.get('sensor_coarse') for c in coco_dset.images().coco_images}) >>> print({c.channels.spec for c in coco_dset.images().coco_images}) >>> sampler = ndsampler.CocoSampler(coco_dset) >>> self = KWCocoVideoDataset(sampler, window_dims=(100, 100), time_dims=5, >>> input_space_scale='0.35GSD', >>> window_space_scale='0.7GSD', >>> output_space_scale='0.2GSD', >>> channels='auto', >>> ) >>> self.disable_augmenter = True >>> index = self.sample_grid['targets'][self.sample_grid['positives_indexes'][3]] >>> Box = kwimage.Box >>> index['space_slice'] = Box.from_slice(index['space_slice']).translate((30, 0)).quantize().to_slice() >>> item = self[index] >>> #print('item summary: ' + ub.urepr(self.summarize_item(item), nl=3)) >>> canvas = self.draw_item(item, overlay_on_image=1, rescale=0, max_channels=3) >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(canvas) >>> kwplot.show_if_requested() """ __scriptconfig__ = KWCocoVideoDatasetConfig def __init__(self, sampler, mode='fit', test_with_annot_info=False, autobuild=True, **kwargs): """ Args: sampler (kwcoco.CocoDataset | ndsampler.CocoSampler): kwcoco dataset mode (str): fit or predict autobuild (bool): if False, defer potentially expensive initialization. In this case the user must call ``._init()`` **kwargs: see :class:`KWCocoVideoDatasetConfig` for valid options these options will be stored in the ``.config`` attribute. """ config = KWCocoVideoDatasetConfig(**kwargs) # note: sampler can be a ndsampler.CocoSampler or a kwcoco.CocoDataset if config.sampler_backend is None: sampler = ndsampler.CocoSampler.coerce(sampler) else: from kwutil import util_parallel sampler = ndsampler.CocoSampler.coerce( sampler, workdir=config.sampler_workdir, backend=config.sampler_backend) if autobuild: workers = util_parallel.coerce_num_workers(config.sampler_workers) sampler.frames.prepare(workers=workers) chip_dims = config['chip_dims'] if isinstance(chip_dims, str): window_dims = chip_dims else: if not ub.iterable(chip_dims): chip_dims = (chip_dims, chip_dims) chip_h, chip_w = chip_dims window_dims = (chip_h, chip_w) config['chip_dims'] = window_dims self.config = config rich.print('self.config = {}'.format(ub.urepr(self.config, nl=1))) # TODO: remove this line. Reduce the number of top-level attributes and # maintain initialization variables in the config object itself. _cfgdict = self.config.to_dict() self.__dict__.update(_cfgdict) # Make config a normal dictionary to reduce attribute lookup overhead self.config = _cfgdict self.sampler = sampler # Add extra categories if we need to and construct a new classes object graph = self.sampler.classes.graph # Update with heuristics # HACK: Overwrite kwcoco data for _catinfo in heuristics.CATEGORIES: name = _catinfo['name'] exists_flag = name in graph.nodes if not exists_flag and _catinfo.get('required'): graph.add_node(name, **_catinfo) if exists_flag: graph.nodes[name].update(**_catinfo) # from kwutil import util_yaml # positive_labels = util_yaml.Yaml.coerce(config.positive_labels) self.classes = kwcoco.CategoryTree(graph) heuristics.category_tree_ensure_color(self.classes) self.background_classes = set(heuristics.BACKGROUND_CLASSES) & set(graph.nodes) self.negative_classes = set(heuristics.NEGATIVE_CLASSES) & set(graph.nodes) self.ignore_classes = set(heuristics.IGNORE_CLASSNAMES) & set(graph.nodes) self.undistinguished_classes = set(heuristics.UNDISTINGUISHED_CLASSES) & set(graph.nodes) # construct composite classes # the idea is that these specific definitions will be configurable in the future self.non_salient_classes = self.background_classes | self.negative_classes self.salient_ignore_classes = self.ignore_classes # should we remove the ignore classes from salient_classes in the future? # yes self.salient_classes = set(self.classes) - (self.non_salient_classes | self.ignore_classes) # define foreground classes for the class activity head self.class_foreground_classes = set(self.classes) - ( self.background_classes | self.ignore_classes | self.undistinguished_classes) self._setup_predictable_classes(sorted(self.background_classes | self.class_foreground_classes)) self.BACKWARDS_COMPAT_NEG_TO_POS = self.config['neg_to_pos_ratio'] is not None self.disable_augmenter = False self.prenormalizers = None self.augment_rng = kwarray.ensure_rng(None) self.mode = mode self.test_with_annot_info = test_with_annot_info # Used for mutex style losses where there is no data that can be used # to label a pixel. # TODO: need to communicate this value to the loss function, but we # should probably design a clean method of communicating between the # dataset and model first. self.ignore_index = -100 self._init_sensorchan() self._init_robust_normalizers() # hidden option for now (todo: expose this) self.inference_only = False # TODO: better "notification of heads" specification and implementation # TODO: modify these names to be less ambiguous. if self.config['requested_tasks'] == 'auto': # for auto, just use defaults for now. _task_updates = {} else: _task_updates = kwutil.Yaml.coerce(self.config['requested_tasks']) self.requested_tasks = { 'change': True, # Note: this is sequential frame change segmentation. 'class': True, # Note: this is per-frame class segmentation. 'saliency': True, # Note: this is per-frame saliency segmentation. 'boxes': True, # Note: this is per-frame bbox detection. 'nonlocal_class': False, # each frame is assigned non-localized class labels. # outputs is not really a task, it requests the weights needed for # predict-time stitching. 'outputs': mode != 'fit', } self.requested_tasks.update(_task_updates) # Hacks: combinable channels can be visualized as RGB images. # The only reason this is a hack is because of the hardcoded names # otherwise it is a cool feature. self.default_combinable_channels = [ ub.oset(['red', 'green', 'blue']), ub.oset(['Dred', 'Dgreen', 'Dblue']), ub.oset(['r', 'g', 'b']), ub.oset(['impervious', 'forest', 'water']), ub.oset(['baren', 'field', 'water']), ub.oset(['landcover_hidden.0', 'landcover_hidden.1', 'landcover_hidden.2']), ub.oset(['sam.0', 'sam.1', 'sam.2']), ub.oset(['sam.3', 'sam.4', 'sam.5']), ] + heuristics.HUERISTIC_COMBINABLE_CHANNELS if autobuild: self._init() def _init_sensorchan(self): """ Part of initialization that coerces sensorchannel information if it is not provided. """ channels = self.config['channels'] if channels is None or channels == 'auto': # Find reasonable channel defaults if channels is not specified. # Use dataset stats to determine something sensible. print('Channels specified as auto, attempting to introspsect') sensorchan_hist = kwcoco_extensions.coco_channel_stats(self.sampler.dset)['sensorchan_hist'] parts = [] for sensor, chan_hist in sensorchan_hist.items(): for c in chan_hist.keys(): chancode = ChannelSpec.coerce(c).fuse().spec parts.append(f'{sensor}:{chancode}') sensorchans = ','.join(sorted(parts)) sensorchans = SensorChanSpec.coerce(sensorchans) print(f'Automatically determined sensorchans = {ub.urepr(sensorchans, nl=1)}') if len(sensorchan_hist) > 0 and channels is None: # Only warn if not explicitly in auto mode warnings.warn( 'Channels are unspecified, but the dataset has a complex ' 'set of channels with multiple sensors. ' 'Passing an explicit sensorchan spec (via the `channels` ' 'argument would be better.') else: # hack sensorchan_hist = None sensorchans = channels self.sensorchan = SensorChanSpec.coerce(sensorchans).normalize() # The user can specify "dynamic channels", which are computed from # loaded channels on the fly. dynamic_channels_spec = copy.deepcopy(kwutil.Yaml.coerce(self.config['dynamic_channels'])) dynamic_channel_names = [] if dynamic_channels_spec is not None: self._dynamic_channels = DynamicChannels(dynamic_channels_spec) dynamic_channel_names = self._dynamic_channels._channel_names else: self._dynamic_channels = None # handle generic * sensors, the idea is that we find matches # in the dataset that can support the requested channels. if '*' in [s.sensor.spec for s in self.sensorchan.streams()]: # handle * sensor in a way that works with previous models # This code is a little messy and should be cleaned up if sensorchan_hist is None: sensorchan_stats = kwcoco_extensions.coco_channel_stats(self.sampler.dset) sensorchan_hist = sensorchan_stats['sensorchan_hist'] expanded_input_sensorchan_streams = [] for fused_sensorchan in self.sensorchan.streams(): sensor = fused_sensorchan.sensor chans = fused_sensorchan.chans if sensor.spec == '*': non_dynamic_chans = chans - dynamic_channel_names for cand_sensor, cand_chans in sensorchan_hist.items(): for cand_chan_group in cand_chans: cand_chan_group = ChannelSpec.coerce(cand_chan_group).fuse() chan_isect = non_dynamic_chans & cand_chan_group if chan_isect.spec == non_dynamic_chans.spec: expanded_input_sensorchan_streams.append(f'{cand_sensor}:{chans.spec}') break else: expanded_input_sensorchan_streams.append('{}:{}'.format(sensor, chans)) if not expanded_input_sensorchan_streams: print('sensorchan_hist = {}'.format(ub.urepr(sensorchan_hist, nl=1))) raise ValueError(f'The generic sensor * was given, but no data in the kwcoco file matched sensorchan: {self.sensorchan}') self.sensorchan = SensorChanSpec.coerce(','.join( list(ub.unique(expanded_input_sensorchan_streams)))).normalize() # TODO: Clean up this code. _input_sensorchans = [] _sample_sensorchans = [] # Holds cases that need to be dynamically computed, perhaps factor into # DynamicChannels later. self._special_inputs = {} for fused_sensorchan in self.sensorchan.streams(): sensor = fused_sensorchan.sensor chans = fused_sensorchan.chans _stream = chans.as_oset() _sample_stream = _stream.copy() # TODO: Might be interesting to have a feature to compute # specialized bands on the fly. # special_bands = _stream & util_bands.SPECIALIZED_BANDS _input_sensorchan = sensor.spec + ':' + '|'.join(_stream) dynamic_chans = _stream & dynamic_channel_names if dynamic_chans: _sample_stream -= dynamic_chans # Sometimes the input channels will depened on other chanels # that can be sampled, but are not explicitly used in the input # so we need to extend the sample to load those. _other_req = sorted(set(ub.flatten(self._dynamic_channels._lut[n]['args'] for n in dynamic_chans))) _sample_stream = _sample_stream | _other_req _sample_sensorchan = sensor.spec + ':' + '|'.join(_sample_stream) self._special_inputs[_input_sensorchan] = { 'sample_sensorchan': _sample_sensorchan, 'input_sensorchan': _input_sensorchan, 'dynamic_chans': dynamic_chans, } else: _sample_sensorchan = _input_sensorchan _input_sensorchans.append(_input_sensorchan) _sample_sensorchans.append(_sample_sensorchan) #### New: input_sensorchan will replace input_channels self.sample_sensorchan = SensorChanSpec( ','.join(_sample_sensorchans) ) self.input_sensorchan = SensorChanSpec.coerce( ','.join(_input_sensorchans) ) def _init_robust_normalizers(self): """ Handle coercion of normalizer_peritem and normalize_perframe. SeeAlso: _robust_normalize_frame_items """ if self.config['normalize_peritem']: ub.schedule_deprecation( 'kwcoco_dataloader', 'normalize_peritem', 'param', 'use robust_normalize instead', deprecate='0.1.0', error='1.0.0', remove='1.1.0', ) # (FIXME:this probably should be extended to be a sensorchan...) if self.config['normalize_peritem'] is True: # If True, then normalize all known channels # FIXME: input config probably should not be modified outside # of the __post_init__, we can set any resolved config to an # internal variable instead of overwriting the user-specified # value. self.config['normalize_peritem'] = FusedChannelSpec.coerce( '|'.join(sorted(set(ub.flatten([ s.chans.to_list() for s in self.input_sensorchan.streams()]))))) else: normperitem_data = self.config['normalize_peritem'] # HACK: if isinstance(normperitem_data, list): normperitem_data = ','.join(normperitem_data) # Otherwise assume the user specified what channels to normalize self.config['normalize_peritem'] = ChannelSpec.coerce(normperitem_data).fuse() else: self.config['normalize_peritem'] = None if self.config['normalize_perframe'] is None: ub.schedule_deprecation( 'kwcoco_dataloader', 'normalize_perframe', 'param', 'use robust_normalize instead', deprecate='0.1.0', error='1.0.0', remove='1.1.0', ) if self.config['robust_normalize'] not in [None, False]: self.robust_normalizer = RobustNormalizer.coerce( self.config['robust_normalize'], default_sensorchan=self.input_sensorchan, ) else: self.robust_normalizer = None def _init(self): """ The expensive part of initialization that builds the sample grid based on the user input. """ self._build_sample_grid() self._build_prenormalizers() if False: # # HACK TO PUT ALL DATA INTO MEMORY # for gid in ub.ProgIter(self.sampler.dset.images(), desc='prepopulate imdata'): # coco_img = self.sampler.dset.coco_image(gid) # img = coco_img.img # imdata = coco_img.imdelay().finalize() # img['imdata'] = imdata # HACK TO PUT ALL DATA INTO MEMORY bundle_dpath = ub.Path(self.sampler.dset.bundle_dpath) for gid in ub.ProgIter(self.sampler.dset.images(), desc='prepopulate imdata'): coco_img = self.sampler.dset.coco_image(gid) img = coco_img.img imdata = kwimage.imread(bundle_dpath / coco_img.img['file_name']) # imdata = coco_img.imdelay().finalize() img['imdata'] = imdata def _build_sample_grid(self): config = self.config grid_workers = int(os.environ.get('GEOWATCH_GRID_WORKERS', os.environ.get('WATCH_GRID_WORKERS', 0))) common_grid_kw = dict( time_dims=config['time_steps'], window_dims=config['chip_dims'], window_overlap=config['chip_overlap'], exclude_sensors=config['exclude_sensors'], include_sensors=config['include_sensors'], select_images=config['select_images'], select_videos=config['select_videos'], time_sampling=config['time_sampling'], time_span=config['time_span'], time_kernel=config['time_kernel'], window_space_scale=self.config['window_space_scale'], set_cover_algo=config['set_cover_algo'], workers=grid_workers, # could configure this use_cache=self.config['use_grid_cache'], respect_valid_regions=self.config['use_grid_valid_regions'], ) # print('common_grid_kw = {}'.format(ub.urepr(common_grid_kw, nl=1))) grid_kw = common_grid_kw.copy() # Remember this for backwards compat self._old_balance_as_negative_classes = ( self.ignore_classes | self.background_classes | self.negative_classes) annot_helper_kws = dict( negative_classes=self._old_balance_as_negative_classes, keepbound=False, use_annot_info=True, use_centered_positives=config['use_centered_positives'], use_grid_positives=config['use_grid_positives'], use_grid_negatives=config['use_grid_negatives'], ) mode = self.mode if mode == 'custom': sample_grid = None self.length = 1 elif mode == 'test': # FIXME: something is wrong with the cache when using an sqlview. # In test mode we have to sample everything for BAS # (TODO: for activity clf, we should only focus on candidate regions) if self.test_with_annot_info: grid_kw.update(annot_helper_kws) else: grid_kw.update(dict( keepbound=True, use_annot_info=False, )) grid_kw['dynamic_fixed_resolution'] = config['dynamic_fixed_resolution'] builder = spacetime_grid_builder.SpacetimeGridBuilder( dset=self.sampler.dset, **grid_kw ) sample_grid = builder.build() self.length = len(sample_grid['targets']) else: grid_kw.update(annot_helper_kws) builder = spacetime_grid_builder.SpacetimeGridBuilder( self.sampler.dset, **grid_kw ) sample_grid = builder.build() self._init_balance(sample_grid) if self.balanced_sampler is None: self.length = len(sample_grid['targets']) else: self.length = len(self.balanced_sampler) if config['max_epoch_length'] is not None: self.length = min(self.length, config['max_epoch_length']) self.sample_grid = sample_grid def _build_prenormalizers(self): if self.config['prenormalize_inputs'] is not None: prenormalizers = None if self.config['prenormalize_inputs'] is True: # default_prenorm = { # 'modality_stats': 'auto', # } # """ e.g. we expect modality stats to look like: modality_stats: - sensor: S2 channels: red|green|blue domain: KR_R001 month: 3 mean: [120, 231, 233] std: [30, 20, 24] min: [0, 0, 0] max: [10000, 10000, 10000] - ... """ # We generally want to compute these on the full dataset stats = self.cached_dataset_stats(num_workers=4) self.prenormalizers = stats['modality_input_stats'] # prenormalizers = [] # for key, value in stats['prenormalizers'].items(): # item = ub.udict(key._asdict()) | {k: v.ravel() for k, v in value.items()} # prenormalizers.append(item) # ... # raise NotImplementedError('need to compute prenormaliztions') elif isinstance(self.config['prenormalize_inputs'], dict): # TODO: Fixme! ... elif isinstance(self.config['prenormalize_inputs'], list): ... else: raise NotImplementedError if prenormalizers is None: stats = self.cached_dataset_stats(num_workers=4) prenormalizers = stats['modality_input_stats'] self.prenormalizers = prenormalizers def __len__(self): return self.length def __getitem__(self, index): """ Build an input batch. Standard pytorch Dataset API. Args: index (int | Dict): This can be an integer index between ``[0, len(self)]``. In test mode this will correspond to the index in the sample grid, but at train time it is randomized and you will usually get a different item each time. You can pass a "target" dictionary (e.g. an item from the sample grid). Note that the subset of a target needed to rebuild a specific batch is returned with each batch. Returns: Dict | None : In this system an item is always a dictionary it is up to the calling process to do any final collation. (avoiding collation makes writing this module a lot simpler). If the sample fails we return None, and the caller should also handle that. Example: >>> # Native sampling project data doctest >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import * # NOQA >>> import kwcoco_dataloader >>> import kwcoco >>> coco_dset = kwcoco_dataloader.coerce_kwcoco('kwcoco_dataloader-msi-geodata-dates') >>> self = KWCocoVideoDataset( >>> coco_dset, >>> time_dims=5, window_dims=(320, 320), >>> window_overlap=0, >>> input_space_scale='native', >>> window_space_scale='0.3GSD', >>> output_space_scale='0.6GSD', >>> dist_weights=1, >>> quality_threshold=0, >>> neg_to_pos_ratio=0, time_sampling='soft2', >>> ) >>> self.requested_tasks['change'] = False >>> # Find a sample with S2 and L8 images in it. >>> for target in self.sample_grid['targets']: ... sensors = coco_dset.images(target['gids']).lookup('sensor_coarse') ... shist = ub.dict_hist(sensors) ... if len(shist) > 1 and all(v > 1 for v in shist.values()): ... break >>> target['allow_augment'] = False >>> index = target >>> item = self[index] >>> print('item summary: ' + ub.urepr(self.summarize_item(item), nl=3)) >>> # xdoctest: +REQUIRES(--show) >>> canvas = self.draw_item(item, max_channels=10, overlay_on_image=0, rescale=0) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(canvas) >>> kwplot.show_if_requested() Example: >>> # xdoctest: +REQUIRES(env:DVC_DATA_DPATH) >>> # Native sampling project data doctest >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import * # NOQA >>> import kwcoco_dataloader >>> import kwcoco >>> dvc_dpath = kwcoco_dataloader.find_dvc_dpath(tags='phase2_data', hardware='auto') >>> coco_fpath = dvc_dpath / 'Drop6/data_vali_wsmall_split1.kwcoco.zip' >>> coco_dset = kwcoco.CocoDataset(coco_fpath) >>> self = KWCocoVideoDataset( >>> coco_dset, >>> time_dims=5, window_dims=(320, 320), >>> window_overlap=0, >>> channels="(S2,L8):blue|green|red|nir", >>> input_space_scale='native', >>> window_space_scale='10GSD', >>> output_space_scale='native', >>> #output_space_scale='10GSD', >>> dist_weights=1, >>> quality_threshold=0, >>> neg_to_pos_ratio=0, time_sampling='soft2', >>> ) >>> self.requested_tasks['change'] = False >>> # Find a sample with S2 and L8 images in it. >>> for target in self.sample_grid['targets']: ... sensors = coco_dset.images(target['gids']).lookup('sensor_coarse') ... shist = ub.dict_hist(sensors) ... if len(shist) > 1 and all(v > 1 for v in shist.values()): ... break >>> target['allow_augment'] = False >>> index = target >>> item = self[index] >>> print('item summary: ' + ub.urepr(self.summarize_item(item), nl=3)) >>> # xdoctest: +REQUIRES(--show) >>> canvas = self.draw_item(item, max_channels=10, overlay_on_image=0, rescale=0) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(canvas) >>> kwplot.show_if_requested() Example: >>> # xdoctest: +REQUIRES(env:DVC_DATA_DPATH) >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import * # NOQA >>> import kwcoco_dataloader >>> import kwcoco >>> dvc_dpath = kwcoco_dataloader.find_dvc_dpath(tags='phase2_data', hardware='auto') >>> coco_fpath = dvc_dpath / 'Drop6/data_vali_wsmall_split1.kwcoco.zip' >>> coco_dset = kwcoco.CocoDataset(coco_fpath) >>> self = KWCocoVideoDataset( >>> coco_dset, >>> time_dims=5, window_dims=(320, 320), >>> window_overlap=0, >>> channels="(S2,L8):blue|green|red|nir", >>> input_space_scale='10GSD', >>> window_space_scale='10GSD', >>> output_space_scale='10GSD', >>> dist_weights=1, >>> quality_threshold=0, >>> neg_to_pos_ratio=0, time_sampling='soft2', >>> ) >>> self.requested_tasks['change'] = False >>> index = self.sample_grid['targets'][self.sample_grid['positives_indexes'][0]] >>> index['allow_augment'] = False >>> item = self[index] >>> target = item['target'] >>> print('item summary: ' + ub.urepr(self.summarize_item(item), nl=3)) >>> # xdoctest: +REQUIRES(--show) >>> canvas = self.draw_item(item, max_channels=10, overlay_on_image=0, rescale=0) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(canvas) >>> kwplot.show_if_requested() """ try: return self.getitem(index) except FailedSample as ex: if self.config['failed_sample_policy'] == 'raise': raise elif self.config['failed_sample_policy'] == 'warn': warnings.warn('FailedSample: ex = {}'.format(ub.urepr(ex, nl=1))) elif self.config['failed_sample_policy'] == 'ignore': return None else: raise AssertionError(self.config['failed_sample_policy'])
def more_demos(): """ CommandLine: USE_RTREE=1 DVC_DPATH=1 XDEV_PROFILE=1 xdoctest -m kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset more_demos:0 USE_RTREE=0 DVC_DPATH=1 XDEV_PROFILE=1 xdoctest -m kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset more_demos:0 Example: >>> # xdoctest: +REQUIRES(env:DVC_DPATH) >>> # Demo with real data >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import * # NOQA >>> import kwcoco_dataloader >>> import kwcoco >>> dvc_dpath = kwcoco_dataloader.find_dvc_dpath(tags='phase2_data', hardware='auto') >>> coco_fpath = dvc_dpath / 'Drop6/data_vali_split1.kwcoco.zip' >>> coco_dset = kwcoco.CocoDataset(coco_fpath) >>> ##'red|green|blue', >>> self = KWCocoVideoDataset( >>> coco_dset, >>> time_dims=7, window_dims=(196, 196), >>> window_overlap=0, >>> channels="(S2,L8):blue|green|red|nir", >>> input_space_scale='3.3GSD', >>> window_space_scale='3.3GSD', >>> output_space_scale='1GSD', >>> prenormalize_inputs=True, >>> dist_weights=0, >>> quality_threshold=0, >>> neg_to_pos_ratio=0, time_sampling='soft2', >>> ) >>> self.requested_tasks['change'] = 1 >>> self.requested_tasks['saliency'] = 1 >>> self.requested_tasks['class'] = 0 >>> self.requested_tasks['boxes'] = 1 >>> index = self.sample_grid['targets'][self.sample_grid['positives_indexes'][3]] >>> index['allow_augment'] = False >>> item = self[index] >>> target = item['target'] >>> #for idx in range(100): ... # self[idx] >>> print('item summary: ' + ub.urepr(self.summarize_item(item), nl=3)) >>> # xdoctest: +REQUIRES(--show) >>> canvas = self.draw_item(item, max_channels=10, overlay_on_image=0, rescale=1) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(canvas, fnum=1) >>> kwplot.show_if_requested() Example: >>> # xdoctest: +REQUIRES(env:DVC_DPATH) >>> # This shows how you can use the dataloader to sample an arbitrary >>> # spacetime volume. >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import * # NOQA >>> import kwcoco_dataloader >>> import kwcoco >>> dvc_dpath = kwcoco_dataloader.find_dvc_dpath(tags='phase2_data', hardware='auto') >>> #coco_fpath = dvc_dpath / 'Drop4-BAS/data_vali.kwcoco.json' >>> coco_fpath = dvc_dpath / 'Drop6/data_vali_split1.kwcoco.zip' >>> coco_dset = kwcoco.CocoDataset(coco_fpath) >>> ##'red|green|blue', >>> self = KWCocoVideoDataset( >>> coco_dset, >>> time_dims=7, window_dims=(196, 196), >>> window_overlap=0, >>> channels="(S2,L8):blue|green|red|nir", >>> input_space_scale='3.3GSD', >>> window_space_scale='3.3GSD', >>> output_space_scale='1GSD', >>> dist_weights=0, >>> quality_threshold=0, >>> neg_to_pos_ratio=0, time_sampling='soft2', >>> ) >>> self.requested_tasks['change'] = 1 >>> self.requested_tasks['saliency'] = 1 >>> self.requested_tasks['class'] = 0 >>> self.requested_tasks['boxes'] = 1 >>> target = { >>> 'video_id': 3, >>> 'gids': [529, 555, 607, 697, 719, 730, 768], >>> 'main_idx': 3, >>> 'space_slice': (slice(0, 65, None), slice(130, 195, None)), >>> } >>> item = self[target] Example: >>> # xdoctest: +REQUIRES(env:DVC_DPATH) >>> # Tests the hard negative sampling >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import * # NOQA >>> import kwcoco_dataloader >>> import kwcoco >>> dvc_dpath = kwcoco_dataloader.find_dvc_dpath(tags='phase2_data', hardware='auto') >>> coco_fpath = dvc_dpath / 'Drop6-MeanYear10GSD/data.kwcoco.zip' >>> coco_dset = kwcoco.CocoDataset(coco_fpath) >>> ##'red|green|blue', >>> self = KWCocoVideoDataset( >>> coco_dset, >>> time_dims=5, window_dims=(196, 196), >>> window_overlap=0, >>> channels="(S2,L8):blue|green|red", >>> fixed_resolution='10GSD', >>> robust_normalize=True, >>> use_grid_negatives='cleared', >>> use_grid_positives=False, >>> use_centered_positives= True, >>> time_kernel='(-2y,-1y,0,1y,2y)', >>> ) >>> self.requested_tasks['change'] = 1 >>> self.requested_tasks['saliency'] = 1 >>> self.requested_tasks['class'] = 0 >>> self.requested_tasks['boxes'] = 1 >>> # Check that all of the negative regions are from cleared videos >>> videos = self.sampler.dset.videos() >>> vidid_to_cleared = ub.udict(ub.dzip(videos.lookup('id'), videos.lookup('cleared', False))) >>> assert self.config['use_grid_negatives'] == 'cleared' >>> positive_idxs = self.sample_grid['positives_indexes'] >>> negative_idxs = self.sample_grid['negatives_indexes'] >>> targets = self.sample_grid['targets'] >>> negative_video_ids = {targets[x]['video_id'] for x in negative_idxs} >>> positive_video_ids = {targets[x]['video_id'] for x in positive_idxs} >>> assert all(vidid_to_cleared.subdict(negative_video_ids).values()) >>> index = 0 >>> item = self[index] >>> target = item['target'] >>> print('item summary: ' + ub.urepr(self.summarize_item(item), nl=3)) >>> # xdoctest: +REQUIRES(--show) >>> canvas = self.draw_item(item, max_channels=10, overlay_on_image=0, rescale=1) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(canvas, fnum=1) >>> kwplot.show_if_requested() Ignore: >>> self.disable_augmenter = True >>> self.config['mask_low_quality'] = True >>> self.config['force_bad_frames'] = True >>> self.config['resample_invalid_frames'] = 0 >>> index = self.sample_grid['targets'][self.sample_grid['positives_indexes'][int((2.5 * 17594) // 3)]] >>> item1 = self[index] >>> self.config['robust_normalize'] = FusedChannelSpec.coerce('red|green|blue|nir') >>> item2 = self[index] >>> canvas1 = self.draw_item(item1, max_channels=10, overlay_on_image=0, rescale=0, draw_weights=0, draw_truth=0) >>> canvas2 = self.draw_item(item2, max_channels=10, overlay_on_image=0, rescale=0, draw_weights=0, draw_truth=0) >>> kwplot.imshow(canvas1, fnum=3, pnum=(2, 1, 1), title='no norm (per-frame normalized for viz purposes only)') >>> kwplot.imshow(canvas2, fnum=3, pnum=(2, 1, 2), title='per-item normalization (across time)') Example: >>> # Test Sampling with Dynamic Channels >>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import * # NOQA >>> import ndsampler >>> import kwcoco >>> coco_dset = kwcoco.CocoDataset.demo('vidshapes1', num_frames=10) >>> sampler = ndsampler.CocoSampler(coco_dset) >>> self = KWCocoVideoDataset( >>> sampler, >>> time_dims=4, >>> window_dims=(300, 300), >>> channels='g|neg_r', >>> dynamic_channels=ub.codeblock( >>> ''' >>> - name: neg_r >>> expr: -r >>> ''') >>> ) >>> self.disable_augmenter = True >>> index = self.sample_grid['targets'][self.sample_grid['positives_indexes'][0]] >>> item = self[index] >>> assert '*:g|neg_r' in item.sensorchan_histogram >>> # Summarize batch item in text >>> summary = self.summarize_item(item) >>> print('item summary: ' + ub.urepr(summary, nl=2)) >>> # Draw batch item >>> canvas = self.draw_item(item) >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(canvas) >>> kwplot.show_if_requested() """ def worker_init_fn(worker_id): worker_info = torch.utils.data.get_worker_info() # TODO # print('worker_info = {}'.format(ub.urepr(worker_info, nl=1))) self = worker_info.dataset if isinstance(self, torch.utils.data.Subset): self = self.dataset if self.config['reseed_fit_random_generators']: self.reseed() if hasattr(self, 'sampler'): if hasattr(self.sampler.dset, 'connect'): # Reconnect to the backend if we are using SQL # print("SQL CONNECTING") self.sampler.dset.connect(readonly=True) # else: # print("DOES NOT HAVE CONNECT") # else: # print("DOES NOT HAVE SAMPLER") @cache def _space_weights(space_shape): sigma = ( (4.8 * ((space_shape[1] - 1) * 0.5 - 1) + 0.8), (4.8 * ((space_shape[0] - 1) * 0.5 - 1) + 0.8), ) gauss_patch = kwimage.gaussian_patch(space_shape, sigma=sigma) space_weights = kwarray.normalize(gauss_patch, out=gauss_patch) return space_weights # Backwards compat sample_video_spacetime_targets = spacetime_grid_builder.sample_video_spacetime_targets class FailedSample(Exception): """ Used to indicate that a sample should be skipped. """ class Modality(NamedTuple): """ A modality consists of a domain, a sensor, and a FusedChannelSpec """ sensor: str channels: str domain: str class Domain(NamedTuple): """ DO NOT USE. BUT DO NOT REMOVE. NEEDED FOR BACKWARDS COMPAT Use Modality instead. """ sensor: str channels: str video_name: str