import kwimage
import kwcoco
import numpy as np
import pandas as pd
import itertools
from geowatch.heuristics import CNAMES_DCT
[docs]
def visualize_videos(pred_dset,
out_dir,
true_dset=None):
bas_mode = NotImplemented
if bas_mode:
keys = keys_to_score_bas
# draw videos (regions) separately
pass
else:
keys = keys_to_score_sc
# draw videos (sites) together within a region
from geowatch.cli import coco_visualize_videos
def add_panoptic_img(pred_dset, true_dset):
'''
handle cross product of these keys:
Active Construction 4500
No Activity 713
Post Construction 13179
Site Preparation 1042
Unknown 1499
background 0
ignore 20755
negative 7455
positive 1287
and decide when to merge sites into regions
'''
# TODO
return []
if true_dset is not None:
pan_key = add_panoptic_img(pred_dset, true_dset)
keys += pan_key
coco_visualize_videos(
src=pred_dset,
space='video',
viz_dpath=out_dir,
channels=keys,
any3=True,
draw_anns=False,
animate=True,
zoom_to_tracks=False,
stack=True,
)
keys_to_score_bas = kwcoco.FusedChannelSpec.coerce('salient')
keys_to_score_sc = kwcoco.FusedChannelSpec.coerce(
'|'.join(CNAMES_DCT['positive']['scored']))
# def chans_intersect(c1: kwcoco.ChannelSpec, c2: kwcoco.ChannelSpec) -> kwcoco.FusedChannelSpec:
[docs]
def are_bas_dct(dset):
'''
This isn't needed because BAS annots will get normalized to SC anyway.
Assumes:
- every image is in a video
- every video has either only BAS tracks or only SC tracks
Returns:
Dict[video_id, True if BAS else SC]
'''
bas_cnames = keys_to_score_bas.code_list().to_set()
sc_cnames = keys_to_score_sc.code_list().to_set()
vids = dset.videos()
# vid_names = vids.lookup('name') # match on region/site
# are_region = [p['type'] == 'region' for p in vids.lookup('properties')]
# region = [p['region_id'] for p in vids.lookup('properties')]
are_bas = []
for images in vids.images:
cnames = set(itertools.chain.from_iterable(a.cnames for a in images.annots))
is_bas = cnames.issubset(bas_cnames)
is_sc = cnames.issubset(sc_cnames)
if is_bas == is_sc:
print('WARNING: multiple or unknown track types in video!')
are_bas.append(is_bas)
return dict(zip(vids.lookup('id'), are_bas))
[docs]
def viz_track_scores(dset, out_fpath, gt_dset=None):
# import json
import geowatch
import kwplot
from matplotlib.collections import LineCollection
from matplotlib.colors import to_rgba
plt = kwplot.autoplt()
sns = kwplot.autosns()
# choose img channels to score
are_bas_imgs = []
are_sc_imgs = []
for i in dset.images().coco_images:
f = i.channels.fuse()
are_bas_imgs.append(f.intersection(keys_to_score_bas).numel() == keys_to_score_bas.numel())
are_sc_imgs.append(f.intersection(keys_to_score_sc).numel() == keys_to_score_sc.numel())
assert (sum(are_bas_imgs) > 0 or sum(are_sc_imgs) > 0), 'no valid channels to score!'
keys = (keys_to_score_bas if sum(are_bas_imgs) > sum(are_sc_imgs) else keys_to_score_sc)
if gt_dset is not None:
# have parallel keys to 'orig' for gt and post-vit
NotImplemented
# true_feats = json.load(open(f'gt_site_models/{track_id}.geojson'))['features'][1:]
# true_labels = [f['properties']['current_phase'] for f in true_feats]
# true_dates = [f['properties']['observation_date'] for f in true_feats]
# true_dates = pd.to_datetime(true_dates).date
# detailed viz
annots = dset.annots()
try:
assert len(annots) > 0
scores = annots.lookup('scores')
tid = annots.lookup('track_id')
dates = pd.to_datetime(annots.images.lookup('date_captured')).date
sens = annots.images.lookup('sensor_coarse')
except (KeyError, AssertionError) as e:
print('cannot viz tracks ', e)
return
df = pd.DataFrame(dict(date=dates, sens=sens, tid=tid)).join(pd.DataFrame.from_records(scores))
df['No Activity'] = 1 - df[keys.as_list()].sum(axis=1)
ordered_phases = ['No Activity'] + keys.as_list()
df['orig'] = df[ordered_phases].idxmax(axis=1)
df['y'] = df['orig'].map(dict(zip(ordered_phases, np.linspace(0, 1, len(ordered_phases)))))
palette = {c['name']: c['color'] for c in geowatch.heuristics.CATEGORIES}
palette['salient'] = geowatch.heuristics.CATEGORIES_DCT['positive']['unscored'][0]['color']
# fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True)
# ax2.stackplot(df['date'], df[ordered_phases].T, labels=ordered_phases, colors=[palette[p] for p in ordered_phases])
# ax2.legend()
# ax1.plot(true_dates, true_labels, label='true')
# ax1.plot(df['date'], df['orig'], label='orig')
def add_scores(x, *ordered_phases_cols, **kwargs):
ax = plt.gca()
ax.stackplot(x, pd.DataFrame(ordered_phases_cols).values, labels=ordered_phases, colors=[palette[p] for p in ordered_phases])
def add_colored_linesegments(x, y, phase, **kwargs):
_df = pd.DataFrame(dict(
xy=zip(x, y),
hue=pd.Series(phase).map(palette).astype('string').map(to_rgba),
phase=phase, # need to keep this due to float comparisons in searchsorted
))
lines = []
colors = []
# drop consecutive dups
ph = _df['phase']
ixs = ph.loc[ph.shift() != ph].index.values
for start, end, hue in zip(
ixs,
ixs[1:].tolist() + [None],
_df['hue'].loc[ixs]
):
line = _df['xy'].loc[start:end].values.tolist()
if len(line) > 0:
lines.append(line)
colors.append(hue)
lc = LineCollection(lines, alpha=1.0, colors=colors)
ax = plt.gca()
ax.add_collection(lc)
def add_ticks(xs, sensor_coarses, **kwargs):
colors = dict(zip(['Sentinel-2', 'Landsat 8', 'WorldView'], kwimage.Color.distinct(3)))
colors['S2'] = colors['Sentinel-2']
colors['L8'] = colors['Landsat 8']
colors['WW'] = colors['WorldView']
ax = plt.gca()
for x, sensor_coarse in zip(xs, sensor_coarses):
plt.axvline(x,
ymin=0.8,
color=colors[sensor_coarse],
alpha=0.1)
ax.legend()
g = sns.FacetGrid(df,
col='tid',
aspect=3,
col_wrap=4,
sharex=False)
g = g.map(add_scores, 'date', *ordered_phases)
# TODO figure out why these aren't showing up
# g = g.map(sns.scatterplot, 'date', 'y', 'orig', palette=palette, s=10)
# g = g.map(add_colored_linesegments, 'date', 'y', 'orig')
# g = g.map(add_ticks, 'date', 'sens')
g = g.set(ylim=(0, 1), ylabel='score')
g = g.add_legend()
# summary viz from run_metrics_framework
# from phase import viterbi
g.savefig(out_fpath)
plt.close()