geowatch.tasks.fusion.datamodules.temporal_sampling.affinity module¶
- geowatch.tasks.fusion.datamodules.temporal_sampling.affinity.affinity_sample(affinity, size, include_indices=None, exclude_indices=None, allow_fewer=False, update_rule='pairwise', gamma=1, deterministic=False, time_kernel=None, unixtimes=None, error_level=2, rng=None, return_info=False, jit=False)[source]¶
Randomly select
size
timesteps from a larger pool based onaffinity
.Given an NxN affinity matrix between frames and an initial set of indices to include, chooses a sample of other frames to complete the sample. Each row and column in the affinity matrix represent a “selectable” timestamp. Given an initial set of
include_indices
that indicate which timesteps must be included in the sample. An iterative process is used to select remaining indices such thatsize
timesteps are returned. In each iteration we choose the “next” timestep based on a probability distribution derived from (1) the affinity matrix (2) the currently included set of indexes and (3) the update rule.- Parameters:
affinity (ndarray) – pairwise affinity matrix
size (int) – Number of sample indices to return
include_indices (List[int]) – Indices that must be included in the sample
exclude_indices (List[int]) – Indices that cannot be included in the sample
allow_fewer (bool) – if True, we will allow fewer than the requested “size” samples to be returned.
update_rule (str) – Modifies how the affinity matrix is used to create the probability distribution for the “next” frame that will be selected. a “+” separated string of codes which can contain:
- pairwise - if included, each newly chosen sample will
modulate the initial “main” affinity with it’s own affinity. Otherwise, only the affinity of the initially included rows are considered.
- distribute - if included, every step of weight updates will
downweight samples temporally close to the most recently selected sample.
gamma (float, default=1.0) – Exponent that modulates the probability distribution. Lower gamma will “flatten” the probability curve. At gamma=0, all frames will be equally likely regardless of affinity. As gamma -> inf, the rule becomes more likely to sample the maximum probability at each timestep. In the limit this becomes equivalent to
deterministic=True
.deterministic (bool) – if True, on each step we choose the next timestamp with maximum probability. Otherwise, we randomly choose a timestep, but with probability according to the current distribution.
error_level (int) – Error and fallback behavior if perfect sampling is not possible. error level 0:
might return excluded, duplicate indexes, or 0-affinity indexes if everything else is exhausted.
- error level 1:
duplicate indexes will raise an error
- error level 2:
duplicate and excluded indexes will raise an error
- error level 3:
duplicate, excluded, and 0-affinity indexes will raise an error
rng (Coercible[RandomState]) – random state for reproducible sampling
return_info (bool) – If True, includes a dictionary of information that details the internal steps the algorithm took.
jit (bool) – NotImplemented - do not use
time_kernel (ndarray) – if specified, the sample will attempt to conform to this time kernel.
- Returns:
The
chosen
indexes for the sample, or if return_info is True, then returns a tuple ofchosen
and the info dictionary.- Return type:
ndarray | Tuple[ndarray, Dict]
- Raises:
TimeSampleError – if sampling is impossible
- Possible Related Work:
Random Stratified Sampling Affinity Matrix
A quasi-random sampling approach to image retrieval
CommandLine
xdoctest -m geowatch.tasks.fusion.datamodules.temporal_sampling.affinity affinity_sample:0 --show xdoctest -m geowatch.tasks.fusion.datamodules.temporal_sampling.affinity affinity_sample:1 --show
Example
>>> from geowatch.tasks.fusion.datamodules.temporal_sampling import * # NOQA >>> from geowatch.tasks.fusion.datamodules.temporal_sampling.affinity import * # NOQA >>> low = datetime_cls.now().timestamp() >>> high = low + datetime_mod.timedelta(days=365 * 5).total_seconds() >>> rng = kwarray.ensure_rng(0) >>> unixtimes = np.array(sorted(rng.randint(low, high, 113)), dtype=float) >>> # >>> affinity = soft_frame_affinity(unixtimes, version=2, time_span='1d')['final'] >>> include_indices = [5] >>> size = 5 >>> chosen, info = affinity_sample(affinity, size, include_indices, update_rule='pairwise', >>> return_info=True, deterministic=True) >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> from geowatch.tasks.fusion.datamodules.temporal_sampling.plots import show_affinity_sample_process >>> sns = kwplot.autosns() >>> plt = kwplot.autoplt() >>> show_affinity_sample_process(chosen, info) >>> kwplot.show_if_requested()
Example
>>> from geowatch.tasks.fusion.datamodules.temporal_sampling import * # NOQA >>> low = datetime_cls.now().timestamp() >>> high = low + datetime_mod.timedelta(days=365 * 5).total_seconds() >>> rng = kwarray.ensure_rng(0) >>> unixtimes = np.array(sorted(rng.randint(low, high, 5)), dtype=float) >>> self = TimeWindowSampler(unixtimes, sensors=None, time_window=4, >>> affinity_type='soft2', time_span='0.3y', >>> update_rule='distribute+pairwise', allow_fewer=False) >>> self.deterministic = False >>> import pytest >>> with pytest.raises(IndexError): >>> self.sample(0, exclude=[1, 2, 4], error_level=3) >>> with pytest.raises(IndexError): >>> self.sample(0, exclude=[1, 2, 4], error_level=2) >>> self.sample(0, exclude=[1, 2, 4], error_level=1) >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> chosen, info = self.show_procedure(idx=0, fnum=10, exclude=[1, 2, 4]) >>> print('info = {}'.format(ub.urepr(info, nl=4))) >>> kwplot.show_if_requested()
Example
>>> # xdoctest: +REQUIRES(env:SMART_DATA_DVC_DPATH) >>> from geowatch.tasks.fusion.datamodules.temporal_sampling import * # NOQA >>> from geowatch.tasks.fusion.datamodules.temporal_sampling.utils import coerce_time_kernel >>> import kwarray >>> import geowatch >>> data_dvc_dpath = geowatch.find_dvc_dpath(tags='phase2_data', hardware='auto') >>> coco_fpath = data_dvc_dpath / 'Drop6/imgonly-KR_R001.kwcoco.json' >>> dset = geowatch.coerce_kwcoco(coco_fpath) >>> vidid = dset.dataset['videos'][0]['id'] >>> time_kernel_code = '-3m,-1w,0,3m,1y' >>> self = TimeWindowSampler.from_coco_video( >>> dset, vidid, >>> time_window=5, >>> time_kernel=time_kernel_code, >>> affinity_type='soft3', >>> update_rule='') >>> self.deterministic = False >>> self.show_affinity() >>> include_indices = [len(self.unixtimes) // 2] >>> exclude_indices = [] >>> affinity = self.affinity >>> size = self.time_window >>> deterministic = self.deterministic >>> update_rule = self.update_rule >>> unixtimes = self.unixtimes >>> gamma = self.gamma >>> time_kernel = self.time_kernel >>> rng = kwarray.ensure_rng(None) >>> deterministic = True >>> return_info = True >>> error_level = 2 >>> chosen, info = affinity_sample( >>> affinity=affinity, >>> size=size, >>> include_indices=include_indices, >>> exclude_indices=exclude_indices, >>> update_rule=update_rule, >>> gamma=gamma, >>> deterministic=deterministic, >>> error_level=error_level, >>> rng=rng, >>> return_info=return_info, >>> time_kernel=time_kernel, >>> unixtimes=unixtimes, >>> ) >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> info['title_suffix'] = chr(10) + time_kernel_code >>> from geowatch.tasks.fusion.datamodules.temporal_sampling.plots import show_affinity_sample_process >>> show_affinity_sample_process(chosen, info, fnum=1) >>> kwplot.show_if_requested()
- geowatch.tasks.fusion.datamodules.temporal_sampling.affinity.make_soft_mask(time_kernel, relative_unixtimes)[source]¶
Assign probabilities to real observations based on an ideal time kernel
- Parameters:
time_kernel (ndarray) – A list of relative seconds in the time kernel. Each element in this list is referred to as a “kernel entry”.
relative_unixtimes (ndarray) – A list of available unixtimes corresponding to real observations. These should be relative to an “ideal” center. I.e. the “main” observation the kernel is centered around should have a relative unixtime of zero.
- Returns:
A tuple of (kernel_masks, kernel_attrs). For each element in the time kernel there is a corresponding entry in the output kernel_masks and kernel_attrs list, with the former being a probability assigned to each observation for that particular kernel entry, and the latter is a dictionary of information about that kernel entry.
- Return type:
Example
>>> # Generates the time kernel visualization >>> from geowatch.tasks.fusion.datamodules.temporal_sampling.affinity import * # NOQA >>> time_kernel = coerce_time_kernel('-1H,-5M,0,5M,1H') >>> relative_unixtimes = coerce_time_kernel('-90M,-70M,-50M,0,1sec,10S,30M') >>> # relative_unixtimes = coerce_time_kernel('-90M,-70M,-50M,-20M,-10M,0,1sec,10S,30M,57M,87M') >>> kernel_masks, kernel_attrs = make_soft_mask(time_kernel, relative_unixtimes) >>> # >>> min_t = min(kattr['left'] for kattr in kernel_attrs) >>> max_t = max(kattr['right'] for kattr in kernel_attrs) >>> # xdoctest: +REQUIRES(--show) >>> import kwimage >>> import kwplot >>> plt = kwplot.autoplt() >>> kwplot.figure(fnum=1, doclf=1) >>> kernel_color = kwimage.Color.coerce('kitware_green').as01() >>> obs_color = kwimage.Color.coerce('kitware_blue').as01() >>> # >>> kwplot.figure(fnum=1, pnum=(2, 1, 1)) >>> plt.plot(time_kernel, [0] * len(time_kernel), '-o', color=kernel_color, label='kernel') >>> # >>> for kattr in kernel_attrs: >>> rv = kattr['rv'] >>> xs = np.linspace(min_t, max_t, 1000) >>> ys = rv.pdf(xs) >>> ys_norm = ys / ys.sum() >>> plt.plot(xs, ys_norm) >>> # >>> ax = plt.gca() >>> ax.legend() >>> ax.set_xlabel('time') >>> ax.set_ylabel('ideal probability') >>> ax.set_title('ideal kernel') >>> # >>> kwplot.figure(fnum=1, pnum=(2, 1, 2)) >>> plt.plot(relative_unixtimes, [0] * len(relative_unixtimes), '-o', color=obs_color, label='observation') >>> ax = plt.gca() >>> # >>> for kattr in kernel_attrs: >>> rv = kattr['rv'] >>> xs = relative_unixtimes >>> ys = rv.pdf(xs) >>> ys_norm = ys / ys.sum() >>> plt.plot(xs, ys_norm) >>> ax.legend() >>> ax.set_xlabel('time') >>> ax.set_ylabel('sample probability') >>> ax.set_title('discrete observations') >>> plt.subplots_adjust(top=0.9, hspace=.3) >>> kwplot.show_if_requested()
Example
>>> from geowatch.tasks.fusion.datamodules.temporal_sampling.affinity import * # NOQA >>> time_kernel = coerce_time_kernel('-1H,-5M,0,5M,1H') >>> relative_unixtimes = [np.nan] * 10 >>> # relative_unixtimes = coerce_time_kernel('-90M,-70M,-50M,-20M,-10M,0,1sec,10S,30M,57M,87M') >>> kernel_masks, kernel_attrs = make_soft_mask(time_kernel, relative_unixtimes)
- geowatch.tasks.fusion.datamodules.temporal_sampling.affinity.hard_time_sample_pattern(unixtimes, time_window, time_kernel=None, time_span=None)[source]¶
Finds hard time sampling indexes
- Parameters:
unixtimes (ndarray) – list of unix timestamps indicating available temporal samples
time_window (int) – number of frames per sample
References
https://docs.google.com/presentation/d/1GSOaY31cKNERQObl_L3vk0rGu6zU7YM_ZFLrdksHSC0/edit#slide=id.p
Example
>>> low = datetime_cls.now().timestamp() >>> high = low + datetime_mod.timedelta(days=365 * 5).total_seconds() >>> rng = kwarray.ensure_rng(0) >>> base_unixtimes = np.array(sorted(rng.randint(low, high, 20)), dtype=float) >>> unixtimes = base_unixtimes.copy() >>> #unixtimes[rng.rand(*unixtimes.shape) < 0.1] = np.nan >>> time_window = 5 >>> sample_idxs = hard_time_sample_pattern(unixtimes, time_window, time_span='2y') >>> name = 'demo-data'
>>> #unixtimes[:] = np.nan >>> time_window = 5 >>> sample_idxs = hard_time_sample_pattern(unixtimes, time_window, time_span='2y') >>> name = 'demo-data'
- geowatch.tasks.fusion.datamodules.temporal_sampling.affinity.soft_frame_affinity(unixtimes, sensors=None, time_kernel=None, time_span=None, version=1, heuristics='default')[source]¶
Produce a pairwise affinity weights between frames based on a dilated time heuristic.
Example
>>> from geowatch.tasks.fusion.datamodules.temporal_sampling.affinity import * # NOQA >>> low = datetime_mod.datetime.now().timestamp() >>> high = low + datetime_mod.timedelta(days=365 * 5).total_seconds() >>> rng = kwarray.ensure_rng(0) >>> base_unixtimes = np.array(sorted(rng.randint(low, high, 113)), dtype=float)
>>> # Test no missing data case >>> unixtimes = base_unixtimes.copy() >>> allhave_weights = soft_frame_affinity(unixtimes, version=2) >>> # >>> # Test all missing data case >>> unixtimes = np.full_like(unixtimes, fill_value=np.nan) >>> allmiss_weights = soft_frame_affinity(unixtimes, version=2) >>> # >>> # Test partial missing data case >>> unixtimes = base_unixtimes.copy() >>> unixtimes[rng.rand(*unixtimes.shape) < 0.1] = np.nan >>> anymiss_weights_1 = soft_frame_affinity(unixtimes, version=2) >>> unixtimes = base_unixtimes.copy() >>> unixtimes[rng.rand(*unixtimes.shape) < 0.5] = np.nan >>> anymiss_weights_2 = soft_frame_affinity(unixtimes, version=2) >>> unixtimes = base_unixtimes.copy() >>> unixtimes[rng.rand(*unixtimes.shape) < 0.9] = np.nan >>> anymiss_weights_3 = soft_frame_affinity(unixtimes, version=2)
>>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autoplt() >>> pnum_ = kwplot.PlotNums(nCols=5) >>> kwplot.figure(fnum=1, doclf=True) >>> # kwplot.imshow(kwarray.normalize(daylight_weights)) >>> kwplot.imshow(kwarray.normalize(allhave_weights['final']), pnum=pnum_(), title='no missing dates') >>> kwplot.imshow(kwarray.normalize(anymiss_weights_1['final']), pnum=pnum_(), title='any missing dates (0.1)') >>> kwplot.imshow(kwarray.normalize(anymiss_weights_2['final']), pnum=pnum_(), title='any missing dates (0.5)') >>> kwplot.imshow(kwarray.normalize(anymiss_weights_3['final']), pnum=pnum_(), title='any missing dates (0.9)') >>> kwplot.imshow(kwarray.normalize(allmiss_weights['final']), pnum=pnum_(), title='all missing dates')
>>> import pandas as pd >>> sns = kwplot.autosns() >>> fig = kwplot.figure(fnum=2, doclf=True) >>> kwplot.imshow(kwarray.normalize(allhave_weights['final']), pnum=(1, 3, 1), title='pairwise affinity') >>> row_idx = 5 >>> df = pd.DataFrame({k: v[row_idx] for k, v in allhave_weights.items()}) >>> df['index'] = np.arange(df.shape[0]) >>> data = df.drop(['final'], axis=1).melt(['index']) >>> kwplot.figure(fnum=2, pnum=(1, 3, 2)) >>> sns.lineplot(data=data, x='index', y='value', hue='variable') >>> fig.gca().set_title('Affinity components for row={}'.format(row_idx)) >>> kwplot.figure(fnum=2, pnum=(1, 3, 3)) >>> sns.lineplot(data=df, x='index', y='final') >>> fig.gca().set_title('Affinity components for row={}'.format(row_idx))
Example
>>> # xdoctest: +REQUIRES(env:SMART_DATA_DVC_DPATH) >>> from geowatch.tasks.fusion.datamodules.temporal_sampling import * # NOQA >>> import geowatch >>> import kwimage >>> data_dvc_dpath = geowatch.find_dvc_dpath(tags='phase2_data', hardware='auto') >>> coco_fpath = data_dvc_dpath / 'Drop6/imgonly-KR_R001.kwcoco.json' >>> dset = geowatch.coerce_kwcoco(coco_fpath) >>> vidid = dset.dataset['videos'][0]['id'] >>> self = TimeWindowSampler.from_coco_video(dset, vidid, time_window=5, time_kernel='-1y,-3m,0,3m,1y', affinity_type='soft3') >>> unixtimes = self.unixtimes >>> sensors = self.sensors >>> time_kernel = self.time_kernel >>> time_span = None >>> version = 4 >>> heuristics = 'default' >>> weights = soft_frame_affinity(unixtimes, sensors, time_kernel, time_span, version, heuristics) >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autoplt() >>> pnum_ = kwplot.PlotNums(nCols=5) >>> kwplot.figure(fnum=1, doclf=True) >>> kwplot.imshow(kwarray.normalize(weights['final']), pnum=pnum_(), title='all missing dates')
>>> import pandas as pd >>> sns = kwplot.autosns() >>> fig = kwplot.figure(fnum=2, doclf=True) >>> kwplot.imshow(weights['final'], pnum=(1, 3, 1), title='pairwise affinity', cmap='viridis') >>> row_idx = 200 >>> df = pd.DataFrame({k: v[row_idx] for k, v in weights.items()}) >>> df['index'] = np.arange(df.shape[0]) >>> data = df.drop(['final'], axis=1).melt(['index']) >>> kwplot.figure(fnum=2, pnum=(1, 3, 2)) >>> sns.lineplot(data=data, x='index', y='value', hue='variable') >>> fig.gca().set_title('Affinity components for row={}'.format(row_idx)) >>> kwplot.figure(fnum=2, pnum=(1, 3, 3)) >>> sns.lineplot(data=df, x='index', y='final') >>> fig.gca().set_title('Affinity components for row={}'.format(row_idx))