Source code for geowatch.tasks.fusion.datamodules.temporal_sampling.sampler

"""
This module defines our generalized time sampling strategy.

This is used to define our dilated time sampling.

This following doctest illustrates the method on project data.


CommandLine:
    SMART_DATA_DVC_DPATH=1 XDEV_PROFILE=1 xdoctest -m geowatch.tasks.fusion.datamodules.temporal_sampling __doc__:3
    SMART_DATA_DVC_DPATH=1 xdoctest -m geowatch.tasks.fusion.datamodules.temporal_sampling __doc__:3

    xdoctest -m geowatch.tasks.fusion.datamodules.temporal_sampling __doc__:0 --show
    xdoctest -m geowatch.tasks.fusion.datamodules.temporal_sampling __doc__:1 --show

Example:
    >>> # Basic overview demo of the algorithm
    >>> from geowatch.tasks.fusion.datamodules.temporal_sampling.sampler import *  # NOQA
    >>> import geowatch
    >>> dset = geowatch.coerce_kwcoco('geowatch-msi', geodata=True, dates=True, num_frames=16, image_size=(8, 8))
    >>> vidid = dset.dataset['videos'][0]['id']
    >>> self = TimeWindowSampler.from_coco_video(
    >>>     dset, vidid,
    >>>     time_window=5,
    >>>     affinity_type='soft2', time_span='8m', update_rule='distribute',
    >>> )
    >>> # xdoctest: +REQUIRES(--show)
    >>> import kwplot
    >>> plt = kwplot.autoplt()
    >>> self.show_summary(samples_per_frame=3, fnum=3)
    >>> self.show_procedure(fnum=1)
    >>> plt.subplots_adjust(top=0.9)
    >>> kwplot.show_if_requested()

Example:
    >>> # Demo multiple different settings
    >>> from geowatch.tasks.fusion.datamodules.temporal_sampling.sampler import *  # NOQA
    >>> import geowatch
    >>> dset = geowatch.coerce_kwcoco('geowatch-msi', geodata=True, dates=True, num_frames=16, image_size=(8, 8), num_videos=1)
    >>> vidid = dset.dataset['videos'][0]['id']
    >>> self = TimeWindowSampler.from_coco_video(
    >>>     dset, vidid,
    >>>     time_window=7,
    >>>     affinity_type='uniform', time_span='8m', update_rule='',
    >>> )
    >>> # xdoctest: +REQUIRES(--show)
    >>> import kwplot
    >>> plt = kwplot.autoplt()
    >>> self.update_affinity(affinity_type='contiguous')
    >>> self.show_summary(samples_per_frame=3, fnum=1)
    >>> self.show_procedure(fnum=4)
    >>> plt.subplots_adjust(top=0.9)
    >>> self.update_affinity(affinity_type='soft2')
    >>> self.show_summary(samples_per_frame=3, fnum=2)
    >>> self.show_procedure(fnum=5)
    >>> plt.subplots_adjust(top=0.9)
    >>> self.update_affinity(affinity_type='hardish3')
    >>> self.show_summary(samples_per_frame=3, fnum=3)
    >>> self.show_procedure(fnum=6)
    >>> self.update_affinity(affinity_type='uniform')
    >>> self.show_summary(samples_per_frame=3, fnum=3)
    >>> self.show_procedure(fnum=6)
    >>> plt.subplots_adjust(top=0.9)
    >>> kwplot.show_if_requested()

Example:
    >>> # Demo corner case where there are too few observations
    >>> from geowatch.tasks.fusion.datamodules.temporal_sampling.sampler import *  # NOQA
    >>> import geowatch
    >>> dset = geowatch.coerce_kwcoco('geowatch-msi', geodata=True, dates=True, num_frames=1, num_videos=1, image_size=(8, 8))
    >>> vidid = dset.dataset['videos'][0]['id']
    >>> self = TimeWindowSampler.from_coco_video(
    >>>     dset, vidid,
    >>>     time_window=2,
    >>>     affinity_type='hardish3', time_span='1y', update_rule='',
    >>> )
    >>> idxs = self.sample()
    >>> assert idxs == [0, 0]
    >>> # xdoctest: +REQUIRES(--show)
    >>> import kwplot
    >>> plt = kwplot.autoplt()
    >>> self.show_summary(samples_per_frame=3, fnum=1)
    >>> self.show_procedure(fnum=4)
    >>> plt.subplots_adjust(top=0.9)
    >>> kwplot.show_if_requested()

Example:
    >>> # xdoctest: +REQUIRES(env:SMART_DATA_DVC_DPATH)
    >>> from geowatch.tasks.fusion.datamodules.temporal_sampling.sampler import *  # NOQA
    >>> import geowatch
    >>> data_dvc_dpath = geowatch.find_dvc_dpath(tags='phase2_data', hardware='auto')
    >>> coco_fpath = data_dvc_dpath / 'Drop6/imganns-KR_R001.kwcoco.zip'
    >>> dset = geowatch.coerce_kwcoco(coco_fpath)
    >>> vidid = dset.dataset['videos'][0]['id']
    >>> self = TimeWindowSampler.from_coco_video(
    >>>     dset, vidid,
    >>>     time_kernel='-1y,-8m,-2w,0,2w,8m,1y',
    >>>     affinity_type='soft4', update_rule='', deterministic=True
    >>>     #time_window=5,
    >>>     #affinity_type='hardish3', time_span='3m', update_rule='pairwise+distribute', deterministic=True
    >>> )
    >>> idxs = self.sample()
    >>> idxs = self.sample()
    >>> # xdoctest: +REQUIRES(--show)
    >>> import kwplot
    >>> plt = kwplot.autoplt()
    >>> self.show_summary(samples_per_frame=1, fnum=1)
    >>> chosen, info = self.show_procedure(fnum=4, idx=10)
    >>> plt.subplots_adjust(top=0.9)
    >>> kwplot.show_if_requested()

Example:
    >>> # xdoctest: +REQUIRES(env:SMART_DATA_DVC_DPATH)
    >>> from geowatch.tasks.fusion.datamodules.temporal_sampling import *  # NOQA
    >>> import geowatch
    >>> data_dvc_dpath = geowatch.find_dvc_dpath(tags='phase2_data', hardware='auto')
    >>> coco_fpath = data_dvc_dpath / 'Drop6/imganns-KR_R001.kwcoco.zip'
    >>> dset = geowatch.coerce_kwcoco(coco_fpath)
    >>> vidid = dset.dataset['videos'][0]['id']
    >>> self = MultiTimeWindowSampler.from_coco_video(
    >>>     dset, vidid,
    >>>     time_window=11,
    >>>     affinity_type='uniform-soft2-hardish3', update_rule=[''], gamma=2,
    >>>     time_span='6m-1y')
    >>> self.sample()
    >>> # xdoctest: +REQUIRES(--show)
    >>> import kwplot
    >>> kwplot.autosns()
    >>> self.show_summary(3)
    >>> kwplot.show_if_requested()

Example:
    >>> # xdoctest: +SKIP
    >>> # TODO: fix the time kernel
    >>> # Test under / over sample with time kernels
    >>> from geowatch.tasks.fusion.datamodules.temporal_sampling.sampler import *  # NOQA
    >>> import geowatch
    >>> dset = geowatch.coerce_kwcoco('geowatch-msi', geodata=True, dates=True, num_frames=16, image_size=(8, 8))
    >>> vidid = dset.dataset['videos'][0]['id']
    >>> self = TimeWindowSampler.from_coco_video(
    >>>     dset, vidid,
    >>>     time_window=4,
    >>>     affinity_type='uniform', update_rule='', time_kernel='-1y,0,1y',
    >>> )
    >>> self.sample()
"""
import kwarray
import numpy as np
import ubelt as ub
import itertools as it
from dateutil import parser
from .utils import coerce_time_kernel
from .utils import coerce_multi_time_kernel
from .plots import plot_dense_sample_indices
from .plots import plot_temporal_sample_indices
from .plots import show_affinity_sample_process
from .affinity import soft_frame_affinity
from .affinity import hard_frame_affinity
from .affinity import affinity_sample
from .exceptions import TimeSampleError


try:
    from line_profiler import profile
except ImportError:
    profile = ub.identity


[docs] class CommonSamplerMixin:
[docs] @classmethod def from_coco_video(cls, dset, vidid, gids=None, **kwargs): if gids is None: gids = dset.images(video_id=vidid).lookup('id') images = dset.images(gids) try: name = dset.index.videos[ub.peek(images.lookup('video_id'))].get('name', '<no-name?>') except KeyError: name = '<no-name?>' datetimes = [None if date is None else parser.parse(date) for date in images.lookup('date_captured', None)] unixtimes = np.array([np.nan if dt is None else dt.timestamp() for dt in datetimes]) sensors = images.lookup('sensor_coarse', None) kwargs['unixtimes'] = unixtimes kwargs['sensors'] = sensors kwargs['name'] = name self = cls(**kwargs) return self
[docs] @classmethod def from_datetimes(cls, datetimes, time_span='full', affinity_type='soft2', **kwargs): unixtimes = np.array([dt.timestamp() for dt in datetimes]) if isinstance(time_span, str) and time_span == 'full': time_span = max(datetimes) - min(datetimes) kwargs['unixtimes'] = unixtimes kwargs['sensors'] = None kwargs['time_span'] = time_span kwargs['affinity_type'] = affinity_type self = cls(**kwargs) return self
[docs] class MultiTimeWindowSampler(CommonSamplerMixin): """ A wrapper that contains multiple time window samplers with different affinity matrices to increase the diversity of temporal sampling. Example: >>> from geowatch.tasks.fusion.datamodules.temporal_sampling.sampler import * # NOQA >>> import datetime as datetime_mod >>> from datetime import datetime as datetime_cls >>> low = datetime_cls.now().timestamp() >>> high = low + datetime_mod.timedelta(days=365 * 5).total_seconds() >>> rng = kwarray.ensure_rng(0) >>> unixtimes = np.array(sorted(rng.randint(low, high, 32)), dtype=float) >>> sensors = ['a' for _ in range(len(unixtimes))] >>> time_window = 5 >>> self = MultiTimeWindowSampler( >>> unixtimes=unixtimes, sensors=sensors, time_window=time_window, update_rule='pairwise+distribute', >>> #time_span=['2y', '1y', '5m']) >>> time_span='7d-1m', >>> affinity_type='soft2') >>> self.sample() >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autosns() >>> self.show_summary(10) Example: >>> from geowatch.tasks.fusion.datamodules.temporal_sampling.sampler import * # NOQA >>> import datetime as datetime_mod >>> from datetime import datetime as datetime_cls >>> low = datetime_cls.now().timestamp() >>> high = low + datetime_mod.timedelta(days=365 * 5).total_seconds() >>> rng = kwarray.ensure_rng(0) >>> unixtimes = np.array(sorted(rng.randint(low, high, 32)), dtype=float) >>> sensors = ['a' for _ in range(len(unixtimes))] >>> time_window = 5 >>> self = MultiTimeWindowSampler( >>> unixtimes=unixtimes, sensors=sensors, time_window=time_window, update_rule='distribute', >>> time_kernel=['-1y,-3m,0,3m,+1y', '-1m,-1d,0,1d,1m'], >>> affinity_type='soft2') >>> self.sample() >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autosns() >>> self.show_summary(10, show_indexes=1, fnum=1) >>> list(self.sub_samplers.values())[0].show_summary(10, show_indexes=1, fnum=2) >>> list(self.sub_samplers.values())[1].show_summary(10, show_indexes=1, fnum=3) self.subplots_adjust """ def __init__(self, unixtimes, sensors, time_window=None, affinity_type='hard', update_rule='distribute', deterministic=False, gamma=1, time_span=None, time_kernel=None, name='?', allow_fewer=True): """ Args: time_span (List[List[str]]): a list of time spans. e.g. ['2y', '1y', '5m'] time_kernel (List[str]): a list of time kernels. """ if time_span is None: time_span = [None] if isinstance(time_span, str): time_span = time_span.split('-') if isinstance(affinity_type, str): affinity_type = affinity_type.split('-') if isinstance(update_rule, str): update_rule = update_rule.split('-') if len(update_rule) == 0: update_rule = [''] self.time_kernel = coerce_multi_time_kernel(time_kernel) self.sensors = sensors self.unixtimes = unixtimes self.time_window = time_window self.update_rule = update_rule self.affinity_type = affinity_type self.deterministic = deterministic self.gamma = gamma self.name = name self.num_frames = len(unixtimes) self.time_span = time_span self.sub_samplers = {} self.allow_fewer = allow_fewer self._build() def _build(self): for time_span, time_kernel, affinity_type, update_rule in it.product(self.time_span, self.time_kernel, self.affinity_type, self.update_rule): sub_sampler = TimeWindowSampler( unixtimes=self.unixtimes, sensors=self.sensors, time_window=self.time_window, update_rule=update_rule, affinity_type=affinity_type, deterministic=self.deterministic, gamma=self.gamma, name=self.name, time_span=time_span, time_kernel=time_kernel, allow_fewer=self.allow_fewer, ) key = ':'.join([str(time_span), str(time_kernel), affinity_type, update_rule]) self.sub_samplers[key] = sub_sampler self.indexes = sub_sampler.indexes # self.indexes = np.arange(self.affinity.shape[0])
[docs] @profile def sample(self, main_frame_idx=None, include=None, exclude=None, return_info=False, error_level=0, rng=None): """ Chooses a sub-sampler and samples from it. Args: main_frame_idx (int): "main" sample index. include (List[int]): other indexes forced to be included exclude (List[int]): other indexes forced to be excluded return_info (bool): for debugging / introspection error_level (int): See :func:`affinity_sample`. Returns: ndarray | Tuple[ndarray, Dict] """ rng = kwarray.ensure_rng(rng) # FIXME: the selection of the subsampler does not respect # self.deterministic chosen_key = rng.choice(list(self.sub_samplers.keys())) chosen_sampler = self.sub_samplers[chosen_key] try: return chosen_sampler.sample(main_frame_idx, include=include, exclude=exclude, return_info=return_info, rng=rng, error_level=error_level) except TimeSampleError as ex: debug_parts = [ f'{self.name=}', f'{self.affinity_type=}', f'{self.deterministic=}', f'{self.gamma=}', f'{self.time_kernel=}', f'{self.time_span=}', f'{self.num_frames=}', f'{main_frame_idx=}', f'{include=}', f'{exclude=}', ] ex.args = ('\n'.join(list(ex.args) + debug_parts),) raise
@property def affinity(self): """ Approximate combined affinity, for this multi-sampler """ affinity = np.mean(np.stack([ sampler.affinity for sampler in self.sub_samplers.values() ]), axis=0) return affinity @affinity.setter def affinity(self, value): # Define setter for backwards compatibility if len(self.sub_samplers) > 1: raise Exception( 'no way to hack affinity directly when there is ' 'more than one subsampler') sub_sampler = ub.peek(self.sub_samplers.values()) sub_sampler.affinity = value
[docs] def show_summary(self, samples_per_frame=1, show_indexes=0, fnum=1): """ Similar to :func:`TimeWindowSampler.show_summary` """ import kwplot kwplot.autompl() sample_idxs = [] for idx in range(len(self.unixtimes)): for _ in range(samples_per_frame): idxs = self.sample(idx) sample_idxs.append(idxs) sample_idxs = np.array(sample_idxs) title_info = ub.codeblock( f''' name={self.name} affinity_type={self.affinity_type} deterministic={self.deterministic} update_rule={self.update_rule} gamma={self.gamma} ''') pnum_ = kwplot.PlotNums(nCols=2 + show_indexes) fig = kwplot.figure(fnum=fnum, doclf=True) fig = kwplot.figure(fnum=fnum, pnum=pnum_()) ax = fig.gca() affinity = self.affinity kwplot.imshow(affinity, ax=ax) ax.set_title('combined frame affinity') if show_indexes: fig = kwplot.figure(fnum=fnum, pnum=pnum_()) if samples_per_frame < 2: ax = plot_dense_sample_indices(sample_idxs, self.unixtimes, linewidths=0.1) ax.set_aspect('equal') else: ax = plot_dense_sample_indices(sample_idxs, self.unixtimes, linewidths=0.001) kwplot.figure(fnum=fnum, pnum=pnum_()) plot_temporal_sample_indices(sample_idxs, self.unixtimes) fig.suptitle(title_info)
[docs] class TimeWindowSampler(CommonSamplerMixin): """ Object oriented API to produce random temporal samples given a set of keyframes with metadata. This works by computing a pairwise "affinity" NxN matrix for each of the N keyframes. The details of the affinity matrix depend on parameters passed to this object. Intuitively, the value at ``Affinity[i, j]`` represents how much frame-i "wants" to be in the same sample as frame-j. Args: unixtimes (List[int]): list of unix timestamps for each frame sensors (List[str]): list of attributes for each frame time_window (int): number of frames to sample affinity_type (str): Method for computing the affinity matrix for the underlying sampling algorithm. Can be: "soft" - The old generalized random affinity matrix. "soft2" - The new generalized random affinity matrix. "soft3" - The newer generalized random affinity matrix. "hard" - A simplified affinity algorithm. "hardish" - Like hard, but with a blur. "contiguous" - Neighboring frames get high affinity. update_rule (str): "+" separated string that can contain {"distribute", "pairwise"}. See :func:`affinity_sample` for details. gamma (float): Modulates sampling probability. Higher values See :func:`affinity_sample` for details. time_span (Coercible[datetime.timedelta]): The ideal distince in time that frames should be separated in. This is typically a string code. E.g. "1y" is one year. name (str): A name for this object. For developer convinience, has no influence on the algorithm. deterministic (bool): if True, on each step we choose the next timestamp with maximum probability. Otherwise, we randomly choose a timestep, but with probability according to the current distribution. This is an attribute, which can be modified to change behavior (not thread safe). Attributes: main_indexes Example: >>> # xdoctest: +REQUIRES(env:SMART_DATA_DVC_DPATH) >>> import os >>> from geowatch.tasks.fusion.datamodules.temporal_sampling.sampler import * # NOQA >>> import kwcoco >>> import geowatch >>> data_dvc_dpath = geowatch.find_dvc_dpath(tags='phase2_data', hardware='auto') >>> coco_fpath = data_dvc_dpath / 'Drop6/data_vali_split1.kwcoco.zip' >>> dset = kwcoco.CocoDataset(coco_fpath) >>> vidid = dset.dataset['videos'][0]['id'] >>> self = TimeWindowSampler.from_coco_video( >>> dset, vidid, >>> time_window=5, >>> affinity_type='hardish3', time_span='1y', >>> update_rule='distribute') >>> self.deterministic = False >>> self.show_summary(samples_per_frame=1, fnum=1) >>> self.deterministic = True >>> self.show_summary(samples_per_frame=3, fnum=2) """ def __init__(self, unixtimes, sensors, time_window=None, affinity_type='hard', update_rule='distribute', deterministic=False, gamma=1, time_span=None, time_kernel=None, affkw=None, name='?', allow_fewer=True): if isinstance(time_span, str) and time_span == 'None': time_span = None if time_kernel is not None and time_span is not None: raise ValueError('time_span and time_kernel are mutex') self.time_kernel = None if time_kernel is None else coerce_time_kernel(time_kernel) if time_window is None: assert self.time_kernel is not None time_window = len(self.time_kernel) ... self.sensors = sensors self.unixtimes = unixtimes self.time_window = time_window self.update_rule = update_rule self.affinity_type = affinity_type self.deterministic = deterministic self.gamma = gamma self.name = name self.num_frames = len(unixtimes) self.time_span = time_span self.affkw = affkw # extra args to affinity matrix building self.allow_fewer = allow_fewer self.compute_affinity()
[docs] def update_affinity(self, affinity_type=None, update_rule=None): """ Construct the affinity matrix given the current ``affinity_type``. Example: >>> # xdoctest: +REQUIRES(env:SMART_DATA_DVC_DPATH) >>> import os >>> from geowatch.tasks.fusion.datamodules.temporal_sampling.sampler import * # NOQA >>> import geowatch >>> data_dvc_dpath = geowatch.find_dvc_dpath(tags='phase2_data', hardware='auto') >>> coco_fpath = dvc_dpath / 'Drop6/data_vali_split1.kwcoco.zip' >>> dset = kwcoco.CocoDataset(coco_fpath) >>> vidid = dset.dataset['videos'][0]['id'] >>> self = TimeWindowSampler.from_coco_video( >>> dset, vidid, >>> time_window=5, >>> affinity_type='contiguous', >>> update_rule='pairwise') >>> self.deterministic = True >>> self.show_procedure(fnum=1) """ if self.affkw is None: self.affkw = {} if affinity_type is not None: self.affinity_type = affinity_type if update_rule is not None: self.update_rule = update_rule # update_rules = set(update_rule.split('+')) # config = ub.dict_subset({'pairwise': True, 'distribute': True}, update_rules) # do_pairwise = config.get('pairwise', False) # do_distribute = config.get('distribute', False) if self.affinity_type.startswith('uniform'): n = len(self.unixtimes) self.affinity = np.full((n, n), fill_value=1 / n, dtype=np.float32) elif self.affinity_type.startswith('soft'): if self.affinity_type == 'soft': version = 1 else: try: version = int(self.affinity_type[4:]) except ValueError: version = self.affinity_type[4:] # Soft affinity self.affinity = soft_frame_affinity( unixtimes=self.unixtimes, sensors=self.sensors, time_span=self.time_span, time_kernel=self.time_kernel, version=version, **self.affkw)['final'] elif self.affinity_type == 'hard': # Hard affinity self.affinity = hard_frame_affinity(self.unixtimes, self.sensors, time_window=self.time_window, blur=False, time_span=self.time_span, time_kernel=self.time_kernel, **self.affkw) elif self.affinity_type == 'hardish': # Hardish affinity self.affinity = hard_frame_affinity(self.unixtimes, self.sensors, time_window=self.time_window, blur=True, time_kernel=self.time_kernel, time_span=self.time_span, **self.affkw) elif self.affinity_type == 'hardish2': # Hardish affinity self.affinity = hard_frame_affinity(self.unixtimes, self.sensors, time_window=self.time_window, blur=3.0, time_span=self.time_span, time_kernel=self.time_kernel, **self.affkw) elif self.affinity_type == 'hardish3': # Hardish affinity self.affinity = hard_frame_affinity(self.unixtimes, self.sensors, time_window=self.time_window, blur=6.0, time_span=self.time_span, time_kernel=self.time_kernel, **self.affkw) elif self.affinity_type == 'contiguous': # Recovers the original method that we used to sample time. time_window = self.time_window unixtimes = self.unixtimes # Note: cant just use a sliding window because we currently need # and NxN affinity matrix, can remove duplicates at sample time. # This allows us to assume that indexes always correspond with # frames, and each frame has an "ideal" sample. This assumption # might not be necsesary, but other code would need to change to # break it, so we keep it in for now. n_before = time_window // 2 n_after = (time_window - n_before) - 1 num_samples = len(unixtimes) all_indexes = np.arange(num_samples) sample_idxs = [] for idx in all_indexes: start_idx = idx - n_before stop_idx = idx + n_after if stop_idx > num_samples: offset = num_samples - start_idx start_idx -= offset stop_idx -= offset if start_idx < 0: offset = 0 - start_idx start_idx += offset stop_idx += offset sample_idxs.append(all_indexes[start_idx:stop_idx]) sample_idxs = np.array(sample_idxs) # Old, and somewhat better way, but does not give us NxN # time_slider = kwarray.SlidingWindow( # (len(self.unixtimes),), (self.time_window,), stride=(1,), # keepbound=True, allow_overshoot=True) # sample_idxs = np.array([all_indexes[sl] for sl in time_slider]) # Construct the contiguous sliding window affinity. # (Note: an alternate approach would be to give the first # and last frames out-of-bounds padding, so they actually # dont give full affinity. That may be more natural) self.affinity = kwarray.one_hot_embedding( sample_idxs, len(self.unixtimes), dim=1).sum(axis=2) else: raise KeyError(self.affinity_type) self.indexes = np.arange(self.affinity.shape[0])
compute_affinity = update_affinity # TODO: deprecate wrapper for moving APIs __todo__ = ''' def deprecate_alias(func): def _wrapper(*args, **kwargs): ub.schedule_deprecation(...) return func(*args, **kwargs) return _wrapper compute_affinity = deprecate_alias(update_affinity) ''' @property def main_indexes(self): ub.schedule_deprecation( 'geowatch', 'main_indexes', 'use indexes instead', deprecate='now') return self.indexes
[docs] @profile def sample(self, main_frame_idx=None, include=None, exclude=None, return_info=False, error_level=0, rng=None): """ Args: main_frame_idx (int): "main" sample index. include (List[int]): other indexes forced to be included exclude (List[int]): other indexes forced to be excluded return_info (bool): for debugging / introspection error_level (int): See :func:`affinity_sample`. Returns: ndarray | Tuple[ndarray, Dict] Example: >>> # xdoctest: +REQUIRES(env:SMART_DATA_DVC_DPATH) >>> import os >>> import kwcoco >>> from geowatch.tasks.fusion.datamodules.temporal_sampling.sampler import * # NOQA >>> from geowatch.utils.util_data import find_dvc_dpath >>> dvc_dpath = find_dvc_dpath() >>> coco_fpath = dvc_dpath / 'Drop2-Aligned-TA1-2022-02-15/data.kwcoco.json' >>> dset = kwcoco.CocoDataset(coco_fpath) >>> vidid = dset.dataset['videos'][0]['id'] >>> self = TimeWindowSampler.from_coco_video( >>> dset, vidid, >>> time_span='1y', >>> time_window=3, >>> affinity_type='soft2', >>> update_rule='distribute+pairwise') >>> self.deterministic = False >>> self.show_summary(samples_per_frame=1 if self.deterministic else 10, fnum=1) >>> self.show_procedure(fnum=2) Example: >>> import os >>> import kwcoco >>> from geowatch.tasks.fusion.datamodules.temporal_sampling.sampler import * # NOQA >>> import geowatch >>> dset = geowatch.coerce_kwcoco('geowatch-msi', geodata=True, dates=True, num_frames=32, image_size=(32, 32), num_videos=1) >>> vidid = dset.dataset['videos'][0]['id'] >>> self = TimeWindowSampler.from_coco_video( >>> dset, vidid, >>> time_span='1y', >>> time_window=3, >>> affinity_type='soft2', >>> update_rule='distribute+pairwise') >>> self.deterministic = True >>> # xdoctest: +REQUIRES(--show) >>> self.show_summary(samples_per_frame=1 if self.deterministic else 10, fnum=1) >>> self.show_procedure(fnum=2) Ignore: import xdev globals().update(xdev.get_func_kwargs(TimeWindowSampler.sample)) """ if main_frame_idx is None: include_indices = [] else: include_indices = [main_frame_idx] if include is not None: include_indices.extend(include) exclude_indices = exclude affinity = self.affinity size = self.time_window deterministic = self.deterministic update_rule = self.update_rule gamma = self.gamma time_kernel = self.time_kernel unixtimes = self.unixtimes rng = kwarray.ensure_rng(rng) # Ret could be an ndarray | Tuple[ndarray, Dict] ret = affinity_sample( affinity=affinity, size=size, include_indices=include_indices, exclude_indices=exclude_indices, update_rule=update_rule, gamma=gamma, deterministic=deterministic, error_level=error_level, rng=rng, return_info=return_info, time_kernel=time_kernel, unixtimes=unixtimes, allow_fewer=self.allow_fewer, ) return ret
[docs] def show_summary(self, samples_per_frame=1, fnum=1, show_indexes=False, with_temporal=True, compare_determ=True, title_suffix=''): """ Visualize the affinity matrix and two views of a selected sample. Plots a figure with three subfigures. (1) The affinity matrix. (2) A visualization of a random sampled over "index-space". A matrix M, where each row is a sample index, each column is a timestep, ``M[i,j] = 1`` (the cell is colored white) to indicate that a sample-i includes timestep-j. (3) A visualization of the same random sample over "time-space". A plot where x is the time-axis is drawn, and vertical lines indicate the selectable time indexes. For each sample, a horizontal line indicates the timespan of the sample and an "x" denotes exactly which timesteps are included in that sample. Example: >>> # xdoctest: +REQUIRES(env:SMART_DATA_DVC_DPATH) >>> import os >>> from geowatch.tasks.fusion.datamodules.temporal_sampling import * # NOQA >>> from geowatch.utils.util_data import find_dvc_dpath >>> import kwcoco >>> # xdoctest: +REQUIRES(--show) >>> dvc_dpath = find_dvc_dpath() >>> coco_fpath = dvc_dpath / 'Drop6/data_vali_split1.kwcoco.zip' >>> dset = kwcoco.CocoDataset(coco_fpath) >>> video_ids = list(ub.sorted_vals(dset.index.vidid_to_gids, key=len).keys()) >>> vidid = video_ids[2] >>> # Demo behavior over a grid of parameters >>> grid = list(ub.named_product({ >>> 'affinity_type': ['hard', 'soft2', 'hardish3', 'hardish2'], >>> 'update_rule': ['distribute', 'pairwise+distribute'][0:1], >>> #'deterministic': [False, True], >>> 'deterministic': [False], >>> 'time_window': [5], >>> })) >>> import kwplot >>> kwplot.autompl() >>> for idx, kwargs in enumerate(grid): >>> print('kwargs = {!r}'.format(kwargs)) >>> self = TimeWindowSampler.from_coco_video(dset, vidid, **kwargs) >>> self.show_summary(samples_per_frame=30, fnum=idx, show_indexes=False, deterministic=True) """ import kwplot kwplot.autompl() if compare_determ: _prev_deterministic = self.deterministic self.deterministic = False sample_idxs = [] for idx in range(self.affinity.shape[0]): for _ in range(samples_per_frame): idxs = self.sample(idx) sample_idxs.append(idxs) if 0: sample_idxs = np.array(sorted(map(tuple, sample_idxs))) else: sample_idxs = np.array(sample_idxs) title_info = ub.codeblock( f''' name={self.name} affinity_type={self.affinity_type} update_rule={self.update_rule} gamma={self.gamma} ''') + title_suffix # deterministic={self.deterministic} with_mat = True n_temp_plots = (compare_determ + 1) * with_temporal num_subplots = (with_mat + show_indexes + n_temp_plots) pnum_ = kwplot.PlotNums(nCols=num_subplots) fig = kwplot.figure(fnum=fnum, doclf=True) fig = kwplot.figure(fnum=fnum, pnum=pnum_()) ax = fig.gca() kwplot.imshow(self.affinity, ax=ax, cmap='viridis') ax.set_title('frame affinity') if show_indexes: fig = kwplot.figure(fnum=fnum, pnum=pnum_()) if samples_per_frame < 5: ax = plot_dense_sample_indices(sample_idxs, self.unixtimes, linewidths=0.1) ax.set_aspect('equal') else: ax = plot_dense_sample_indices(sample_idxs, self.unixtimes, linewidths=0.001) if with_temporal: kwplot.figure(fnum=fnum, pnum=pnum_()) plot_temporal_sample_indices(sample_idxs, self.unixtimes, sensors=self.sensors, title_suffix=': non-deterministic') if compare_determ: self.deterministic = True sample_idxs = [] for idx in range(self.affinity.shape[0]): idxs = self.sample(idx) sample_idxs.append(idxs) kwplot.figure(fnum=fnum, pnum=pnum_()) plot_temporal_sample_indices(sample_idxs, self.unixtimes, sensors=self.sensors, title_suffix=': deterministic') fig.suptitle(title_info) if compare_determ: self.deterministic = _prev_deterministic
[docs] def show_affinity(self, fnum=3): """ Simple drawing of the affinity matrix. """ import kwplot kwplot.autompl() fig = kwplot.figure(fnum=fnum) ax = fig.gca() kwplot.imshow(self.affinity, ax=ax) ax.set_title('frame affinity')
[docs] def show_procedure(self, idx=None, exclude=None, fnum=2, rng=None): """ Draw a figure that shows the process of performing on call to :func:`TimeWindowSampler.sample`. Each row illustrates an iteration of the algorithm. The left column draws the current indicies included in the sample and the right column draws how that sample (corresponding to the current row) influences the probability distribution for the next row. Example: >>> # xdoctest: +REQUIRES(env:SMART_DATA_DVC_DPATH) >>> import os >>> from geowatch.tasks.fusion.datamodules.temporal_sampling import * # NOQA >>> from geowatch.utils.util_data import find_dvc_dpath >>> dvc_dpath = find_dvc_dpath() >>> coco_fpath = dvc_dpath / 'Drop1-Aligned-L1-2022-01/data.kwcoco.json' >>> dset = kwcoco.CocoDataset(coco_fpath) >>> vidid = dset.dataset['videos'][0]['id'] >>> self = TimeWindowSampler.from_coco_video( >>> dset, vidid, >>> time_window=5, >>> affinity_type='soft2', >>> update_rule='distribute+pairwise') >>> self.deterministic = False >>> self.show_procedure(idx=0, fnum=10) >>> self.show_affinity(fnum=100) for idx in xdev.InteractiveIter(list(range(self.num_frames))): self.show_procedure(idx=idx, fnum=1) xdev.InteractiveIter.draw() self = TimeWindowSampler.from_coco_video(dset, vidid, time_window=5, affinity_type='soft2', update_rule='distribute+pairwise') self.deterministic = True self.show_summary(samples_per_frame=20, fnum=1) self.deterministic = False self.show_summary(samples_per_frame=20, fnum=2) self = TimeWindowSampler.from_coco_video(dset, vidid, time_window=5, affinity_type='hard', update_rule='distribute') self.deterministic = True self.show_summary(samples_per_frame=20, fnum=3) self.deterministic = False self.show_summary(samples_per_frame=20, fnum=4) self = TimeWindowSampler.from_coco_video(dset, vidid, time_window=5, affinity_type='hardish', update_rule='distribute') self.deterministic = True self.show_summary(samples_per_frame=20, fnum=5) self.deterministic = False self.show_summary(samples_per_frame=20, fnum=6) >>> self.show_procedure(fnum=1) >>> self.deterministic = True >>> self.show_procedure(fnum=2) >>> self.show_procedure(fnum=3) >>> self.show_procedure(fnum=4) >>> self.deterministic = False >>> self.show_summary(samples_per_frame=3, fnum=10) Ignore: import xdev globals().update(xdev.get_func_kwargs(TimeWindowSampler.show_procedure)) """ rng = kwarray.ensure_rng(rng) # if idx is None: # idx = self.num_frames // 2 from kwutil.util_time import coerce_timedelta td_kernel = None def _concise_td(delta): total_sec = delta.total_seconds() sign = 1 if total_sec >= 0 else -1 delta = delta * sign extra_days = delta.days total_years, extra_days = divmod(extra_days, 365) if total_years > 0: return f'{sign * total_years}y' if extra_days > 30: total_months, extra_days = divmod(extra_days, 30) if total_months > 0: return f'{sign * total_months}m' if extra_days > 0: return f'{sign * extra_days}d' else: return f'{sign * int(delta.total_seconds())}s' if self.time_kernel is not None: td_kernel = [coerce_timedelta(d) for d in self.time_kernel] approx_code = ','.join([_concise_td(delta) for delta in td_kernel]) else: approx_code = None title_info = ub.codeblock( f''' name={self.name} affinity_type={self.affinity_type} deterministic={self.deterministic} update_rule={self.update_rule} gamma={self.gamma} {approx_code} ''') chosen, info = self.sample(idx, return_info=True, exclude=exclude, rng=rng) info['title_suffix'] = title_info show_affinity_sample_process(chosen, info, fnum=fnum) return chosen, info