#!/usr/bin/env python3
"""
Script for running the iarpa metrics code with a bit of post-processing.
TODO:
- [ ] Rename to run_polygon_evaluation.py? Or run_iarpa_metrics.py?
Note currently depends on :py:mod:`iarpa_smart_metrics`, which can be
obtained from:
https://smartgitlab.com/TE/metrics-and-test-framework
or
https://gitlab.kitware.com/smart/metrics-and-test-framework
Example:
>>> # xdoctest: +REQUIRES(module:iarpa_smart_metrics)
>>> from geowatch.cli.run_metrics_framework import * # NOQA
>>> from geowatch.demo.metrics_demo.generate_demodata import generate_demo_metrics_framework_data
>>> cmdline = 0
>>> base_dpath = ub.Path.appdir('geowatch', 'tests', 'test-iarpa-metrics2')
>>> data_dpath = base_dpath / 'inputs'
>>> dpath = base_dpath / 'outputs'
>>> demo_info1 = generate_demo_metrics_framework_data(
>>> roi='DR_R001',
>>> num_sites=5, num_observations=10, noise=2, p_observe=0.5,
>>> p_transition=0.3, drop_noise=0.5, drop_limit=0.5)
>>> demo_info2 = generate_demo_metrics_framework_data(
>>> roi='DR_R002',
>>> num_sites=7, num_observations=10, noise=1, p_observe=0.5,
>>> p_transition=0.1, drop_noise=0.8, drop_limit=0.5)
>>> demo_info3 = generate_demo_metrics_framework_data(
>>> roi='DR_R003',
>>> num_sites=11, num_observations=10, noise=3, p_observe=0.5,
>>> p_transition=0.2, drop_noise=0.3, drop_limit=0.5)
>>> print('demo_info1 = {}'.format(ub.urepr(demo_info1, nl=1)))
>>> print('demo_info2 = {}'.format(ub.urepr(demo_info2, nl=1)))
>>> print('demo_info3 = {}'.format(ub.urepr(demo_info3, nl=1)))
>>> out_dpath = dpath / 'region_metrics'
>>> merge_fpath = dpath / 'merged.json'
>>> out_dpath.delete()
>>> kwargs = {
>>> 'pred_sites': demo_info1['pred_site_dpath'],
>>> 'true_region_dpath': demo_info1['true_region_dpath'],
>>> 'true_site_dpath': demo_info1['true_site_dpath'],
>>> 'merge': True,
>>> 'merge_fpath': merge_fpath,
>>> 'out_dir': out_dpath,
>>> }
>>> main(cmdline=False, **kwargs)
>>> # TODO: visualize
Example:
>>> # xdoctest: +REQUIRES(module:iarpa_smart_metrics)
>>> from geowatch.cli.run_metrics_framework import * # NOQA
>>> from geowatch.demo.metrics_demo.generate_demodata import generate_demo_metrics_framework_data
>>> # Test single region case
>>> cmdline = 0
>>> base_dpath = ub.Path.appdir('geowatch', 'tests', 'test-iarpa-metrics5')
>>> data_dpath = (base_dpath / 'inputs').ensuredir()
>>> dpath = (base_dpath / 'outputs').ensuredir()
>>> demo_info1 = generate_demo_metrics_framework_data(
>>> roi='DR_R004',
>>> num_sites=5, num_observations=10, noise=2, p_observe=0.5,
>>> p_transition=0.3, drop_noise=0.5, drop_limit=0.5, outdir=data_dpath)
>>> print('demo_info1 = {}'.format(ub.urepr(demo_info1, nl=1)))
>>> out_dpath = dpath / 'poly_eval'
>>> merge_fpath = dpath / 'poly_eval.json'
>>> out_dpath.delete()
>>> kwargs = {
>>> 'pred_sites': demo_info1['pred_site_dpath'],
>>> 'true_region_dpath': demo_info1['true_region_dpath'],
>>> 'true_site_dpath': demo_info1['true_site_dpath'],
>>> 'merge': True,
>>> 'merge_fpath': merge_fpath,
>>> 'out_dir': out_dpath,
>>> }
>>> main(cmdline=False, **kwargs)
>>> # TODO: visualize
"""
import os
import json
import shlex
import ubelt as ub
import scriptconfig as scfg
from packaging import version
import warnings
[docs]
class MetricsConfig(scfg.DataConfig):
"""
Score IARPA site model GeoJSON files using IARPA's metrics-and-test-framework
"""
pred_sites = scfg.Value(None,
required=True,
nargs='*',
help=ub.paragraph('''
List of paths to predicted v2 site models. Or a path to a single text
file containing the a list of paths to predicted site models.
All region_ids from these sites will be scored, and it will be assumed
that there are no other sites in these regions.
'''))
gt_dpath = scfg.Value(None,
help=ub.paragraph('''
Path to a local copy of the ground truth annotations,
https://smartgitlab.com/TE/annotations. If None, use geowatch_dvc to
find $DVC_DATA_DPATH/annotations.
'''))
true_site_dpath = scfg.Value(None,
help=ub.paragraph('''
Directory containing true site models. Defaults to
gt_dpath / site_models
'''))
true_region_dpath = scfg.Value(None,
help=ub.paragraph('''
Directory containing true region models. Defaults to
gt_dpath / region_models
'''))
out_dir = scfg.Value(None,
help=ub.paragraph('''
Output directory where scores will be written. Each
region will have. Defaults to ./iarpa-metrics-output/
'''))
merge = scfg.Value('overwrite',
help=ub.paragraph('''
Merge BAS and SC metrics from all regions and output to
{out_dir}/merged/.
'overwrite' = rerun IARPA metrics,
'read' = assume they exist on disk,
(TODO 'write' = rerun IARPA metrics if needed.)
'''))
merge_fpath = scfg.Value(None,
help=ub.paragraph('''
Forces the merge summary to be written to a specific
location.
'''))
merge_fbetas = scfg.Value([],
help=ub.paragraph('''
A list of BAS F-scores to compute besides F1.
'''))
tmp_dir = scfg.Value(None,
help=ub.paragraph('''
If specified, will write temporary data here instead of
using a non-persistent directory
'''))
enable_viz = scfg.Value(False,
isflag=1,
help=ub.paragraph('''
If true, enables iarpa visualizations
'''))
name = scfg.Value('unknown',
help=ub.paragraph('''
Short name for the algorithm used to generate the model.
UNUSED. TODO: incorporate
'''))
enable_sc_viz = scfg.Value(False,
isflag=1,
help=ub.paragraph('''
If true, enables our SC visualization
'''))
load_workers = scfg.Value(0,
help=ub.paragraph('''
The number of workers used to load site models.
'''))
parallel = scfg.Value(False, isflag=True,
help=ub.paragraph('''
Innvocate running IARPA T&E metrics in parallel. Note:
Only works with IARPA T&E metrics version 1.0.0 or greater.
'''))
performer = scfg.Value('kit', help='the performer id')
tau = scfg.Value(0.2, help='T&E tau association threshold param')
rho = scfg.Value(0.5, help='T&E rho site detection threshold (min proportion of matching polygons)')
[docs]
def ensure_thumbnails(image_root, region_id, sites):
'''
Symlink and organize images in the format the metrics framework expects
For the region visualizations:
> image_list = glob(f"{self.image_path}/
> {self.region_model.id.replace('_', '/')}/images/*/*/*.jp2")
For the site visualizations:
> image_list = glob(f"{self.image_path}/
> {gt_ann_id.replace('_', '/')}/crops/*.tif")
Which becomes:
{country_code}/
{region_num}/
images/
*/
*/
*.jp2
{site_num}/
crops/
*.tif
Args:
image_root: root directory to save under
region_id: ex. 'KR_R001'
sites: proposed sites with image paths in the 'source' field
TODO change to 'misc_info' field
'''
image_root = ub.Path(image_root)
verbose = 0
# gather images and dates
site_img_date_dct = dict()
for site in sites:
img_date_dct = dict()
for feat in site['features']:
props = feat['properties']
if props['type'] == 'observation':
img_path = ub.Path(props['source'])
if img_path.is_file():
img_date_dct[img_path] = props['observation_date']
else:
if verbose:
print(f'warning: image {img_path}' ' is not a valid path')
elif props['type'] == 'site':
site_id = props['site_id']
else:
raise ValueError(props['type'])
site_img_date_dct[site_id] = img_date_dct
# build region viz
region_root = image_root.joinpath(*region_id.split('_')) / 'images' / 'a' / 'b'
region_root.mkdir(parents=True, exist_ok=True)
for img_path, img_date in ub.dict_union(
*site_img_date_dct.values()).items():
link_path = (region_root / '_'.join(
(img_date.replace('-', ''), img_path.with_suffix('.jp2').name)))
if img_path.exists():
ub.symlink(img_path, link_path, verbose=0)
else:
if verbose:
print(f'warning: {img_path=} not found')
# build site viz
for site_id, img_date_dct in site_img_date_dct.items():
site_root = image_root.joinpath(*site_id.split('_')) / 'crops'
site_root.mkdir(parents=True, exist_ok=True)
for img_path, img_date in img_date_dct.items():
# TODO crop
link_path = (site_root / '_'.join(
(img_date.replace('-', ''), img_path.with_suffix('.tif').name)))
if img_path.exists():
ub.symlink(img_path, link_path, verbose=0)
else:
if verbose:
print(f'warning: {img_path=} not found')
[docs]
def main(cmdline=True, **kwargs):
"""
CommandLine:
xdoctest -m geowatch.cli.run_metrics_framework main
"""
from geowatch.utils import util_gis
from kwcoco.util import util_json
from geowatch.utils import process_context
from tempfile import TemporaryDirectory
import safer
config = MetricsConfig.cli(cmdline=cmdline, data=kwargs)
args = config
config_dict = config.asdict()
print('config = {}'.format(ub.urepr(config_dict, nl=2, sort=0)))
try:
# Do we have the latest and greatest?
import iarpa_smart_metrics
IARPA_METRICS_VERSION = version.Version(iarpa_smart_metrics.__version__)
except Exception:
raise AssertionError('The iarpa_smart_metrics package should be pip installed ' 'in your virtualenv')
assert IARPA_METRICS_VERSION >= version.Version('2.5.0')
# Record the git hash of the metrics code if possible.
try:
metrics_modpath = ub.Path(iarpa_smart_metrics.__file__).parent
gitout = ub.cmd('git rev-parse --short HEAD', cwd=metrics_modpath, check=None)
IARPA_METRICS_GIT_HASH = gitout.stdout.strip()
except Exception:
IARPA_METRICS_GIT_HASH = None
# Record information about this process
info = []
# Args will be serialized in kwcoco, so make sure it can be coerced to json
jsonified_config = util_json.ensure_json_serializable(config_dict)
walker = ub.IndexableWalker(jsonified_config)
for problem in util_json.find_json_unserializable(jsonified_config):
bad_data = problem['data']
walker[problem['loc']] = str(bad_data)
proc_context = process_context.ProcessContext(
type='process',
name='geowatch.cli.run_metrics_framework',
config=jsonified_config,
extra={'iarpa_smart_metrics_version': iarpa_smart_metrics.__version__},
)
proc_context.start()
# load pred_sites
load_workers = config['load_workers']
pred_site_infos = util_gis.coerce_geojson_paths(config['pred_sites'], return_manifests=True)
if len(pred_site_infos['manifest_fpaths']) > 1:
raise Exception('Only expected at most one manifest')
parent_info = []
for manifest_fpath in pred_site_infos['manifest_fpaths']:
# The manifest contains info about how these predictions were computed
# Grab that if possible.
print('Load parent info from manifest')
with open(manifest_fpath, 'r') as file:
manifest = json.load(file)
assert (isinstance(manifest, dict) and manifest.get('type', None) == 'tracking_result')
# The input was a track result json which contains pointers to
# the actual sites
parent_info.extend(manifest.get('info', []))
pred_sites = [
info['data'] for info in util_gis.coerce_geojson_datas(
pred_site_infos['geojson_fpaths'], format='json', workers=load_workers)
]
if len(pred_sites) == 0:
# FIXME: when the tracker produces no results, we fail to score here.
# Is there a way to produce a valid empty file in the tracker?
raise Exception('No input predicted sites were given')
# name = args.name
true_site_dpath = args.true_site_dpath
true_region_dpath = args.true_region_dpath
if true_region_dpath is None or true_site_dpath is None:
# normalize paths
if args.gt_dpath is not None:
gt_dpath = ub.Path(args.gt_dpath).absolute()
else:
import geowatch
data_dvc_dpath = geowatch.find_dvc_dpath(tags='phase2_data')
gt_dpath = data_dvc_dpath / 'annotations'
print(f'gt_dpath unspecified, defaulting to {gt_dpath=}')
if true_region_dpath is None:
assert gt_dpath.is_dir(), gt_dpath
true_region_dpath = gt_dpath / 'region_models'
if true_site_dpath is None:
assert gt_dpath.is_dir(), gt_dpath
true_site_dpath = gt_dpath / 'site_models'
true_region_dpath = ub.Path(true_region_dpath)
true_site_dpath = ub.Path(true_site_dpath)
if args.tmp_dir is not None:
tmp_dpath = ub.Path(args.tmp_dir)
else:
temp_dir = TemporaryDirectory(suffix='iarpa-metrics-tmp')
tmp_dpath = ub.Path(temp_dir.name)
# split sites by region
out_dirs = []
grouped_sites = ub.group_items(pred_sites, lambda site: site['features'][0]['properties']['region_id'])
main_out_dir = ub.Path(args.out_dir or './iarpa-metrics-output')
main_out_dir.ensuredir()
# First build up all of the commands and prepare necessary data for them.
commands = []
for region_id, region_sites in ub.ProgIter(sorted(grouped_sites.items()), desc='prepare regions for eval'):
roi = region_id
gt_dir = os.fspath(true_site_dpath)
# Test to see if GT regions exist as they would be checked for in the
# iarpa_smart_metrics tool.
from iarpa_smart_metrics.commons import as_local_path
gt_dir = as_local_path(gt_dir, "annotations/truth/", reg_exp=f".*{roi}.*.geojson")
gt_dir = ub.Path(gt_dir)
gt_files = list(gt_dir.glob(f"*{roi}*.geojson"))
if len(gt_files) == 0:
warnings.warn(f'No truth for region: {roi}. Skipping')
continue
site_dpath = (tmp_dpath / 'site' / region_id).ensuredir()
image_dpath = (tmp_dpath / 'image').ensuredir()
out_dir = (main_out_dir / region_id).ensuredir()
out_dirs.append(out_dir)
# doctor site_dpath for expected structure
pred_site_sub_dpath = site_dpath / 'latest' / region_id
pred_site_sub_dpath.ensuredir()
# copy site models to site_dpath
for site in region_sites:
geojson_fpath = pred_site_sub_dpath / (site['features'][0]['properties']['site_id'] + '.geojson')
with safer.open(geojson_fpath, 'w', temp_file=not ub.WIN32) as f:
json.dump(site, f)
ensure_thumbnails(image_dpath, region_id, region_sites)
key_to_disable_flag = {
'region': '--no-viz-region', # we often want this enabled
'slices': '--no-viz-slices',
'detection_table': '--no-viz-detection-table',
'comparison_table': '--no-viz-comparison-table',
'associate_metrics': '--no-viz-associate-metrics',
'activity_metrics': '--no-viz-activity-metrics',
}
if isinstance(args.enable_viz, str):
# Allow the user to enable specific visualizations
chosen = set(args.enable_viz.split(','))
elif args.enable_viz:
# Enable all visualizations (usually a bad idea)
# warnings.warn(ub.paragraph(
# '''
# All IARPA visualizations were enabled. Try setting
# --enable_viz=region to get only the useful visualizations
# '''))
warnings.warn(ub.paragraph(
'''
When enable_viz is True, we default to only using "region"
Can also set things like --enable_viz="region,slices,detection_table"
etc...
'''))
# chosen = set(key_to_disable_flag)
chosen = {'region'}
else:
chosen = set()
to_disable = set(key_to_disable_flag) - chosen
to_enable = set(key_to_disable_flag) & chosen
viz_flags = [key_to_disable_flag[k] for k in sorted(to_disable)]
viz_flags += [key_to_disable_flag[k].replace('--no-', '--')
for k in sorted(to_enable)]
run_eval_command = [
'python',
'-m',
'iarpa_smart_metrics.run_evaluation',
'--roi',
roi,
'--gt_dir',
os.fspath(gt_dir),
'--rm_dir',
os.fspath(true_region_dpath),
'--sm_dir',
os.fspath(pred_site_sub_dpath),
'--image_dir',
os.fspath(image_dpath),
'--output_dir',
os.fspath(out_dir),
## Restrict to make this faster
'--tau', str(config.tau),
'--rho', str(config.rho),
'--activity', 'overall',
'--loglevel', 'ERROR',
f'--performer={config.performer}',
'--eval_num=0',
'--eval_run_num=0',
# '--no-db',
'--sequestered_id', '',
# 'seq', # default None broken on autogen branch
]
# Add parallel flag if requested
if args.parallel:
run_eval_command += ['--parallel']
else:
run_eval_command += ['--serial']
run_eval_command += viz_flags
# run metrics framework
cmd = shlex.join(run_eval_command)
region_invocation_text = ub.codeblock('''
#!/bin/bash
__doc__="
This is an auto-generated file that records the command used to
generate this evaluation of this particular region.
"
''') + chr(10) + cmd + chr(10)
# Dump this command to disk for reference and debugging.
(out_dir / 'invocation.sh').write_text(region_invocation_text)
commands.append(cmd)
if 0:
import cmd_queue
queue = cmd_queue.Queue.create(backend='serial')
for cmd in commands:
queue.submit(cmd)
# TODO: make command queue stop on the first failure?
queue.run()
# if queue.read_state()['failed']:
# raise Exception('jobs failed')
else:
import subprocess
for cmd in commands:
if args.merge != 'read':
try:
ub.cmd(cmd, verbose=3, check=True, shell=True)
except subprocess.CalledProcessError:
print('error in metrics framework, probably due to zero '
'TP site matches or a region without site truth.')
if len(grouped_sites) == 1:
print('Dump confusion')
# In the case where there is one region:
# Write a the invocation for confusion analysis.
region_id = list(grouped_sites.keys())[0]
confusion_analysis_text = ub.codeblock(
rf'''
#!/bin/bash
__doc__="
This is an auto-generated file that records the command used to
generate this evaluation of multiple regions.
"
python -m geowatch.mlops.confusor_analysis \
--metrics_node_dpath={main_out_dir} \
--out_dpath={main_out_dir}/confusion_analysis \
--true_region_dpath={true_region_dpath} \
--true_site_dpath={true_site_dpath} \
--region_id={region_id} \
--viz_sites=True \
--reload=0 "$@"
''')
cfsn_invoke_fpath = (main_out_dir / 'confusion_analysis.sh')
cfsn_invoke_fpath.write_text(confusion_analysis_text)
import stat
cfsn_invoke_fpath.chmod(cfsn_invoke_fpath.stat().st_mode | stat.S_IXUSR)
print('out_dirs = {}'.format(ub.urepr(out_dirs, nl=1)))
if args.merge and out_dirs:
from geowatch.tasks.metrics.merge_iarpa_metrics import merge_metrics_results
from geowatch.tasks.metrics.merge_iarpa_metrics import iarpa_bas_color_legend
import kwimage
if args.merge_fpath is None:
merge_dpath = (main_out_dir / 'merged').ensuredir()
merge_fpath = merge_dpath / 'summary2.json'
else:
merge_fpath = ub.Path(args.merge_fpath)
merge_dpath = merge_fpath.parent.ensuredir()
region_dpaths = out_dirs
context = proc_context.stop()
context['IARPA_METRICS_VERSION'] = str(IARPA_METRICS_VERSION)
context['IARPA_METRICS_GIT_HASH'] = IARPA_METRICS_GIT_HASH
info.append(context)
json_data, bas_df, sc_df, best_bas_rows = merge_metrics_results(region_dpaths, true_site_dpath,
true_region_dpath, args.merge_fbetas)
# TODO: parent info should probably belong to info itself
json_data['info'] = info
json_data['parent_info'] = parent_info
merge_dpath = ub.Path(merge_dpath).ensuredir()
with safer.open(merge_fpath, 'w', temp_file=not ub.WIN32) as f:
json.dump(json_data, f, indent=4)
print('merge_fpath = {!r}'.format(merge_fpath))
# Consolodate visualizations
combined_viz_dpath = (merge_dpath / 'region_viz_overall')
# Write a legend to go with the BAS viz
if config.enable_viz:
combined_viz_dpath.ensuredir()
legend_img = iarpa_bas_color_legend()
legend_fpath = (combined_viz_dpath / 'bas_legend.png')
kwimage.imwrite(legend_fpath, legend_img)
bas_df.to_pickle(merge_dpath / 'bas_df.pkl')
sc_df.to_pickle(merge_dpath / 'sc_df.pkl')
# Symlink to visualizations
for dpath in region_dpaths:
overall_dpath = dpath / 'overall'
viz_dpath = (overall_dpath / 'bas' / 'region').ensuredir()
for viz_fpath in viz_dpath.iterdir():
combined_viz_dpath.ensuredir()
viz_link = viz_fpath.augment(dpath=combined_viz_dpath)
ub.symlink(viz_fpath, viz_link, verbose=1)
# viz SC
if config.enable_sc_viz:
combined_viz_dpath.ensuredir()
from geowatch.tasks.metrics.viz_sc_results import viz_sc
viz_sc(region_dpaths, true_site_dpath, true_region_dpath, combined_viz_dpath)
__config__ = MetricsConfig
__config__.main = main
if __name__ == '__main__':
main()