import json
import logging
import warnings
from functools import partial
import ubelt as ub
import os
import click
import kwcoco
import kwimage
import numpy as np
import torch
# import torchvision.transforms
# from medpy.filter.smoothing import anisotropic_diffusion
from torch.utils.data import DataLoader
from tqdm import tqdm
from scipy import ndimage
import geowatch_tpl # NOQA
from .datasets import WVRgbDataset, WVSuperRgbDataset
# from .pl_highres_verify import MultiTaskModel, modify_bn, dfactor, local_utils
from .pl_highres_verify import MultiTaskModel, modify_bn, dfactor
from .utils import process_image_chunked
from ..landcover.detector import get_device
from ..landcover.predict import get_output_file
from ..landcover.utils import setup_logging
log = logging.getLogger(__name__)
# Debug variables
ENABLE_WRITES = 1 # set to zero to disable writing to disk
ENABLE_MODEL = 1 # set to zero to disable running the model forward
ENABLE_PROCESS = 1 # set to zero to disable chunked process
ENABLE_CHUNKS = 1 # set to zero to disable chunking (increases VRAM)
@click.command()
@click.option('--dataset', required=True, type=click.Path(exists=True), help='input kwcoco dataset')
@click.option('--deployed', required=True, type=click.Path(exists=True), help='pytorch weights file')
@click.option('--output', required=False, type=click.Path(), help='output kwcoco dataset')
@click.option('--window_size', required=False, type=int, default=1024, help='sliding window size')
@click.option('--dump_shards', required=False, default=False, help='if True, output partial kwcoco files as they are completed')
@click.option('--data_workers', required=False, default=0, help='background data loaders')
@click.option('--select_images', required=False, default=None, help='if specified, a jq operation to filter images')
@click.option('--select_videos', required=False, default=None, help='if specified, a jq operation to filter videos')
@click.option('--asset_suffix', required=False, default='_assets/dzyne_depth', help='folder relative to output to save features in')
@click.option('--cache', required=False, default=0, help='if True, enable caching of results')
@click.option('--use_super_res_bands', required=False, default=False, help='if True, uses super-res (upsampled) bands to form source RGB image')
@click.option('--scale', required=False, default=4, help='Specify scaling factor between image size and super-res bands, only applicable when use_super_res_bands is true')
def predict(dataset, deployed, output, window_size=2048, dump_shards=False,
data_workers=0, select_images=None, select_videos=None,
asset_suffix='_assets/dzyne_depth', cache=False, use_super_res_bands=False, scale=4):
"""
Example:
>>> # xdoctest: +REQUIRES(env:DVC_DPATH)
>>> from geowatch.tasks.depth.predict import * # NOQA
>>> import geowatch
>>> dvc_dpath = geowatch.find_dvc_dpath()
>>> dataset = dvc_dpath / 'Drop2-Aligned-TA1-2022-02-15/data_vali.kwcoco.json'
>>> output = dvc_dpath / 'Drop2-Aligned-TA1-2022-02-15/dzyne_depth_test.kwcoco.json'
>>> deployed = dvc_dpath / "models/depth/weights_v1.pt"
>>> data_workers = 0
>>> cache = 0
>>> select_images = '.name == "crop_20150123T020752Z_N37.734145E128.855484_N37.811709E128.946746_WV_0"'
>>> select_videos = None
>>> dump_shards = False
>>> window_size = 1536
>>> asset_suffix = '_assets/test_dzyne_depth'
>>> predict.callback(dataset=dataset, deployed=deployed, output=output,
>>> window_size=window_size, dump_shards=dump_shards,
>>> data_workers=data_workers, select_images=select_images,
>>> select_videos=select_videos, cache=cache)
"""
weights_filename = ub.Path(deployed)
output_dset_filename = ub.Path(get_output_file(output))
output_bundle_dpath = output_dset_filename.parent
output_data_dir = output_bundle_dpath / asset_suffix
log.info('Input: {}'.format(dataset))
log.info('Weights: {}'.format(weights_filename))
log.info('Output: {}'.format(output_dset_filename))
log.info('Output Images: {}'.format(output_data_dir))
input_dset = kwcoco.CocoDataset.coerce(dataset)
input_bundle_dpath = ub.Path(input_dset.bundle_dpath)
from geowatch.utils import kwcoco_extensions
from geowatch.tasks.fusion.predict import quantize_float01
filtered_gids = kwcoco_extensions.filter_image_ids(
input_dset,
include_sensors=['WV'],
select_images=select_images,
select_videos=select_videos
)
log.info('Valid Images: {}'.format(len(filtered_gids)))
input_dset = input_dset.subset(filtered_gids)
output_dset = input_dset.copy()
if input_bundle_dpath != output_bundle_dpath:
# Need to change the root of the output directory
# The kwcoco reroot logic is flakey for complex cases, so be careful
# In the normal case where the output and input kwcoco share the same
# bundle, then this logic is avoided
output_dset.reroot(absolute=True)
output_dset.fpath = str(output_dset_filename)
new_prefix = os.path.relpath(input_bundle_dpath, output_bundle_dpath)
output_dset.reroot(old_prefix=str(input_bundle_dpath),
new_prefix=str(new_prefix), absolute=False,
check=True)
# input data
if use_super_res_bands is True:
torch_dataset = WVSuperRgbDataset(dset=input_dset, scale=scale) # TODO - auto-detect scaling
else:
torch_dataset = WVRgbDataset(input_dset)
if cache:
log.debug('checking for cached files')
# Remove any image ids that are already computed
gid_to_pred_filename = {}
miss_gids = []
hit_gids = []
for gid in torch_dataset.gids:
img_info = torch_dataset.dset.imgs[gid]
pred_filename = _image_pred_filename(torch_dataset,
output_data_dir, img_info)
gid_to_pred_filename[gid] = pred_filename
if pred_filename.exists():
hit_gids.append(gid)
else:
miss_gids.append(gid)
log.info(
f'Found {len(hit_gids)} / {len(gid_to_pred_filename)} cached depth maps')
# Might be a better way to indicate a subset, but this works
torch_dataset.gids = miss_gids
# model
log.debug('loading model')
config = _load_config()
config['backbone_params']['pretrained'] = False # dont download on predict
model = MultiTaskModel(config=config)
state_dict = torch.load(
weights_filename,
map_location=lambda storage,
loc: storage)
model.load_state_dict(state_dict)
model = modify_bn(model, track_running_stats=False, bn_momentum=0.01)
model = model.eval()
model.to(get_device())
S = window_size
chip_size = (S, S, 3)
overlap = (128, 128, 0)
output_dtype = np.float32 # Will be quantized as a final step
process_func = partial(run_inference, model=model)
log.debug('processing images')
dataloader = DataLoader(torch_dataset, num_workers=data_workers,
batch_size=1, pin_memory=1, collate_fn=lambda x: x)
with torch.no_grad():
for batch in tqdm(dataloader, miniters=1, unit='image', disable=False):
assert len(batch) == 1
batch_item = batch[0]
gid = batch_item['id']
# get clean img_info
img_info = torch_dataset.dset.imgs[gid]
pred_filename = _image_pred_filename(torch_dataset,
output_data_dir, img_info)
if cache and pred_filename.exists():
# Dereference items after we are done with them
batch_item = None
image = None
continue
try:
image = batch_item['imgdata']
log.info('processing image {}'.format(image.shape))
if ENABLE_PROCESS:
use_chunks = ENABLE_CHUNKS
if image.shape[0] < S and image.shape[1] < S:
use_chunks = 0
if use_chunks:
pred = process_image_chunked(image, process_func,
chip_size=chip_size,
overlap=overlap,
output_dtype=output_dtype)
else:
pred = process_func(image)
else:
pred = np.zeros(image.shape[0:2], dtype=output_dtype)
# Dereference items after we are done with them
batch_item = None # dereference for memory
image = None # dereference for memory
quant_pred, quantization = quantize_float01(
pred, old_min=0, old_max=1, quantize_dtype=np.uint8)
pred = None # dereference for memory
info = _write_output(img_info, quant_pred, pred_filename,
output_bundle_dpath, quantization)
quant_pred = None # dereference for memory
aux = output_dset.imgs[gid].get('auxiliary', [])
aux.append(info)
output_dset.imgs[gid]['auxiliary'] = aux
if dump_shards:
# Dump debugging shard
shard_dset = output_dset.subset([gid])
shard_dset.reroot(absolute=True)
shard_dset.fpath = pred_filename.augment(
ext='.kwcoco.json')
# output_dpath / (imgname + '_depth.kwcoco.json')
shard_dset.dump(shard_dset.fpath, indent=2)
except KeyboardInterrupt:
log.info('interrupted')
break
except Exception:
log.exception(
'Unable to load id:{} - {}'.format(img_info['id'], img_info['name']))
if cache and hit_gids:
from geowatch.utils import util_gdal
# add metadata for cache items
for gid in hit_gids:
img_info = torch_dataset.dset.imgs[gid]
pred_filename = _image_pred_filename(torch_dataset,
output_data_dir, img_info)
with util_gdal.GdalDataset.open(pred_filename, 'r') as gdal_img:
pred_shape = (gdal_img.RasterYSize, gdal_img.RasterXSize,
gdal_img.RasterCount)
# pred_shape = kwimage.load_image_shape(pred_filename)
# Hack to get the quantization dict that would have been computed
# at predict time.
_, quantization = quantize_float01(
None, old_min=0, old_max=1, quantize_dtype=np.uint8)
info = _build_aux_info(img_info, pred_shape, pred_filename,
output_bundle_dpath, quantization)
aux = output_dset.imgs[gid].get('auxiliary', [])
aux.append(info)
output_dset.imgs[gid]['auxiliary'] = aux
if ENABLE_WRITES:
output_dset.dump(str(output_dset_filename), indent=2)
output_dset.validate()
log.info('output written to {}'.format(output_dset_filename))
def _image_pred_filename(torch_dataset, output_data_dir, img_info):
# Construct an output file name based on the video and image name
imgname = img_info['name']
vidid = img_info.get('video_id', None)
if vidid is not None:
vidname = torch_dataset.dset.index.videos[vidid]['name']
output_dpath = output_data_dir / vidname
else:
output_dpath = output_data_dir
pred_filename = output_dpath / (imgname + '_depth.tif')
return pred_filename
[docs]
def fake_model(batch2, tta=True):
# For testing
np_data = batch2['image'][0].permute(1, 2, 0).numpy()
x = kwimage.gaussian_blur(np_data, sigma=7)
depth = torch.from_numpy(x.mean(axis=2))[None, None]
pred2 = dict(depth=depth, seg=depth)
return pred2, batch2
def _test():
"""
Small test to check that stitching logic works when nan regions are
involved.
"""
import kwimage
src = kwimage.ensure_float01(kwimage.grab_test_image(dsize=(2048, 2048)))
nan_poly = kwimage.Polygon.random(rng=32021).scale(src.shape[0] * 3)
image = nan_poly.fill(src.copy() * 255, np.nan)
output_dtype = np.uint8
overlap = (0, 0, 0)
chip_size = (512, 512, 3)
model = fake_model
process_func = partial(run_inference, model=model, device='cpu')
pred = process_image_chunked(image, process_func,
chip_size=chip_size,
overlap=overlap,
output_dtype=output_dtype)
from geowatch.tasks.fusion.predict import quantize_float01
quant_pred, quantization = quantize_float01(pred, old_min=0, old_max=1,
quantize_dtype=np.uint8)
print('quantization = {}'.format(ub.urepr(quantization, nl=1)))
import kwplot
kwplot.autompl()
kwplot.imshow(kwimage.normalize_intensity(
image), pnum=(1, 3, 1), doclf=True)
kwplot.imshow(pred, pnum=(1, 3, 2))
kwplot.imshow(quant_pred, pnum=(1, 3, 3))
[docs]
def run_inference(image, model, device=0):
"""
Example:
>>> from geowatch.tasks.depth.predict import * # NOQA
>>> import kwimage
>>> import kwarray
>>> src = kwimage.ensure_float01(kwimage.grab_test_image(dsize=(512, 512)))
>>> src = kwimage.Polygon.random(rng=None).scale(src.shape[0]).fill(src.copy(), np.nan)
>>> model = fake_model
>>> image = src * 255
>>> device = 'cpu'
>>> result = run_inference(image, model, device=device)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(src, pnum=(1, 2, 1), doclf=True)
>>> kwplot.imshow(result, pnum=(1, 2, 2))
"""
if not ENABLE_MODEL:
pred_shape = image.shape[0:2]
return np.full(pred_shape, fill_value=np.nan, dtype=np.float32)
img_h, img_w = image.shape[0:2]
pad_h = pad_w = 0
min_w, min_h = 64, 64
if img_h < min_h or img_w < min_w:
pad_h = min_h - img_h
pad_w = min_w - img_w
pad_dims = ((0, pad_h), (0, pad_w), (0, 0))
image = np.pad(image, pad_dims, mode='constant',
constant_values=np.nan)
with torch.no_grad():
nodata_mask = np.isnan(image)
if not np.all(nodata_mask):
# Replace nans with zeros before going into the network
image_float = image / 255.0 # not sure why we want to do this...
# image_float = image.copy()
image_float[nodata_mask] = 0
image_tensor = torch.from_numpy(
image_float.transpose(
(2, 0, 1))).contiguous()
mean = np.nanmean(image.reshape(-1, image.shape[-1]), axis=0)
std = np.nanstd(image.reshape(-1, image.shape[-1]), axis=0)
batch2 = {
"image": image_tensor[None, ...].to(device),
"image_mean": torch.from_numpy(mean)[None, ...].to(device),
"image_std": torch.from_numpy(std)[None, ...].to(device),
}
pred2, batch2 = model(batch2, tta=True)
output_depth = pred2['depth'][0, 0, :, :].cpu().data.numpy()
output_label = pred2['seg'][0, 0, :, :].cpu().data.numpy()
# output_depth[nodata_mask.all(axis=2)] = np.nan
weighted_depth = dfactor * output_depth
alpha = 0.9
weighted_seg = alpha * output_label + \
(1.0 - alpha) * np.minimum(0.99, weighted_depth / 70.0)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
# tmp2 = 255 * anisotropic_diffusion(weighted_seg, niter=1, kappa=100, gamma=0.8)
tmp2 = anisotropic_diffusion(
weighted_seg, niter=1, kappa=100, gamma=0.8)
# weighted_final = ndimage.median_filter(tmp2.astype(np.uint8), size=7)
weighted_final = ndimage.median_filter(tmp2, size=7)
weighted_final[nodata_mask.all(axis=2)] = np.nan
# weighted_final = ndimage.median_filter(tmp2.astype(np.uint8), size=7)
if pad_h or pad_w:
# Undo pad
weighted_final = weighted_final[:-pad_h, :-pad_w]
else:
pred_shape = image.shape[0:2]
return np.full(pred_shape, fill_value=np.nan, dtype=np.float32)
return weighted_final
# Vendored in to deal with 3.10 issue
[docs]
def anisotropic_diffusion(img, niter=1, kappa=50,
gamma=0.1, voxelspacing=None, option=1):
r"""
Edge-preserving, XD Anisotropic diffusion.
Parameters
----------
img : array_like
Input image (will be cast to numpy.float).
niter : integer
Number of iterations.
kappa : integer
Conduction coefficient, e.g. 20-100. ``kappa`` controls conduction
as a function of the gradient. If ``kappa`` is low small intensity
gradients are able to block conduction and hence diffusion across
steep edges. A large value reduces the influence of intensity gradients
on conduction.
gamma : float
Controls the speed of diffusion. Pick a value :math:`<= .25` for stability.
voxelspacing : tuple of floats or array_like
The distance between adjacent pixels in all img.ndim directions
option : {1, 2, 3}
Whether to use the Perona Malik diffusion equation No. 1 or No. 2,
or Tukey's biweight function.
Equation 1 favours high contrast edges over low contrast ones, while
equation 2 favours wide regions over smaller ones. See [1]_ for details.
Equation 3 preserves sharper boundaries than previous formulations and
improves the automatic stopping of the diffusion. See [2]_ for details.
Returns
-------
anisotropic_diffusion : ndarray
Diffused image.
Notes
-----
Original MATLAB code by Peter Kovesi,
School of Computer Science & Software Engineering,
The University of Western Australia,
pk @ csse uwa edu au,
<http://www.csse.uwa.edu.au>
Translated to Python and optimised by Alistair Muldal,
Department of Pharmacology,
University of Oxford,
<alistair.muldal@pharm.ox.ac.uk>
Adapted to arbitrary dimensionality and added to the MedPy library Oskar Maier,
Institute for Medical Informatics,
Universitaet Luebeck,
<oskar.maier@googlemail.com>
June 2000 original version. -
March 2002 corrected diffusion eqn No 2. -
July 2012 translated to Python -
August 2013 incorporated into MedPy, arbitrary dimensionality -
References
----------
.. [1] P. Perona and J. Malik.
Scale-space and edge detection using ansotropic diffusion.
IEEE Transactions on Pattern Analysis and Machine Intelligence,
12(7):629-639, July 1990.
.. [2] M.J. Black, G. Sapiro, D. Marimont, D. Heeger
Robust anisotropic diffusion.
IEEE Transactions on Image Processing,
7(3):421-432, March 1998.
"""
# define conduction gradients functions
import numpy
if option == 1:
def condgradient(delta, spacing):
return numpy.exp(-(delta / kappa)**2.) / float(spacing)
elif option == 2:
def condgradient(delta, spacing):
return 1. / (1. + (delta / kappa)**2.) / float(spacing)
elif option == 3:
kappa_s = kappa * (2**0.5)
def condgradient(delta, spacing):
top = 0.5 * ((1. - (delta / kappa_s)**2.)**2.) / float(spacing)
return numpy.where(numpy.abs(delta) <= kappa_s, top, 0)
# initialize output array
out = numpy.array(img, dtype=numpy.float32, copy=True)
# set default voxel spacing if not supplied
if voxelspacing is None:
voxelspacing = tuple([1.] * img.ndim)
# initialize some internal variables
deltas = [numpy.zeros_like(out) for _ in range(out.ndim)]
for _ in range(niter):
# calculate the diffs
for i in range(out.ndim):
slicer = tuple([slice(None, -1) if j == i else slice(None)
for j in range(out.ndim)])
deltas[i][slicer] = numpy.diff(out, axis=i)
# update matrices
matrices = [
condgradient(
delta,
spacing) *
delta for delta,
spacing in zip(
deltas,
voxelspacing)]
# subtract a copy that has been shifted ('Up/North/West' in 3D case) by one
# pixel. Don't as questions. just do it. trust me.
for i in range(out.ndim):
slicer = tuple([slice(1, None) if j == i else slice(None)
for j in range(out.ndim)])
matrices[i][slicer] = numpy.diff(matrices[i], axis=i)
# update the image
out += gamma * (numpy.sum(matrices, axis=0))
return out
def _build_aux_info(img_info, pred_shape, pred_filename, output_bundle_dpath,
quantization):
info = {
'file_name': str(pred_filename.relative_to(output_bundle_dpath)),
'channels': 'depth',
'height': pred_shape[0],
'width': pred_shape[1],
'num_bands': 1,
'quantization': quantization,
'warp_aux_to_img': {'scale': [img_info['width'] / pred_shape[1],
img_info['height'] / pred_shape[0]],
'type': 'affine'}
}
return info
def _write_output(img_info, pred, pred_filename,
output_bundle_dpath, quantization):
pred_shape = pred.shape
info = _build_aux_info(img_info, pred_shape, pred_filename,
output_bundle_dpath, quantization)
# with warnings.catch_warnings():
# warnings.simplefilter('ignore', UserWarning)
if ENABLE_WRITES:
pred_filename.parent.mkdir(parents=True, exist_ok=True)
kwimage.imwrite(str(pred_filename),
pred, backend='gdal', blocksize=256,
nodata=quantization['nodata'],
compress='DEFLATE', overviews=3)
return info
def _load_config():
from importlib import resources as importlib_resources
fp = importlib_resources.open_text('geowatch.tasks.depth', 'config.json')
return json.load(fp)
if __name__ == '__main__':
r"""
# Notes:
weights_v1 - for RGB
weights_v2_gray - for PAN
TODO: Predict with both models, one for RGB and one with PAN
# VRAM usage with weights_v2_gray
# window_size=512: 4.951 GB
# window_size=640: 7.406 GB
# window_size=704: 8.912 GB
# window_size=736: 9.310 GB
# window_size=768: 10.099 GB
# window_size=1024: 17.111 GB
# window_size=1152: 21.007 GB
## Drop 4 ##
python3 -m geowatch.tasks.depth.predict \
--deployed=/smart/backup/models/depth/weights_v1.pt \
--dataset=/output/Aligned-Drop4-2022-08-08-TA1-S2-WV-PD-ACC/data_wv_superRes.kwcoco.json \
--output=/output/Aligned-Drop4-2022-08-08-TA1-S2-WV-PD-ACC/data_wv_superRes_depth.kwcoco.json \
--data_workers=8 \
--window_size=2048 \
--cache=1 \
--use_super_res_bands=True \
--scale=4
python -m geowatch visualize $KWCOCO_BUNDLE_DPATH/dzyne_depth.kwcoco.json \
--animate=True --channels="depth,red|green|blue" --skip_missing=True \
--select_images '.sensor_coarse == "WV"' --workers=4 --draw_anns=False
python -m geowatch stats $KWCOCO_BUNDLE_DPATH/dzyne_depth.kwcoco.json
python -m kwcoco stats $KWCOCO_BUNDLE_DPATH/dzyne_depth.kwcoco.json
Notes:
Does export MALLOC_MMAP_THRESHOLD_=0 help with memory?
export MALLOC_MMAP_THRESHOLD_=8192
export MALLOC_ARENA_MAX=4
"""
setup_logging()
torch.hub.set_dir('/tmp/weights')
predict()