import pandas as pd
import numpy as np
import ubelt as ub
import kwplot
import kwimage
from matplotlib.collections import LineCollection
from matplotlib.colors import to_rgba
from matplotlib.dates import date2num
from typing import Literal
import datetime as datetime_mod
[docs]
def viz_sc(region_dpaths, true_site_dpath, true_region_dpath, save_dpath):
"""
Seems broken in 1.0.0
Ignore:
>>> # xdoctest: +REQUIRES(module:iarpa_smart_metrics)
>>> from geowatch.cli.run_metrics_framework import * # NOQA
>>> from geowatch.demo.metrics_demo.generate_demodata import generate_demo_metrics_framework_data
>>> cmdline = 0
>>> base_dpath = ub.Path.appdir('geowatch', 'tests', 'test-iarpa-metrics2')
>>> data_dpath = base_dpath / 'inputs'
>>> dpath = base_dpath / 'outputs'
>>> demo_info3 = generate_demo_metrics_framework_data(
>>> roi='DR_R003',
>>> num_sites=11, num_observations=10, noise=3, p_observe=0.5,
>>> p_transition=0.2, drop_noise=0.3, drop_limit=0.5)
>>> print('demo_info3 = {}'.format(ub.urepr(demo_info3, nl=1)))
>>> out_dpath = dpath / 'region_metrics'
>>> merge_fpath = dpath / 'merged.json'
>>> out_dpath.delete()
>>> kwargs = {
>>> 'pred_sites': demo_info3['pred_site_dpath'],
>>> 'true_region_dpath': demo_info3['true_region_dpath'],
>>> 'true_site_dpath': demo_info3['true_site_dpath'],
>>> 'merge': True,
>>> 'merge_fpath': merge_fpath,
>>> 'out_dir': out_dpath,
>>> }
>>> main(cmdline=False, **kwargs)
>>> #
>>> from geowatch.tasks.metrics.viz_sc_results import * # NOQA
>>> region_dpaths = [save_dpath / 'region_metrics/DR_R003']
>>> save_dpath = merge_fpath.parent
>>> true_site_dpath = demo_info3['true_site_dpath']
>>> true_region_dpath = demo_info3['true_region_dpath']
>>> viz_sc(region_dpaths, true_site_dpath, true_region_dpath, save_dpath)
>>> # TODO: visualize
"""
from geowatch.tasks.metrics.merge_iarpa_metrics import RegionResult
kwplot.autompl()
results = []
for pth in region_dpaths:
try:
results.append(
RegionResult.from_dpath_and_anns_root(
pth, true_site_dpath, true_region_dpath)
)
except FileNotFoundError:
print(f'warning: missing region {pth}')
# merge SC
sc_results = [r for r in results if r.sc_dpath]
# check out:
# kwimage.stack_image
# kwplot.make_legend_image
phs = [ph for r in sc_results if (ph := r.sc_phasetable) is not None]
for ph in ub.ProgIter(phs, desc='visualize sc gantt'):
rid = ph.index.get_level_values('region_id')[0]
# site-level viz
for (site, site_cand), df_ in ph.groupby(['site', 'site_candidate']):
df = df_.droplevel([0, 1, 2])
viz_sc_gantt(
df,
' vs. '.join((site, site_cand)),
((save_dpath / rid).ensuredir() / ('_'.join((site, site_cand)) + '.png'))
)
# region-level viz
df = ph
plot_title = rid
save_fpath = (save_dpath / f'sc_{rid}.png')
viz_sc_multi(ph, plot_title, save_fpath)
# merged viz
print('Visualize merged')
merged_df = pd.concat(phs, axis=0)
viz_sc_multi(
merged_df,
' '.join(merged_df.index.unique(level='region_id')),
(save_dpath / 'sc_merged.png')
)
[docs]
def viz_sc_gantt(df, plot_title, save_fpath):
from geowatch.heuristics import PHASES as phases
plt = kwplot.autoplt()
sns = kwplot.autosns()
# TODO how to pick site boundary?
df = df.apply(lambda s: s.str.split(', '))
df['pred'] = df['pred'].fillna(method='ffill')
# df = df[~df['true'].isna()]
# df['pred'] = df['pred'].fillna('Unknown')
df['pred'] = df['pred'].fillna(method='bfill')
df = df.reset_index()
df['date'] = pd.to_datetime(df['date']).dt.date
df = df.melt(id_vars=['date'])
df = df.dropna()
# split out subsites
subsite_ixs = df['value'].str.len() > 1
var, val = [], []
for _, row in df.loc[subsite_ixs].iterrows():
for v in row['value']:
var.append(row['variable'])
val.append(v)
df = df.loc[~subsite_ixs]
df['value'] = df['value'].str[0]
df = pd.concat(
(df,
pd.DataFrame(dict(variable=var, value=val))),
axis=0,
ignore_index=True,
)
# order hack for relplot
phases_type = pd.api.types.CategoricalDtype(
(['Unknown'] + phases)[::-1], ordered=True)
df['value'] = df['value'].astype(phases_type)
df = df.sort_values(by='value')
# TODO threadsafe
grid = sns.relplot(
data=df,
x='date',
y='value',
hue='variable',
size='variable'
)
plt.title(plot_title)
grid.savefig(save_fpath)
plt.close()
[docs]
def viz_sc_multi(ph, plot_title, save_fpath,
date: Literal['absolute', 'from_start', 'from_active'] = 'absolute',
how: Literal['residual', 'strip'] = 'strip'):
from geowatch.heuristics import PHASES as phases
import geowatch.heuristics
plt = kwplot.autoplt()
sns = kwplot.autosns()
df = ph
# df.index = [df.index.map('{0[0]} {0[1]} {0[2]}'.format), df.index.get_level_values(3)]
# ignore subsites
# TODO how to pick site boundary?
df = df.apply(lambda s: s.str.split(', ').str[0])
# could do just region_id for error bars after fillna
df = df.reset_index()
df['group'] = df[['region_id', 'site', 'site_candidate']].astype('string').agg('\n'.join, axis=1)
df = df.drop(['region_id', 'site', 'site_candidate'], axis=1)
if how == 'residual': # true and pred must be aligned
df['pred'] = (df.groupby('group')['pred']
.fillna(method='ffill')
.fillna('No Activity'))
# df['pred'] = df['pred'].fillna('Unknown')
# df['pred'] = df['pred'].fillna(method='bfill')
df = df[~df['true'].isna()]
# df = df[df['true'] != 'Unknown'] # Unk should be gone from df after this
df['date'] = pd.to_datetime(df['date']) # .dt.date
# must do this before searchsorted
phases_type = pd.api.types.CategoricalDtype(
['Unknown'] + phases, ordered=True)
df['pred'] = df['pred'].astype(phases_type)
df['true'] = df['true'].astype(phases_type)
# assert df.groupby('group')['true'].is_monotonic_increasing.all()
# assert df.groupby('group')['pred'].is_monotonic_increasing.all()
if date == 'absolute': # absolute date
df['date'] = df['date'].dt.date
x_var = 'date'
elif date == 'from_start': # relative date since start
df['date'] = df.groupby('group')['date'].transform(lambda date: date - date.iloc[0])
# df['date'] = df['date'].dt.days.fillna(0)
df['date'] = df['date'].dt.days
x_var = 'days since start'
elif date == 'from_active': # relative date since last NA; requires 1 NA to exist before SP
df = df.groupby('group').apply(_align_start)
df['date'] = df['date'].dt.days
x_var = 'days since final No Activity'
palette = {c['name']: c['color'] for c in geowatch.heuristics.CATEGORIES}
if how == 'residual':
jitter = 0.1
df['diff'] = df['pred'].cat.codes - df['true'].cat.codes
df['diff'] += np.random.uniform(low=-jitter, high=jitter, size=len(df['diff']))
# TODO threadsafe
grid = sns.relplot(
kind='scatter',
data=df,
x='date',
y='diff',
hue='true',
palette=palette,
s=10,
)
grid.map(_add_colored_linesegments,
'date', 'diff', 'true', 'group',)
y_var = 'pred phases ahead of true phase'
elif how == 'strip':
# get tp idxs before reshaping
tps = df.groupby('group').apply(_tp_idxs)
dfm = df.melt(id_vars=['date', 'group'], value_name='phase').dropna()
dfm['phase'] = dfm['phase'].astype(phases_type)
dfm['yval'], ylabels = pd.factorize(dfm['group'])
tps.index = tps.index.map(dict(zip(ylabels, range(len(ylabels)))))
with pd.option_context('mode.chained_assignment', None):
dfm['yval'].loc[dfm['variable'] == 'pred'] -= 0.2
grid = sns.relplot(
kind='scatter',
data=dfm,
x='date',
y='yval',
hue='phase',
palette=palette,
s=10,
facet_kws=dict(legend_out=False),
)
grid.map(_add_colored_linesegments,
'date', 'yval', 'phase', 'yval', palette=palette)
grid.map(_highlight_tp, 'yval', tps=tps)
if len(ylabels) <= 20: # draw site names if they'll be readable
grid.set(yticks=range(len(ylabels)))
grid.set_yticklabels(ylabels, size=4)
# and write them as an index
yregion, ytrue, ypred = np.array(ylabels.str.split('\n').to_list()).T
pd.DataFrame(dict(
index=range(len(ylabels)),
region=yregion,
true=ytrue,
pred=ypred,
)).to_csv(save_fpath.with_suffix('.index.csv'))
y_var = '[region, true, pred]'
# df['group_phase'] = df[['group', 'true']].agg('_'.join, axis=1)
sns.move_legend(grid, 'upper right')
grid.set_axis_labels(x_var=x_var, y_var=y_var)
plt.title(plot_title)
grid.savefig(save_fpath)
plt.close()
def _align_start(grp, phase='Site Preparation', before=True):
grp = grp.sort_values(by='date')
# assert grp['true'].iloc[0] != phase, grp['true']
grp['date'] -= grp.iloc[grp['true'].searchsorted(phase) - int(before)]['date']
return grp
def _tp_idxs(grp):
grp = grp.sort_values(by='date')
grp['pred'] = (grp['pred']
.fillna(method='ffill')
.fillna('No Activity'))
grp = grp[~grp['true'].isna()]
match = grp['pred'] == grp['true']
ixs = grp[match.shift() != match].index.values
x = grp['date'].map(date2num)
blocks = [x.loc[[start, end]] for start, end, matches in zip(ixs, ixs[1:], match[ixs]) if matches]
return blocks
def _highlight_tp(y, **kwargs):
tps = kwargs.get('tps')
sites = y.round().abs().astype(int).unique()
for site in sites:
boxes = kwimage.Boxes([[*xs, site - 0.5, site + 0.5] for xs in tps.loc[site]], format='xxyy')
kwplot.draw_boxes(boxes, color='green', fill=True, lw=0, alpha=0.3)
def _add_colored_linesegments(x, y, phase, units, **kwargs):
"""
Ignore:
x = dfm['date']
y = dfm['yval']
phase = dfm['phase']
units = dfm['yval']
"""
# need args instead of kwargs because of grid.map() weirdness
plt = kwplot.autoplt()
palette = kwargs.get('palette')
if isinstance(x.iloc[0], datetime_mod.date):
x = date2num(x)
_df = pd.DataFrame(dict(
xy=zip(x, y),
hue=pd.Series(phase).map(palette).astype('string').map(to_rgba),
phase=phase, # need to keep this due to float comparisons in searchsorted
units=units,
))
lines = []
colors = []
for unit, grp in _df.groupby('units'):
# drop consecutive dups
ph = grp['phase']
ixs = ph.loc[ph.shift() != ph].index.values
for start, end, hue in zip(
ixs,
ixs[1:].tolist() + [None],
grp['hue'].loc[ixs]
):
line = grp['xy'].loc[start:end].values.tolist()
if len(line) > 0:
lines.append(line)
colors.append(hue)
lc = LineCollection(lines, alpha=0.5, colors=colors)
ax = plt.gca()
ax.add_collection(lc)
# ax.autoscale()