r"""
Prediction script for landcover features.
Given a checkout of the model and drop6 data, the following demos computing and
visualizing a subset of the features.
CommandLine:
DVC_EXPT_DPATH=$(geowatch_dvc --tags=phase2_expt --hardware=auto)
DVC_DATA_DPATH=$(geowatch_dvc --tags=phase2_data --hardware=auto)
KWCOCO_BUNDLE_DPATH=$DVC_DATA_DPATH/Drop6
DZYNE_LANDCOVER_MODEL_FPATH="$DVC_EXPT_DPATH/models/landcover/sentinel2.pt"
INPUT_DATASET_FPATH=$KWCOCO_BUNDLE_DPATH/imganns-KR_R001.kwcoco.zip
OUTPUT_DATASET_FPATH=$KWCOCO_BUNDLE_DPATH/imganns-KR_R001_landcover_small.kwcoco.zip
echo "
DVC_DATA_DPATH="$DVC_DATA_DPATH"
DVC_EXPT_DPATH="$DVC_EXPT_DPATH"
DZYNE_LANDCOVER_MODEL_FPATH="$DZYNE_LANDCOVER_MODEL_FPATH"
INPUT_DATASET_FPATH="$INPUT_DATASET_FPATH"
OUTPUT_DATASET_FPATH="$OUTPUT_DATASET_FPATH"
"
export CUDA_VISIBLE_DEVICES="1"
python -m geowatch.tasks.landcover.predict \
--dataset="$INPUT_DATASET_FPATH" \
--deployed="$DZYNE_LANDCOVER_MODEL_FPATH" \
--device=0 \
--num_workers=4 \
--select_images='(.frame_index < 100) and (.sensor_coarse == "S2")' \
--with_hidden=6 \
--output="$OUTPUT_DATASET_FPATH"
geowatch stats $OUTPUT_DATASET_FPATH
geowatch visualize $OUTPUT_DATASET_FPATH \
--animate=True --channels="red|green|blue,barren|forest|water,landcover_hidden.0:3,landcover_hidden.3:6" \
--skip_missing=True --workers=4 --draw_anns=False --smart=True
"""
import torch
import datetime
import ubelt as ub
from pathlib import Path
import kwcoco
from torch.utils.data import DataLoader
from kwutil import util_parallel
from kwutil import util_progress
from . import detector
from .model_info import lookup_model_info
from .utils import setup_logging
import scriptconfig as scfg
[docs]
class LandcoverPredictConfig(scfg.DataConfig):
dataset = scfg.Value(None, required=True, help='input kwcoco dataset')
deployed = scfg.Value(None, required=True, help='pytorch weights file')
output = scfg.Value(None, required=True, help='output kwcoco dataset')
num_workers = scfg.Value(0, type=str, help='number of dataloading workers. Can be "auto"')
io_workers = scfg.Value('auto', type=str, help='Number of writer threads. Defaults to min(num_workers, 2)')
device = scfg.Value('auto', type=str, help='auto, cpu, or integer of the device to use')
select_images = scfg.Value(None, type=str, help='if specified, a jq operation to filter images')
select_videos = scfg.Value(None, type=str, help='if specified, a jq operation to filter videos')
with_hidden = scfg.Value(None, type=int, help='if true, also write out this many of the hidden activations')
track_emissions = scfg.Value(True, help='Set to False to disable codecarbon')
window_dim = scfg.Value(1024, help='Set to False to disable codecarbon')
assets_dname = scfg.Value('_assets', help=ub.paragraph(
'''
The name of the top-level directory to write new assets.
'''))
[docs]
def predict(cmdline=1, **kwargs):
"""
Example:
>>> # xdoctest: +REQUIRES(env:DVC_DPATH)
>>> from geowatch.tasks.landcover.predict import * # NOQA
>>> from geowatch.tasks.landcover.predict import _predict_single
>>> import kwcoco
>>> import geowatch
>>> dvc_data_dpath = geowatch.find_dvc_dpath(tags='phase2_data', hardware='auto')
>>> dvc_expt_dpath = geowatch.find_dvc_dpath(tags='phase2_expt', hardware='auto')
>>> dset = kwcoco.CocoDataset(dvc_data_dpath / 'Drop6/imganns-KR_R001.kwcoco.zip')
>>> deployed = dvc_expt_dpath / 'models/landcover/sentinel2.pt'
>>> kwargs = {
>>> 'dataset': dset.fpath,
>>> 'deployed': deployed,
>>> 'output': ub.Path(dset.fpath).augment(stemsuffix='_landcover', multidot=True),
>>> 'select_images': '.sensor_coarse == "S2"',
>>> }
>>> cmdline = 0
>>> predict(cmdline, **kwargs)
"""
from geowatch.utils.lightning_ext import util_device
from geowatch.utils import process_context
from geowatch.utils import kwcoco_extensions
from geowatch.tasks.fusion.coco_stitcher import CocoStitchingManager
import rich
config = LandcoverPredictConfig.cli(cmdline=cmdline, data=kwargs, strict=True)
rich.print('config = {}'.format(ub.urepr(config, align=':', nl=1)))
coco_dset_filename = config.dataset
weights_filename = Path(config.deployed)
output_dset_filename = get_output_file(config.output)
deployed = ub.Path(config.deployed)
if not deployed.is_file():
raise ValueError('Landcover model does not exist')
device = util_device.coerce_devices(config.device)[0]
print(f'device={device}')
config.num_workers = util_parallel.coerce_num_workers(config.num_workers)
if config.io_workers == 'auto':
config.io_workers = min(2, config.num_workers)
config.io_workers = util_parallel.coerce_num_workers(config.io_workers)
input_dset = kwcoco.CocoDataset.coerce(coco_dset_filename)
filtered_gids = kwcoco_extensions.filter_image_ids(
input_dset, include_sensors=None, exclude_sensors=None,
select_images=config.select_images, select_videos=config.select_videos)
input_dset = input_dset.subset(filtered_gids)
print('Selected input_dset = {!r}'.format(input_dset))
model_info = lookup_model_info(weights_filename)
ptdataset = model_info.create_dataset(input_dset)
model = model_info.load_model(weights_filename, device)
print('Using {}'.format(type(model_info).__name__))
print('Creating output dataset')
output_dset = input_dset.copy()
output_dset.fpath = output_dset_filename
print('Initialize dataloader')
dataloader = DataLoader(ptdataset, num_workers=config.num_workers,
batch_size=None, collate_fn=lambda x: x)
# Start the worker processes before we do threading
print('Initialize dataloader iter')
dataloader_iter = iter(dataloader)
# Create a queue that writes data to disk in the background
print('Initialize stitchers')
writer_queue = util_parallel.BlockingJobQueue(max_workers=config.io_workers)
landcover_stitcher = CocoStitchingManager(
output_dset,
'landcover',
chan_code='|'.join(model_info.model_outputs),
stiching_space='image',
writer_queue=writer_queue,
expected_minmax=(0, 1),
assets_dname=config.assets_dname,
)
num_hidden = config.with_hidden
if config.with_hidden:
_register_hidden_layer_hook(model)
hidden_stitcher = CocoStitchingManager(
output_dset,
'landcover_hidden',
chan_code=f'landcover_hidden.0:{num_hidden}',
stiching_space='image',
writer_queue=writer_queue,
assets_dname=config.assets_dname,
)
hidden_stitcher.num_hidden = num_hidden
else:
model._activation_cache = None
hidden_stitcher = None
print('Initialize process context')
proc_context = process_context.ProcessContext(
type='process',
name='geowatch.tasks.invariants.predict',
config=config.to_dict(),
track_emissions=config.track_emissions,
)
proc_context.start()
print('Starting main predict loop')
# pman = util_progress.ProgressManager('progiter')
pman = util_progress.ProgressManager('rich')
window_dim = config.window_dim
with pman, torch.no_grad():
_prog = pman.progiter(dataloader_iter, total=len(dataloader),
desc='predict landcover')
for img_info in _prog:
try:
_predict_single(
img_info, model, model_info.model_outputs,
landcover_stitcher, hidden_stitcher,
output_dset=output_dset, window_dim=window_dim)
except KeyboardInterrupt:
print('interrupted')
break
except Exception as ex:
print('warning: ex = {}'.format(ub.urepr(ex, nl=1)))
print('Unable to load id:{} - {}'.format(img_info['id'], img_info['name']))
raise
writer_queue.wait_until_finished()
print('Finish process context')
proc_context.add_disk_info(ub.Path(input_dset.fpath).parent)
proc_context.add_device_info(device)
proc_context.stop()
output_dset.dataset['info'].append(proc_context.obj)
output_dset.dump(str(output_dset_filename), indent=2)
print('output written to {}'.format(output_dset_filename))
def _register_hidden_layer_hook(model):
# TODO: generalize to other models
# Specific to UNetR model
# These are at half of the output image resolution.
model._activation_cache = {}
def record_hidden_activation(layer, input, output):
activation = output.detach()
model._activation_cache['hidden'] = activation
layer_of_interest = model.decoder1[3]
layer_of_interest._forward_hooks.clear()
layer_of_interest.register_forward_hook(record_hidden_activation)
def _predict_single(img_info,
model,
model_outputs,
landcover_stitcher,
hidden_stitcher,
output_dset: kwcoco.CocoDataset, window_dim=1024):
"""
Modifies the coco dataset inplace, returns the data that needs to be
written to disk.
"""
import kwarray
import kwimage
gid = img_info['id']
img = img_info['imgdata']
## Hardcoded params
window_dims = (window_dim, window_dim)
window_overlap = 0.3
hidden_scale = 0.5 # scale from raw predictions to hidden predictions
image_box = kwimage.Box.from_dsize(img.shape[0:2][::-1])
if hidden_stitcher is not None:
hidden_box = image_box.scale(hidden_scale).astype(int)
hidden_dsize = hidden_box.dsize
# Need to run a sliding window so we can manage larger image
slider = kwarray.SlidingWindow(img.shape[0:2], window_dims,
overlap=window_overlap,
keepbound=True,
allow_overshoot=True)
valid_pred = False
for img_slice in slider:
subimg = img[img_slice]
pred = detector.run(model, subimg, img_info)
if pred is None:
continue
else:
valid_pred = True
scale = img.shape[0] / img_info["height"]
landcover_stitcher.accumulate_image(gid, img_slice, pred, asset_dsize=img.shape[0:2][::-1], scale_asset_from_stitchspace=scale)
if hidden_stitcher is not None:
# Lots of hardcoded things here.
# Hack to unpad here.
if 'hidden' not in model._activation_cache:
print('WARNING: we expected hidden, but its not there')
else:
hidden_raw = model._activation_cache['hidden'].cpu().numpy()
model._activation_cache.pop('hidden', None)
hidden = hidden_raw[0].transpose(1, 2, 0)
h, w = pred.shape[0:2]
hidden = hidden[0:h // 2, 0:w // 2, 0:hidden_stitcher.num_hidden]
output_box = kwimage.Box.from_slice(img_slice)
hidden_box = output_box.scale(hidden_scale).astype(int)
hidden_slice = hidden_box.to_slice()
hidden_stitcher.accumulate_image(
gid, hidden_slice, hidden, asset_dsize=hidden_dsize,
scale_asset_from_stitchspace=hidden_scale)
if valid_pred:
landcover_stitcher.submit_finalize_image(gid)
if hidden_stitcher is not None:
hidden_stitcher.submit_finalize_image(gid)
[docs]
def get_output_file(output):
default_output_filename = 'out_{}.kwcoco.json'.format(datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
if output is None:
output_dir = Path('/tmp')
return output_dir.joinpath(default_output_filename)
else:
output = Path(output)
if output.is_dir():
return output.joinpath(default_output_filename)
else:
return output
if __name__ == '__main__':
setup_logging()
predict()