from datetime import datetime as datetime_cls
import kwarray
import kwimage
import numpy as np
import ubelt as ub
from .utils import guess_missing_unixtimes
[docs]
def show_affinity_sample_process(chosen, info, fnum=1):
"""
Debugging / demo visualization of the iterative sample algorithm.
For details see :func:`TimeWindowSampler.show_procedure`.
"""
# import seaborn as sns
import kwplot
mask_color = kwimage.Color.coerce('kitware_yellow').as01()
probability_color = kwimage.Color.coerce('kitware_blue').as01()
update_weight_color = kwimage.Color.coerce('kitware_green').as01()
prev_chosen_color = kwimage.Color.coerce('kitware_darkgray').as01()
chosen_text_color = 'black'
chosen_arrow_color = 'orange'
chosen_line_color = 'orange'
# from matplotlib import pyplot as plt
steps = info['steps']
unixtimes = info.get('unixtimes', None)
if unixtimes is not None:
unixtimes = guess_missing_unixtimes(unixtimes)
_include_summary_row = 0
pnum_ = kwplot.PlotNums(nCols=2, nRows=len(steps) + (1 + _include_summary_row))
fig = kwplot.figure(fnum=fnum, doclf=True)
fig = kwplot.figure(pnum=pnum_(), fnum=fnum)
ax = fig.gca()
# initial_weights = info['initial_weights']
# initial_indexes = info['include_indices']
initial_indexes = info['initial_chosen']
# if len(initial_indexes):
idx = initial_indexes[0]
# else:
# idx = None
probs = info['initial_weights']
ymax = probs.max()
xmax = len(probs)
SHOW_UNIXTIMES_IN_TOP_LEFT = 0
if unixtimes is None or not SHOW_UNIXTIMES_IN_TOP_LEFT:
for x_ in initial_indexes:
ax.plot([x_, x_], [0, ymax], color=prev_chosen_color)
ax.plot(np.arange(xmax), probs)
else:
datetimes = np.array([datetime_cls.fromtimestamp(t) for t in unixtimes])
initial_datetimes = datetimes[initial_indexes]
for x_ in initial_datetimes:
ax.plot([x_, x_], [0, ymax], color=prev_chosen_color)
ax.plot(datetimes, probs)
ax.set_title('Initialize included indices')
fig = kwplot.figure(pnum=pnum_())
ax = fig.gca()
initial_mask = info.get('initial_mask', None)
xidxs = np.arange(xmax)
if initial_mask is not None:
ax.fill_between(xidxs, initial_mask, color=mask_color, alpha=0.5)
try:
ax.plot(xidxs, info['initial_update_weights'], color=update_weight_color)
except ValueError:
ax.plot(xidxs, [info['initial_update_weights']] * xmax, color=update_weight_color)
...
if initial_mask is not None:
ax.set_title('Initialize update weights & Mask')
else:
ax.set_title('Initialize update weights')
# kwplot.imshow(kwimage.normalize(affinity), title='Pairwise Affinity')
chosen_so_far = list(initial_indexes)
start_index = list(initial_indexes)
for step_idx, step in enumerate(steps, start=len(initial_indexes)):
fig = kwplot.figure(pnum=pnum_())
ax = fig.gca()
idx = step['next_idx']
probs = step['probs']
ymax = probs.max()
if ymax == 0:
ymax = 1
xmax = len(probs)
x, y = idx, probs[idx]
for x_ in chosen_so_far:
ax.plot([x_, x_], [0, ymax], color=prev_chosen_color)
ax.plot(np.arange(xmax), probs, color=probability_color)
xpos = x + xmax * 0.0 if x < (xmax / 2) else x - xmax * 0.0
ypos = y + ymax * 0.3 if y < (ymax / 2) else y - ymax * 0.3
ax.annotate('chosen', (x, y), xytext=(xpos, ypos), color=chosen_text_color, arrowprops=dict(color=chosen_arrow_color, arrowstyle='->'))
ax.plot([x, x], [0, ymax], color=chosen_line_color)
#ax.annotate('chosen', (x, y), color='black')
ax.set_title('Iteration {}: sample'.format(step_idx))
chosen_so_far.append(idx)
fig = kwplot.figure(pnum=pnum_())
ax = fig.gca()
if step_idx < len(steps):
next_mask = step.get('next_mask', None)
xidxs = np.arange(xmax)
try:
if next_mask is not None:
ax.fill_between(xidxs, next_mask, color=mask_color, alpha=0.5)
ax.plot(xidxs, step['update_weights'], color=update_weight_color)
ax.plot([x, x], [0, step['update_weights'].max()], color=chosen_line_color)
except ValueError:
ax.plot(xidxs, [step['update_weights']] * xmax, color=update_weight_color)
ax.plot([x, x], [0, step['update_weights']], color=chosen_line_color)
if next_mask is not None:
ax.set_title('Iteration {}: update & mask weights'.format(step_idx))
else:
ax.set_title('Iteration {}: update weights'.format(step_idx))
else:
if unixtimes is None:
for x_ in chosen_so_far:
ax.plot([x_, x_], [0, ymax], color=prev_chosen_color)
else:
chosen_unixtimes = unixtimes[chosen_so_far]
chosen_datetimes = np.array([datetime_cls.fromtimestamp(t) for t in chosen_unixtimes])
for x_ in chosen_datetimes:
ax.plot([x_, x_], [0, ymax], color=prev_chosen_color)
ax.set_title('Final sample')
if _include_summary_row:
# This last row is not helpful, don't include it.
affinity = info['affinity']
fig = kwplot.figure(pnum=pnum_())
ax = fig.gca()
for row in affinity[chosen]:
ax.plot(row)
ax.set_title('Chosen affinities')
# kwplot.imshow(kwimage.normalize(), pnum=pnum_(), title='Chosen Affinities')
final_mat = affinity[chosen][:, chosen]
final_mat[np.isnan(final_mat)] = 0
final_mat = kwimage.normalize(final_mat)
kwplot.imshow(final_mat, pnum=pnum_(), title='Final affinities')
title_suffix = info.get('title_suffix', '')
fig.suptitle(f'Sample procedure: {start_index}{title_suffix}')
fig.subplots_adjust(hspace=0.4)
return fig
[docs]
def plot_dense_sample_indices(sample_idxs, unixtimes, title_suffix='', linewidths=0):
"""
Visualization helper
Args:
sample_idxs (List[List[int]] | ArrayLike[ndim=2]):
A list of frame indexes that index into unixtimes.
I.e. multiple samples of frame index groups.
unixtimes (List | ArrayLike[ndim=1] | None):[
An array of unix timestamps corresonding to frame indexes.
If unspecified, then frame indexes are shown directly.
Example:
>>> # xdoctest: +REQUIRES(module:kwplot)
>>> unixtimes = None
>>> sample_idxs = [
>>> [0, 1, 2],
>>> [3, 5, 6],
>>> [2, 3, 6],
>>> ]
>>> plot_dense_sample_indices(sample_idxs, unixtimes)
"""
import seaborn as sns
import pandas as pd
use_datetimes = unixtimes is not None
if not use_datetimes:
max_frame = max([max(s) for s in sample_idxs])
unixtimes = np.arange(max_frame + 1)
num_keyframes = len(unixtimes)
try:
# Fast homogeneous path
dense_sample = kwarray.one_hot_embedding(sample_idxs, num_keyframes, dim=1).sum(axis=2)
except AttributeError:
# Slower heterogeneous path
rows = []
for frame_idxs in sample_idxs:
frame_idxs = np.array(frame_idxs)
row = kwarray.one_hot_embedding(frame_idxs, num_keyframes, dim=0).sum(axis=1)
rows.append(row)
dense_sample = np.array(rows)
unixtimes = guess_missing_unixtimes(unixtimes)
# =====================
# Show Sample Pattern in heatmap
datetimes = np.array([datetime_cls.fromtimestamp(t) for t in unixtimes])
# dates = np.array([datetime_cls.fromtimestamp(t).date() for t in unixtimes])
df = pd.DataFrame(dense_sample)
df.index.name = 'index'
if use_datetimes:
df.columns = pd.to_datetime(datetimes).date
df.columns.name = 'date'
ax = sns.heatmap(data=df, cbar=False, linewidths=linewidths, linecolor='darkgray')
ax.set_title('Sample Indexes' + title_suffix)
ax.set_xlabel('Observation Index')
ax.set_ylabel('Sample Index')
return ax
[docs]
def plot_temporal_sample_indices(sample_idxs, unixtimes=None, sensors=None, title_suffix=''):
"""
Visualization helper
Args:
sample_idxs (List[List[int]]):
A list of frame indexes that index into unixtimes.
I.e. multiple samples of frame index groups.
unixtimes (List | None):
An array of unix timestamps corresonding to frame indexes.
If unspecified, then frame indexes are shown directly.
Example:
>>> # xdoctest: +REQUIRES(module:kwplot)
>>> unixtimes = None
>>> sample_idxs = [
>>> [0, 1, 2],
>>> [3, 5, 6],
>>> [2, 3, 6],
>>> ]
>>> plot_temporal_sample_indices(sample_idxs, unixtimes)
"""
import matplotlib.pyplot as plt
import kwimage
if unixtimes is None:
xlabel = 'Frame Index'
max_frame = max([max(s) for s in sample_idxs])
datetimes = np.arange(max_frame + 1)
else:
xlabel = 'Time'
unixtimes = guess_missing_unixtimes(unixtimes)
datetimes = np.array([datetime_cls.fromtimestamp(t) for t in unixtimes])
# =====================
# Show Sample Pattern WRT to time
ax = plt.gca()
if sensors:
unique_sensors = set(sensors)
unique_colors = kwimage.Color.distinct(len(unique_sensors))
sensor_to_color = ub.dzip(unique_sensors, unique_colors)
colors = [sensor_to_color[s] for s in sensors]
else:
colors = ['darkblue'] * len(datetimes)
# Mark available observation locations
for t, color in zip(datetimes, colors):
ax.plot([t, t], [0, len(sample_idxs) + 1], color=color, alpha=0.5)
# Order the samples along the y-axis
sample_ordering = 'duration'
sample_ordering = 'start_time'
if sample_ordering == 'start_time':
sample_idxs = sorted(sample_idxs, key=lambda x: tuple([min(x), max(x)])) # start time
elif sample_ordering == 'end_time':
sample_idxs = sorted(sample_idxs, key=lambda x: tuple([max(x), min(x)]))
elif sample_ordering == 'duration':
sample_idxs = sorted(sample_idxs, key=lambda x: tuple([max(x) - min(x), min(x), max(x)]))
else:
raise KeyError(sample_ordering)
# Mark specific sample locations
for sample_ypos, sample in enumerate(sample_idxs, start=1):
ax.plot(datetimes[sample], [sample_ypos] * len(sample), '-', marker='.')
ax.set_title('Sample Times' + title_suffix)
ax.set_xlabel(xlabel)
ax.set_ylabel('Sample Index')
return ax
# import matplotlib.dates as mdates
# ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
# ax.xaxis.set_major_locator(mdates.DayLocator(interval=1))
# ax.figure.autofmt_xdate()
[docs]
def plot_temporal_sample(affinity, sample_idxs, unixtimes, sensors=None, fnum=1):
"""
Visualization helper
"""
import kwplot
kwplot.autompl()
# =====================
# Show Sample Pattern in heatmap
kwplot.figure(fnum=fnum, pnum=(2, 1, 1))
plot_dense_sample_indices(sample_idxs, unixtimes, title_suffix='')
# =====================
# Show Sample Pattern WRT to time
kwplot.figure(fnum=fnum, pnum=(2, 1, 2))
plot_temporal_sample_indices(sample_idxs, unixtimes, sensors=sensors)