#!/usr/bin/env python3
r"""
Parses an existing tensorboard event file and draws the plots as pngs on disk
in the monitor/tensorboard directory.
Derived from netharn/mixins.py for dumping tensorboard plots to disk
CommandLine:
# cd into training directory
GEOWATCH_PREIMPORT=0 python -m geowatch.utils.lightning_ext.callbacks.tensorboard_plotter .
python -m geowatch.utils.lightning_ext.callbacks.tensorboard_plotter \
/data/joncrall/dvc-repos/smart_expt_dvc/training/toothbrush/joncrall/Drop6/runs/Drop6_BAS_scratch_landcover_10GSD_split2_V4/lightning_logs/version_4/
"""
import scriptconfig as scfg
import os
import ubelt as ub
from pytorch_lightning.callbacks import Callback
__all__ = ['TensorboardPlotter']
# TODO: can move the callback to its own file and have the CLI variant with
# core logic live separately for faster response times when using the CLI (i.e.
# avoid lightning import overhead).
[docs]
class TensorboardPlotter(Callback):
"""
Asynchronously dumps PNGs to disk visualize tensorboard scalars.
exit
CommandLine:
xdoctest -m geowatch.utils.lightning_ext.callbacks.tensorboard_plotter TensorboardPlotter
Example:
>>> # xdoctest: +REQUIRES(module:tensorboard)
>>> from geowatch.utils.lightning_ext import demo
>>> from geowatch.monkey import monkey_lightning
>>> import pytorch_lightning as pl
>>> import pandas as pd
>>> monkey_lightning.disable_lightning_hardware_warnings()
>>> self = demo.LightningToyNet2d(num_train=55)
>>> default_root_dir = ub.Path.appdir('lightning_ext/tests/TensorboardPlotter').ensuredir()
>>> #
>>> trainer = pl.Trainer(callbacks=[TensorboardPlotter()],
>>> default_root_dir=default_root_dir,
>>> max_epochs=3, accelerator='cpu', devices=1)
>>> trainer.fit(self)
>>> train_dpath = trainer.logger.log_dir
>>> print('trainer.logger.log_dir = {!r}'.format(train_dpath))
>>> data = read_tensorboard_scalars(train_dpath)
>>> for key in data.keys():
>>> d = data[key]
>>> df = pd.DataFrame({key: d['ydata'], 'step': d['xdata'], 'wall': d['wall']})
>>> print(df)
"""
def _on_epoch_end(self, trainer, logs=None, serial=False):
# The following function draws the tensorboard result. This might take
# a some non-trivial amount of time so we attempt to run in a separate
# process.
from kwutil import util_environ
if util_environ.envflag('DISABLE_TENSORBOARD_PLOTTER'):
return
if trainer.global_rank != 0:
return
# train_dpath = trainer.logger.log_dir
train_dpath = trainer.log_dir
if train_dpath is None:
import warnings
warnings.warn('The trainer logdir is not set. Cannot dump a batch plot')
return
func = _dump_measures
model = trainer.model
# TODO: get step number
if hasattr(model, 'get_cfgstr'):
model_cfgstr = model.get_cfgstr()
else:
from geowatch.utils.lightning_ext import util_model
from kwutil.slugify_ext import smart_truncate
hparams = util_model.model_hparams(model)
model_config = {
'type': str(model.__class__),
'hp': smart_truncate(ub.urepr(hparams, compact=1, nl=0), max_length=8),
}
model_cfgstr = smart_truncate(ub.urepr(
model_config, compact=1, nl=0), max_length=64)
args = (train_dpath, model_cfgstr)
proc_name = 'dump_tensorboard'
if not serial:
# This causes thread-unsafe warning messages in the inner loop
# Likely because we are forking while a thread is alive
if not hasattr(trainer, '_internal_procs'):
trainer._internal_procs = ub.ddict(dict)
# Clear finished processes from the pool
for pid in list(trainer._internal_procs[proc_name].keys()):
proc = trainer._internal_procs[proc_name][pid]
if not proc.is_alive():
trainer._internal_procs[proc_name].pop(pid)
# only start a new process if there is room in the pool
if len(trainer._internal_procs[proc_name]) < 1:
import multiprocessing
proc = multiprocessing.Process(target=func, args=args)
proc.daemon = True
proc.start()
trainer._internal_procs[proc_name][proc.pid] = proc
else:
# Draw is already in progress
pass
else:
func(*args)
[docs]
def on_train_epoch_end(self, trainer, logs=None):
return self._on_epoch_end(trainer, logs=logs)
[docs]
def on_validation_epoch_end(self, trainer, logs=None):
return self._on_epoch_end(trainer, logs=logs)
[docs]
def on_test_epoch_end(self, trainer, logs=None):
return self._on_epoch_end(trainer, logs=logs)
def read_tensorboard_scalars(train_dpath, verbose=1, cache=1):
"""
Reads all tensorboard scalar events in a directory.
Caches them because reading events of interest from protobuf can be slow.
Ignore:
train_dpath = '/home/joncrall/.cache/lightning_ext/tests/TensorboardPlotter/lightning_logs/version_2'
tb_data = read_tensorboard_scalars(train_dpath)
"""
try:
from tensorboard.backend.event_processing import event_accumulator
except ImportError:
raise ImportError('tensorboard/tensorflow is not installed')
train_dpath = ub.Path(train_dpath)
event_paths = sorted(train_dpath.glob('events.out.tfevents*'))
# make a hash so we will re-read of we need to
cfgstr = ub.hash_data(list(map(ub.hash_file, event_paths))) if cache else ''
cacher = ub.Cacher('tb_scalars', depends=cfgstr, enabled=cache,
dpath=train_dpath / '_cache')
datas = cacher.tryload()
if datas is None:
datas = {}
for p in ub.ProgIter(list(reversed(event_paths)), desc='read tensorboard',
enabled=verbose, verbose=verbose * 3):
p = os.fspath(p)
if verbose:
print('reading tensorboard scalars')
ea = event_accumulator.EventAccumulator(p)
if verbose:
print('loading tensorboard scalars')
ea.Reload()
if verbose:
print('iterate over scalars')
for key in ea.scalars.Keys():
if key not in datas:
datas[key] = {'xdata': [], 'ydata': [], 'wall': []}
subdatas = datas[key]
events = ea.scalars.Items(key)
for e in events:
subdatas['xdata'].append(int(e.step))
subdatas['ydata'].append(float(e.value))
subdatas['wall'].append(float(e.wall_time))
# Order all information by its wall time
for _key, subdatas in datas.items():
sortx = ub.argsort(subdatas['wall'])
for d, vals in subdatas.items():
subdatas[d] = list(ub.take(vals, sortx))
cacher.save(datas)
return datas
def _write_helper_scripts(out_dpath, train_dpath):
"""
Writes scripts to let the user refresh data on the fly
"""
train_dpath_ = train_dpath.resolve().shrinkuser()
# TODO: make this a nicer python script that aranges figures nicely.
stack_fpath = (out_dpath / 'stack.sh')
stack_fpath.write_text(ub.codeblock(
fr'''
#!/usr/bin/env bash
kwimage stack_images --out "{train_dpath_}/monitor/tensorboard-stack.png" -- {train_dpath_}/monitor/tensorboard/*.png
'''))
try:
stack_fpath.chmod('ug+x')
except PermissionError as ex:
print(f'Unable to change permissions on {stack_fpath}: {ex}')
refresh_fpath = (out_dpath / 'redraw.sh')
refresh_fpath.write_text(ub.codeblock(
fr'''
#!/usr/bin/env bash
GEOWATCH_PREIMPORT=0 python -m geowatch.utils.lightning_ext.callbacks.tensorboard_plotter \
{train_dpath_}
'''))
try:
refresh_fpath.chmod('ug+x')
except PermissionError as ex:
print(f'Unable to change permissions on {refresh_fpath}: {ex}')
def _dump_measures(train_dpath, title='?name?', smoothing='auto', ignore_outliers=True, verbose=0):
"""
This is its own function in case we need to modify formatting
"""
import kwplot
import kwutil
from kwplot.auto_backends import BackendContext
import pandas as pd
import numpy as np # NOQA
train_dpath = ub.Path(train_dpath).resolve()
if not train_dpath.name.startswith('version_'):
# hack: use knowledge of common directory structures to find
# the root directory of training output for a specific training run
if not (train_dpath / 'monitor').exists():
if (train_dpath / '../monitor').exists():
train_dpath = (train_dpath / '..')
elif (train_dpath / '../../monitor').exists():
train_dpath = (train_dpath / '../..')
tb_data = read_tensorboard_scalars(train_dpath, cache=0, verbose=verbose)
out_dpath = ub.Path(train_dpath, 'monitor', 'tensorboard').ensuredir()
_write_helper_scripts(out_dpath, train_dpath)
if isinstance(smoothing, str) and smoothing == 'auto':
smoothing_values = [0.6, 0.95]
elif isinstance(smoothing, list):
smoothing_values = [smoothing]
else:
smoothing_values = [smoothing]
plot_keys = [k for k in tb_data.keys() if '/' not in k]
keys = set(tb_data.keys()).intersection(set(plot_keys))
# no idea what hp metric is, but it doesn't seem important
# keys = keys - {'hp_metric'}
if len(keys) == 0:
print('warning: no known keys to plot')
print(f'available keys: {list(tb_data.keys())}')
USE_NEW_PLOT_PREF = 0
if USE_NEW_PLOT_PREF:
# TODO: finish this
default_plot_preferences = kwutil.Yaml.loads(ub.codeblock(
'''
attributes:
- pattern: [
'*_acc*', '*_ap*', '*_mAP*', '*_auc*', '*_mcc*', '*_brier*', '*_mauc*',
'*_f1*', '*_iou*',
]
ymax: 1
ymin: 0
- pattern: ['*error*', '*loss*']
ymin: 0
- pattern: ['*lr*', '*momentum*', '*epoch*']
smoothing: null
- pattern: ['hp_metric']
ignore: true
'''))
plot_preferences_fpath = train_dpath / 'plot_preferences.yaml'
if plot_preferences_fpath.exists():
user_plot_preferences = kwutil.Yaml.coerce(plot_preferences_fpath)
plot_preferences = default_plot_preferences.copy()
plot_preferences.update(user_plot_preferences)
else:
plot_preferences = default_plot_preferences
print(f'plot_preferences = {ub.urepr(plot_preferences, nl=3)}')
for item in plot_preferences['attributes']:
item['pattern_'] = kwutil.util_pattern.MultiPattern.coerce(item['pattern'])
key_table = []
for plot_key in keys:
row = {'key': plot_key}
row['smoothing'] = smoothing_values
for item in plot_preferences['attributes']:
if item['pattern_'].match(plot_key.lower()):
row.update(item)
row.pop('pattern', None)
row.pop('pattern_', None)
key_table.append(row)
else:
y01_measures = [
'_acc', '_ap', '_mAP', '_auc', '_mcc', '_brier', '_mauc',
'_f1', '_iou',
]
y0_measures = ['error', 'loss']
HACK_NO_SMOOTH = {'lr', 'momentum', 'epoch'}
key_table = []
for plot_key in tb_data.keys():
row = {'key': plot_key}
if plot_key == 'hp_metric' or '/' in plot_key:
row['ignore'] = True
continue
if plot_key in y01_measures:
row['ymax'] = 1
row['ymin'] = 0
if plot_key in y0_measures:
if ignore_outliers:
row['ymax'] = 'ignore_outliers'
row['ymin'] = 0
if plot_key in HACK_NO_SMOOTH:
row['smoothing'] = None
else:
row['smoothing'] = smoothing_values
key_table.append(row)
if 0:
print(f'key_table = {ub.urepr(key_table, nl=1)}')
print(pd.DataFrame(key_table))
key_table = [r for r in key_table if not r.get('ignore', False)]
with BackendContext('agg'):
import seaborn as sns
sns.set()
nice = title
fig = kwplot.figure(fnum=1)
fig.clf()
ax = fig.gca()
key_iter = ub.ProgIter(key_table, desc='dump plots', verbose=verbose * 3)
for key_row in key_iter:
key = key_row['key']
key_iter.set_extra(key)
snskw = {
'y': key,
'x': 'step',
}
d = tb_data[key]
df_orig = pd.DataFrame({key: d['ydata'], 'step': d['xdata']})
num_non_nan = (~df_orig[key].isnull()).sum()
num_nan = (df_orig[key].isnull()).sum()
df_orig['smoothing'] = 0.0
variants = [df_orig]
smoothing_values = key_row['smoothing']
if smoothing_values:
for _smoothing_value in smoothing_values:
# if 0:
# # TODO: can we get a hueristic for how much smoothing
# # we might want? Look at the entropy of the derivative
# # curve?
# import scipy.stats
# deriv = np.diff(df_orig[key])
# counts1, bins1 = np.histogram(deriv[deriv < 0], bins=25)
# counts2, bins2 = np.histogram(deriv[deriv >= 0], bins=25)
# counts = np.hstack([counts1, counts2])
# # bins = np.hstack([bins1, bins2])
# # dict(zip(bins, counts))
# entropy = scipy.stats.entropy(counts)
# print(f'entropy={entropy}')
if _smoothing_value > 0:
df_smooth = df_orig.copy()
beta = _smoothing_value
ydata = df_orig[key]
df_smooth[key] = smooth_curve(ydata, beta)
df_smooth['smoothing'] = _smoothing_value
variants.append(df_smooth)
if len(variants) == 1:
df = variants[0]
else:
if verbose:
print('Combine smoothed variants')
df = pd.concat(variants).reset_index()
snskw['hue'] = 'smoothing'
kw = {}
ymin = key_row.get('ymin', None)
ymax = key_row.get('max', None)
if ymin is not None:
kw['ymin'] = float(ymin)
if ymax is not None:
if ymax == 'ignore_outliers':
if num_non_nan > 3:
if verbose:
print('Finding outliers')
low, kw['ymax'] = tensorboard_inlier_ylim(ydata)
else:
kw['ymax'] = float(ymax)
if verbose:
print('Begin plot')
# NOTE: this is actually pretty slow
# TODO: port title buidler to kwplot and use it
ax.cla()
try:
if num_non_nan <= 1:
sns.scatterplot(data=df, **snskw)
else:
# todo: we have an alternative in kwplot can
# handle nans, use that instead.
sns.lineplot(data=df, **snskw)
except Exception as ex:
title = nice + '\n' + key + str(ex)
else:
title = nice + '\n' + key
initial_ylim = ax.get_ylim()
if kw.get('ymax', None) is None:
kw['ymax'] = initial_ylim[1]
if kw.get('ymin', None) is None:
kw['ymin'] = initial_ylim[0]
try:
ax.set_ylim(kw['ymin'], kw['ymax'])
except Exception:
...
if num_nan > 0:
title += '(num_nan={})'.format(num_nan)
ax.set_title(title)
# png is smaller than jpg for this kind of plot
fpath = out_dpath / (key + '.png')
if verbose:
print('Save plot: ' + str(fpath))
ax.figure.savefig(fpath)
ax.figure.subplots_adjust(top=0.8)
do_tensorboard_stack(train_dpath)
def do_tensorboard_stack(train_dpath):
# Do the kwimage stack as well.
import kwimage
tensorboard_dpath = train_dpath / 'monitor/tensorboard'
monitor_dpath = train_dpath / 'monitor'
image_paths = sorted(tensorboard_dpath.glob('*.png'))
images = [kwimage.imread(fpath) for fpath in image_paths]
canvas = kwimage.stack_images_grid(images)
stack_fpath = monitor_dpath / 'tensorboard-stack.png'
kwimage.imwrite(stack_fpath, canvas)
def smooth_curve(ydata, beta):
"""
Curve smoothing algorithm used by tensorboard
"""
import pandas as pd
alpha = 1.0 - beta
if alpha <= 0:
return ydata
ydata_smooth = pd.Series(ydata).ewm(alpha=alpha).mean().values
return ydata_smooth
# def inlier_ylim(ydata):
# """
# outlier removal used by tensorboard
# """
# import kwarray
# normalizer = kwarray.find_robust_normalizers(ydata, {
# 'low': 0.05,
# 'high': 0.95,
# })
# low = normalizer['min_val']
# high = normalizer['max_val']
# return (low, high)
def tensorboard_inlier_ylim(ydata):
"""
outlier removal used by tensorboard
"""
import numpy as np
q1 = 0.05
q2 = 0.95
low_, high_ = np.quantile(ydata, [q1, q2])
# Extrapolate how big the entire span should be based on inliers
inner_q = q2 - q1
inner_extent = high_ - low_
extrap_total_extent = inner_extent / inner_q
# amount of padding to add to either side
missing_p1 = q1
missing_p2 = 1 - q2
frac1 = missing_p1 / (missing_p2 + missing_p1)
frac2 = missing_p2 / (missing_p2 + missing_p1)
missing_extent = extrap_total_extent - inner_extent
pad1 = missing_extent * frac1
pad2 = missing_extent * frac2
low = low_ - pad1
high = high_ + pad2
return (low, high)
def redraw_cli(train_dpath):
"""
Create png plots for the tensorboard data in a training directory.
"""
from kwutil.util_yaml import Yaml
train_dpath = ub.Path(train_dpath)
expt_name = train_dpath.parent.parent.name
hparams_fpath = train_dpath / 'hparams.yaml'
if hparams_fpath.exists():
print('Found hparams')
hparams = Yaml.load(hparams_fpath)
if 'name' in hparams:
title = hparams['name']
else:
from kwutil.slugify_ext import smart_truncate
model_config = {
# 'type': str(model.__class__),
'hp': smart_truncate(ub.urepr(hparams, compact=1, nl=0), max_length=8),
}
model_cfgstr = smart_truncate(ub.urepr(
model_config, compact=1, nl=0), max_length=64)
title = model_cfgstr
title = expt_name + '\n' + title
else:
print('Did not find hparams')
title = expt_name
if 1:
# Add in other relevant data
# ...
config_fpath = train_dpath / 'config.yaml'
if config_fpath.exists():
config = Yaml.load(config_fpath)
trainer_config = config.get('trainer', {})
optimizer_config = config.get('optimizer', {})
data_config = config.get('data', {})
optimizer_args = optimizer_config.get('init_args', {})
devices = trainer_config.get('devices', None)
batch_size = data_config.get('batch_size', None)
accum_batches = trainer_config.get('accumulate_grad_batches', None)
optim_lr = optimizer_args.get('lr', None)
decay = optimizer_args.get('weight_decay', None)
# optim_name = optimizer_config.get('class_path', '?').split('.')[-1]
learn_dynamics_str = ub.codeblock(
f'''
BS=({batch_size} x {accum_batches}), LR={optim_lr}, decay={decay}, devs={devices}
'''
)
title = title + '\n' + learn_dynamics_str
# print(learn_dynamics_str)
print(f'train_dpath={train_dpath}')
print(f'title={title}')
_dump_measures(train_dpath, title, verbose=1)
import rich
tensorboard_dpath = train_dpath / 'monitor/tensorboard'
rich.print(f'[link={tensorboard_dpath}]{tensorboard_dpath}[/link]')
class TensorboardPlotterCLI(scfg.DataConfig):
"""
Helper CLI executable to redraw on demand.
"""
train_dpath = scfg.Value('.', help='train_dpath', position=1)
@classmethod
def main(cls, cmdline=1, **kwargs):
import rich
config = cls.cli(cmdline=cmdline, data=kwargs, strict=True)
rich.print('config = ' + ub.urepr(config, nl=1))
redraw_cli(config.train_dpath)
if __name__ == '__main__':
"""
CommandLine:
GEOWATCH_PREIMPORT=0 python -X importtime -m geowatch.utils.lightning_ext.callbacks.tensorboard_plotter .
"""
TensorboardPlotterCLI.main()