"""
Basline Example:
DVC_DATA_DPATH=$(geowatch_dvc --tags='phase2_data' --hardware=auto)
DVC_EXPT_DPATH=$(geowatch_dvc --tags='phase2_expt' --hardware=auto)
python -m geowatch.tasks.invariants.predict \
--input_kwcoco=$DVC_DATA_DPATH/Drop4-BAS/data_vali_KR_R001.kwcoco.json \
--output_kwcoco=$DVC_DATA_DPATH/Drop4-BAS/all_tests/model_thirteen_epoch/data_vali_KR_R001_invariants.kwcoco.json \
--pretext_package=$DVC_EXPT_DPATH/models/uky/uky_invariants_2022_12_17/TA1_pretext_model/pretext_package.pt \
--input_space_scale=10GSD \
--window_space_scale=10GSD \
--patch_size=256 \
--do_pca 0 \
--patch_overlap=0.3 \
--workers="2" \
--write_workers 0 \
--tasks before_after pretext
# After your model predicts the outputs, you should be able to use the
# geowatch visualize tool to inspect your features.
python -m geowatch visualize $DVC_DATA_DPATH/Drop4-BAS/all_tests/model_thirteen_epoch/data_vali_invariants.kwcoco.json \
--channels "invariants.5:8,invariants.8:11,invariants.14:17" --stack=only --workers=avail --animate=True \
--draw_anns=False
SeeAlso:
~/code/watch/geowatch/cli/queue_cli/prepare_teamfeats.py
"""
import torch
import ubelt as ub
from .data.datasets import GriddedDataset
from .pretext_model import pretext
from .segmentation_model import segmentation_model as seg_model
from geowatch.tasks.fusion.coco_stitcher import CocoStitchingManager
import kwimage
from kwutil import util_parallel
from geowatch.utils.lightning_ext import util_device
import scriptconfig as scfg
[docs]
class InvariantPredictConfig(scfg.DataConfig):
"""
Configuration for UKY invariant models
"""
device = scfg.Value('cuda', type=str)
pretext_ckpt_path = scfg.Value(None, type=str)
segmentation_ckpt_path = scfg.Value(None, type=str)
pretext_package_path = scfg.Value(None, type=str)
segmentation_package_path = scfg.Value(None, type=str)
batch_size = scfg.Value(1, type=int)
workers = scfg.Value(4, help=ub.paragraph(
'''
number of background data loading workers
'''), alias=['num_workers'])
io_workers = scfg.Value(0, help=ub.paragraph(
'''
number of background data writing workers
'''), alias=['write_workers'])
window_resolution = scfg.Value('10GSD', help='The window GSD to build the grid at', alias=['window_space_scale'])
input_resolution = scfg.Value('10GSD', help='The input GSD to sample the grid at', alias=['input_space_scale'])
sensor = scfg.Value(['S2', 'L8'], nargs='+')
bands = scfg.Value(['shared'], type=str, help=ub.paragraph(
'''
Choose bands on which to train. Can specify 'all' for all
bands from given sensor, or 'share' to use common bands when
using both S2 and L8 sensors
'''), nargs='+')
patch_size = scfg.Value(256, type=int)
patch_overlap = scfg.Value(0.25, type=float)
input_kwcoco = scfg.Value(None, type=str, required=True, help=ub.paragraph(
'''
Path to kwcoco dataset with images to generate feature for
'''))
output_kwcoco = scfg.Value(None, type=str, required=True, help=ub.paragraph(
'''
Path to write an output kwcoco file. Output file will be a
copy of input_kwcoco with addition feature fields generated
by predict.py rerooted to point to the original data.
'''))
tasks = scfg.Value(['all'], help=ub.paragraph(
'''
Specify which tasks to choose from (segmentation,
before_after, or pretext. Can also specify 'all')
'''), nargs='+')
do_pca = scfg.Value(1, type=int, help=ub.paragraph(
'''
Set to 1 to perform pca. Choose output dimension in num_dim
argument.
'''))
pca_projection_path = scfg.Value('', type=str, help='Path to pca projection matrix')
track_emissions = scfg.Value(True, help='Set to false to disable codecarbon')
assets_dname = scfg.Value('_assets', help=ub.paragraph(
'''
The name of the top-level directory to write new assets.
'''))
def __post_init__(self):
if 'all' in self.tasks:
self['tasks'] = ['segmentation', 'before_after', 'pretext']
[docs]
class Predictor(object):
"""
CommandLine:
DVC_DPATH=$(geowatch_dvc)
DVC_DPATH=$DVC_DPATH xdoctest -m geowatch.tasks.invariants.predict Predictor
python -m geowatch visualize $DVC_DPATH/Drop2-Aligned-TA1-2022-02-15/test_uky.kwcoco.json \
--channels='invariants.0:3' --animate=True --with_anns=False
Example:
>>> # xdoctest: +REQUIRES(env:DVC_DPATH)
>>> from geowatch.tasks.invariants.predict import * # NOQA
>>> import kwcoco
>>> import geowatch
>>> dvc_dpath = geowatch.find_dvc_dpath()
>>> # Write out smaller version of the dataset
>>> dset = kwcoco.CocoDataset(dvc_dpath / 'Drop2-Aligned-TA1-2022-02-15/data_nowv_vali.kwcoco.json')
>>> images = dset.videos(names=['KR_R001']).images[0]
>>> sub_images = images.compress([s != 'WV' for s in images.lookup('sensor_coarse')])[::5]
>>> sub_dset = dset.subset(sub_images)
>>> sub_dset.fpath = (dvc_dpath / 'Drop2-Aligned-TA1-2022-02-15/small_test_data_nowv_vali.kwcoco.json')
>>> sub_dset.dump(sub_dset.fpath)
>>> input_kwcoco = sub_dset.fpath
>>> output_kwcoco = dvc_dpath / 'Drop2-Aligned-TA1-2022-02-15/test_uky.kwcoco.json'
>>> pretext_package_path = dvc_dpath / 'models/uky/uky_invariants_2022_03_11/TA1_pretext_model/pretext_package.pt'
>>> pca_projection_path = dvc_dpath / 'models/uky/uky_invariants_2022_03_11/TA1_pretext_model/pca_projection_matrix.pt'
>>> segmentation_package_path = dvc_dpath / 'models/uky/uky_invariants_2022_02_11/TA1_segmentation_model/segmentation_package.pt'
>>> argv = []
>>> argv += ['--input_kwcoco', f'{sub_dset.fpath}']
>>> argv += ['--output_kwcoco', f'{output_kwcoco}']
>>> argv += ['--pca_projection_path', f'{pca_projection_path}']
>>> argv += ['--pretext_package_path', f'{pretext_package_path}']
>>> argv += ['--segmentation_package_path', f'{segmentation_package_path}']
>>> argv += ['--patch_overlap', '0.25']
>>> argv += ['--workers', '2']
>>> argv += ['--tasks', 'all']
>>> argv += ['--do_pca', '1']
>>> args = InvariantPredictConfig.cli(argv=argv)
>>> self = Predictor(args)
>>> self.forward(args)
"""
def __init__(self, args):
# import xdev
# xdev.make_warnings_print_tracebacks()
# TODO: Add a cache flag. If cache==1, Determine what images we have
# already predicted for and then take a kwcoco subset containing
# the cache misses. See dzyne and rutgers predictors for example
# implementations.
# Doesnt work?
# FIX_ALBUMENTATIONS_HACK = 0
# if FIX_ALBUMENTATIONS_HACK:
# from albumentations.core import composition
# composition.Transforms = composition.TransformType
self.batch_size = args.batch_size
self.do_pca = args.do_pca
self.tasks = args.tasks
self.workers = util_parallel.coerce_num_workers(args.workers)
print('self.workers = {!r}'.format(self.workers))
self.io_workers = util_parallel.coerce_num_workers(args.io_workers)
print(f'self.io_workers={self.io_workers}')
self.devices = util_device.coerce_devices(args.device)
assert len(self.devices) == 1, 'only 1 for now'
self.device = device = self.devices[0]
print('device = {!r}'.format(device))
# Initialize models
print('Initialize models')
if 'all' in args.tasks:
self.tasks = ['segmentation', 'before_after', 'pretext']
else:
self.tasks = args.tasks
### Define tasks
if 'segmentation' in self.tasks:
if args.segmentation_package_path:
print('Initialize segmentation model from package')
self.segmentation_model = seg_model.load_package(args.segmentation_package_path)
else:
print('Initialize segmentation model from checkpoint')
self.segmentation_model = seg_model.load_from_checkpoint(args.segmentation_ckpt_path, dataset=None)
self.segmentation_model = self.segmentation_model.to(device)
if 'pretext' in self.tasks:
if args.pretext_package_path:
print('Initialize pretext model from package')
self.pretext_model = pretext.load_package(args.pretext_package_path)
else:
print('Initialize pretext model from checkpoint')
self.pretext_model = pretext.load_from_checkpoint(args.pretext_ckpt_path, train_dataset=None, vali_dataset=None)
self.pretext_model = self.pretext_model.eval().to(device)
# pretext_hparams = pretext_model.hparams
try:
# Hack
self.pretext_model.sort_accuracy = None
except Exception:
pass
self.pretext_model.__dict__['sort_accuracy'] = ub.identity # HUGE HACK
self.in_feature_dims = self.pretext_model.hparams.feature_dim_shared
if args.do_pca:
print('Initialize PCA model')
self.pca_projector = torch.load(args.pca_projection_path).to(device)
self.out_feature_dims = self.pca_projector.shape[0]
else:
self.out_feature_dims = self.in_feature_dims
self.num_out_channels = self.out_feature_dims
if 'segmentation' in self.tasks:
self.num_out_channels += 1
if 'before_after' in self.tasks:
self.num_out_channels += 1
# initialize dataset
import kwcoco
print('load coco dataset')
self.coco_dset = kwcoco.CocoDataset = kwcoco.CocoDataset.coerce(args.input_kwcoco)
# self.coco_dset = self.coco_dset.subset(list(self.coco_dset.images()[0:20]))
###
print('build grid dataset')
self.dataset = GriddedDataset(self.coco_dset,
sensor=args.sensor,
bands=args.bands,
patch_size=args.patch_size,
patch_overlap=args.patch_overlap,
window_space_scale=args.window_resolution,
input_space_scale=args.input_resolution,
output_space_scale=args.input_resolution,
mode='test')
print('copy dataset')
self.output_dset = self.dataset.coco_dset.copy()
print('reroot')
self.output_dset.reroot(absolute=True) # Make all paths absolute
self.output_dset.fpath = args.output_kwcoco # Change output file path and bundle path
self.output_dset.reroot(absolute=False) # Reroot in the new bundle path
self.finalized_gids = set()
# self.stitcher_dict = {}
self.save_channels = f'invariants:{self.num_out_channels}'
self.output_kwcoco_path = ub.Path(args.output_kwcoco)
self.imwrite_kw = {
'compress': 'DEFLATE',
'backend': 'gdal',
'blocksize': 128,
}
self.stitch_manager = CocoStitchingManager(
result_dataset=self.output_dset,
short_code='pred_invariants',
chan_code=self.save_channels,
stiching_space='video',
prob_compress=self.imwrite_kw['compress'],
quantize=True,
assets_dname=args.assets_dname,
)
from geowatch.utils import process_context
self.proc_context = process_context.ProcessContext(
type='process',
name='geowatch.tasks.invariants.predict',
config=args.to_dict(),
track_emissions=args.track_emissions,
)
[docs]
def forward(self):
device = self.device
loader = torch.utils.data.DataLoader(
self.dataset, num_workers=self.workers,
batch_size=self.batch_size, shuffle=False)
num_batches = len(loader)
# Start background processes
# Build a task queue for background write results workers
from kwutil import util_parallel
from kwutil import util_progress
writer_queue = util_parallel.BlockingJobQueue(max_workers=self.io_workers)
self.stitch_manager.writer_queue = writer_queue
self.proc_context.start()
self.proc_context.add_disk_info(ub.Path(self.output_dset.fpath).parent)
self.output_dset.dataset.setdefault('info', [])
self.output_dset.dataset['info'].append(self.proc_context.obj)
print('Evaluating and saving features')
pman = util_progress.ProgressManager()
with torch.set_grad_enabled(False), pman:
seen_images = set()
prog = pman.progiter(enumerate(loader), total=num_batches, desc='Compute invariants', verbose=1)
for idx, batch in prog:
save_feat = []
save_feat2 = []
# Handle input nans
img1 = batch['image1']
img2 = batch['image2']
offset_image1 = batch['offset_image1']
augmented_image1 = batch['augmented_image1']
invalid_mask1 = torch.isnan(img1)[0].any(dim=0)
invalid_mask2 = torch.isnan(img2)[0].any(dim=0)
batch['image1'] = torch.nan_to_num(img1).to(device)
batch['image2'] = torch.nan_to_num(img2).to(device)
batch['offset_image1'] = torch.nan_to_num(offset_image1).to(device)
batch['augmented_image1'] = torch.nan_to_num(augmented_image1).to(device)
if 'pretext' in self.tasks:
image_stack = torch.stack([batch['image1'], batch['image2'], batch['offset_image1'], batch['augmented_image1']], dim=1)
image_stack = image_stack.to(device)
# Remove nans before going into the network
image_stack = torch.nan_to_num(image_stack)
all_features = self.pretext_model(image_stack)
#select features corresponding to first image
features = all_features[:, 0, :, :, :]
#select features corresponding to second image
features2 = all_features[:, 1, :, :, :]
if self.do_pca:
features = torch.einsum('xy,byhw->bxhw', self.pca_projector, features)
features2 = torch.einsum('xy,byhw->bxhw', self.pca_projector, features2)
features = features.squeeze().permute(1, 2, 0).cpu()
features2 = features2.squeeze().permute(1, 2, 0).cpu()
features[invalid_mask1] = float('nan')
features2[invalid_mask2] = float('nan')
save_feat.append(features)
save_feat2.append(features2)
if 'before_after' in self.tasks:
### TO DO: Set to output of separate model.
before_after_heatmap = self.pretext_model.shared_step(batch)['before_after_heatmap'][0].permute(1, 2, 0)
before_after_heatmap = torch.sigmoid(torch.exp(before_after_heatmap[:, :, 1]) - torch.exp(before_after_heatmap[:, :, 0])).unsqueeze(-1).cpu()
before_after_heatmap[invalid_mask1] = float('nan')
before_after_heatmap[invalid_mask2] = float('nan')
save_feat.append(before_after_heatmap)
save_feat2.append(before_after_heatmap)
if 'segmentation' in self.tasks:
image_stack = [batch[key] for key in batch if key.startswith('image')]
image_stack = torch.stack(image_stack, dim=1).to(self.device)
predictions = torch.exp(self.segmentation_model(image_stack)['predictions'])
segmentation_heatmap = torch.sigmoid(predictions[0, 0, 1, :, :] - predictions[0, 0, 0, :, :]).unsqueeze(0).permute(1, 2, 0).cpu()
segmentation_heatmap2 = torch.sigmoid(predictions[0, 1, 1, :, :] - predictions[0, 1, 0, :, :]).unsqueeze(0).permute(1, 2, 0).cpu()
segmentation_heatmap[invalid_mask1] = float('nan')
segmentation_heatmap2[invalid_mask2] = float('nan')
save_feat.append(segmentation_heatmap)
save_feat2.append(segmentation_heatmap2)
save_feat = torch.cat(save_feat, dim=-1)
save_feat = save_feat.numpy()
save_feat2 = torch.cat(save_feat2, dim=-1)
save_feat2 = save_feat2.numpy()
target = self.dataset.patches[idx]
# These dataloader has told us that these iamges are now
# complete, and thus can be finalized free any memory used by
# its stitcher
new_complete_gids = target.get('new_complete_gids', [])
for gid in new_complete_gids:
assert gid not in seen_images
seen_images.add(gid)
# if 1:
# img = self.dataset.coco_dset.index.imgs[gid]
# frame_index = img['frame_index']
# video_id = img['video_id']
# prog.ensure_newline()
# print(f'finalize {video_id=}, {gid=}, {frame_index=}')
# writer_queue.submit(self.stitch_manager.finalize, gid)
self.stitch_manager.submit_finalize_image(gid)
gid1, gid2 = target['gids']
# slice_ = target['space_slice']
# stitcher1 = self.ensure_stitcher(gid1)
# stitcher2 = self.ensure_stitcher(gid2)
# CocoStitchingManager._stitcher_center_weighted_add(
# stitcher1, slice_, save_feat)
# CocoStitchingManager._stitcher_center_weighted_add(
# stitcher2, slice_, save_feat2)
sample_outspace_ltrb = kwimage.Box.coerce(batch['sample_outspace_ltrb'].numpy(), format='ltrb')
full_stitch_outspace_box = kwimage.Box.coerce(batch['full_stitch_outspace_ltrb'].numpy(), format='ltrb')
scale_outspace_from_vid = batch['scale_outspace_from_vid'].numpy()[0]
outspace_slice = sample_outspace_ltrb.to_slice()
outspace_dsize = full_stitch_outspace_box.dsize
self.stitch_manager.accumulate_image(
gid1, outspace_slice, save_feat,
dsize=outspace_dsize,
scale=scale_outspace_from_vid)
self.stitch_manager.accumulate_image(
gid2, outspace_slice, save_feat2,
dsize=outspace_dsize,
scale=scale_outspace_from_vid)
print('Finalize already compelted jobs')
writer_queue.wait_until_finished(desc='Finalize submitted jobs')
# Finalize everything else that hasn't completed
for gid in pman.progiter(list(self.stitch_manager.image_stitchers.keys()), desc='submit loose write jobs'):
if gid not in seen_images:
seen_images.add(gid)
self.stitch_manager.submit_finalize_image(gid)
# writer_queue.submit(self.stitch_manager.finalize, gid)
print('Finalize loose jobs')
writer_queue.wait_until_finished()
print('Finish process context')
self.proc_context.add_device_info(device)
self.proc_context.stop()
print('Write to dset.fpath = {!r}'.format(self.output_dset.fpath))
self.output_dset.dump(self.output_dset.fpath, newlines=True)
print('Done')
[docs]
def main():
args = InvariantPredictConfig.cli()
import rich
rich.print('config = {}'.format(ub.urepr(args, nl=1)))
Predictor(args).forward()
if __name__ == '__main__':
"""
SeeAlso:
../../cli/queue_cli/prepare_teamfeats.py
# Team Features on Drop3
DVC_DPATH=$(geowatch_dvc)
KWCOCO_BUNDLE_DPATH=$DVC_DPATH/Aligned-Drop3-TA1-2022-03-10
python -m geowatch.cli.queue_cli.prepare_teamfeats \
--base_fpath=$KWCOCO_BUNDLE_DPATH/data.kwcoco.json \
--with_depth=0 \
--with_landcover=0 \
--with_materials=0 \
--with_invariants=1 \
--do_splits=0 \
--gres=0 --backend=serial --run=1
CommandLine:
python -m geowatch.tasks.template.predict --help
DVC_DPATH=$(geowatch_dvc)
PRETEXT_PATH=$DVC_DPATH/models/uky/uky_invariants_2022_02_11/TA1_pretext_model/pretext_package.pt
SSEG_PATH=$DVC_DPATH/models/uky/uky_invariants_2022_02_11/TA1_segmentation_model/segmentation_package.pt
PCA_FPATH=$DVC_DPATH/models/uky/uky_invariants_2022_02_11/TA1_pretext_model/pca_projection_matrix.pt
KWCOCO_BUNDLE_DPATH=$DVC_DPATH/Drop2-Aligned-TA1-2022-02-15
python -m geowatch.tasks.invariants.predict \
--pretext_package_path "$PRETEXT_PATH" \
--segmentation_package_path "$SSEG_PATH" \
--pca_projection_path "$PCA_FPATH" \
--input_kwcoco $KWCOCO_BUNDLE_DPATH/data.kwcoco.json \
--workers=avail \
--do_pca 0 \
--patch_overlap=0.3 \
--output_kwcoco $KWCOCO_BUNDLE_DPATH/uky_invariants.kwcoco.json \
--tasks before_after pretext
python -m geowatch stats $KWCOCO_BUNDLE_DPATH/uky_invariants.kwcoco.json
python -m geowatch visualize $KWCOCO_BUNDLE_DPATH/uky_invariants/invariants_nowv_vali.kwcoco.json \
--channels "invariants.7,invariants.6,invariants.5" --animate=True \
--select_images '.sensor_coarse != "WV"' --draw_anns=False
Ignore:
### Command 1 / 2 - geowatch-teamfeat-job-0
python -m geowatch.tasks.invariants.predict \
--input_kwcoco "/home/joncrall/remote/toothbrush/data/dvc-repos/smart_data_dvc/Aligned-Drop4-2022-08-08-TA1-S2-L8-ACC/data_kr1br2.kwcoco.json" \
--output_kwcoco "/home/joncrall/remote/toothbrush/data/dvc-repos/smart_data_dvc/Aligned-Drop4-2022-08-08-TA1-S2-L8-ACC/data_kr1br2_uky_invariants.kwcoco.json" \
--pretext_package_path "/home/joncrall/remote/toothbrush/data/dvc-repos/smart_expt_dvc/models/uky/uky_invariants_2022_03_21/pretext_model/pretext_package.pt" \
--pca_projection_path "/home/joncrall/remote/toothbrush/data/dvc-repos/smart_expt_dvc/models/uky/uky_invariants_2022_03_21/pretext_model/pretext_pca_104.pt" \
--do_pca 0 \
--patch_overlap=0.0 \
--workers="2" \
--io_workers 0 \
--tasks before_after pretext
cd /home/joncrall/remote/toothbrush/data/dvc-repos/smart_data_dvc-ssd/Aligned-Drop4-2022-08-08-TA1-S2-L8-ACC
kwcoco subset --src=data.kwcoco.json --dst=AE_R001.kwcoco.json --select_videos='.name == "AE_R001"'
kwcoco subset --src=data.kwcoco.json --dst=NZ_R001.kwcoco.json --select_videos='.name == "NZ_R001"'
python -m geowatch.tasks.invariants.predict \
--input_kwcoco "/home/joncrall/remote/toothbrush/data/dvc-repos/smart_data_dvc-ssd/Aligned-Drop4-2022-08-08-TA1-S2-L8-ACC/KR_R001.kwcoco.json" \
--output_kwcoco "/home/joncrall/remote/toothbrush/data/dvc-repos/smart_data_dvc-ssd/Aligned-Drop4-2022-08-08-TA1-S2-L8-ACC/KR_R001_invariants.kwcoco.json" \
--pretext_package_path "/home/joncrall/remote/toothbrush/data/dvc-repos/smart_expt_dvc/models/uky/uky_invariants_2022_03_21/pretext_model/pretext_package.pt" \
--pca_projection_path "/home/joncrall/remote/toothbrush/data/dvc-repos/smart_expt_dvc/models/uky/uky_invariants_2022_03_21/pretext_model/pretext_pca_104.pt" \
--input_space_scale=30GSD \
--window_space_scale=30GSD \
--patch_size=256 \
--do_pca 0 \
--patch_overlap=0.0 \
--workers="2" \
--io_workers 2 \
--tasks before_after pretext
python -m geowatch visualize KR_R001_invariants.kwcoco.json \
--channels "invariants.5:8,invariants.8:11,invariants.14:17" --stack=only --workers=avail --animate=True \
--draw_anns=False
"""
main()