Source code for geowatch.tasks.depth.utils

# from distutils.log import error
import logging

import dask.array as da
import numpy as np
from tqdm import tqdm
import kwarray

log = logging.getLogger(__name__)


def _process_image_chunked_with_kwarray(image,
                                        process_func,
                                        chip_size=(2048, 2048, 3),
                                        overlap=(128, 128, 0),
                                        output_dtype=np.uint8,
                                        verbose=1):

    gh, gw = image.shape[0:2]
    ch, cw = chip_size[0:2]

    # if gh <= ch and gw <= cw:
    #     overlap = 0
    # else:
    #     if (chip_size[0] == 0):
    #         overlap = 0
    #     else:
    #         overlap = float(overlap[0]) / chip_size[0]

    # if gh <= ch and gw <= cw:
    #     stride = 1
    # else:
    #     if (chip_size[0] == 0):
    #         stride = 1
    #     else:
    #         stride = (chip_size[0] - overlap[0], chip_size[1] - overlap[1])
    #         # float(overlap[0]) / chip_size[0]

    # HACK:
    slider = kwarray.SlidingWindow(image.shape[0:2], chip_size[0:2],
                                   # stride=stride,
                                   overlap=0.3,
                                   keepbound=True,
                                   allow_overshoot=True)

    output_shape = slider.input_shape
    stitcher = kwarray.Stitcher(output_shape)

    from geowatch.tasks.fusion.predict import CocoStitchingManager
    for sl in tqdm(slider, desc='sliding window'):

        chip = image[sl]
        new_chip = process_func(chip)

        CocoStitchingManager._stitcher_center_weighted_add(
            stitcher, sl, new_chip)

        # # Basic add that treats all locations equally
        # stitcher.add(sl, new_chip)

    final = stitcher.finalize()

    return final


def _process_image_chunked_with_dask(image, process_func,
                                     chip_size=(2048, 2048, 3),
                                     overlap=(128, 128, 0),
                                     output_dtype=np.uint8, verbose=1):

    def process_wrapper(img: np.ndarray, pbar, block_info=None):
        if block_info:
            # get total number of chunks and update the progress bar
            num_chunks = np.prod(block_info[0]['num-chunks'])
            pbar.total = num_chunks
            pbar.refresh()
        try:
            res = process_func(img)
        finally:
            pbar.update()
        return res

    gh, gw = image.shape[0:2]
    ch, cw = chip_size[0:2]
    if gh <= ch and gw <= cw:
        overlap = (0, 0, 0)

    # the actual size of the image passed to __process_chip is chunk_size + 2*overlap
    chunk_size = tuple(c - 2 * o for c, o in zip(chip_size, overlap))

    image: da.Array = da.asanyarray(image)
    image = image.rechunk(chunk_size)

    mapkw = {
        'boundary': 'none',
        'pbar': tqdm(unit=' chip', disable=not verbose),
    }
    if 0:
        print('overlap = {!r}'.format(overlap))
        print('image = {!r}'.format(image))

    pred = image.map_overlap(
        process_wrapper,
        # overlap on each dimension
        depth=overlap,
        # FIXME: dont do this? Drop after?
        # input is w,h,b output is w,h so tell map_overlay that we're dropping axis 2
        drop_axis=2,
        # output will have this dtype
        dtype=output_dtype,
        # meta=np.array((), dtype=output_dtype),
        # pass through
        **mapkw,
    )

    # Is there a leak or memory issue here?
    scheduler = 'single-threaded'
    # scheduler = 'synchronous'
    pred = pred.compute(scheduler=scheduler)
    mapkw['pbar'].close()

    return pred


[docs] def process_image_chunked(image, process_func, chip_size=(2048, 2048, 3), overlap=(128, 128, 0), output_dtype=np.uint8, verbose=1, sliding_window_method='kwarray'): """ Args: chip_size : must be less than half of the overlap Example: >>> from geowatch.tasks.depth.utils import * # NOQA >>> import kwimage >>> import kwarray >>> image = kwimage.ensure_float01(kwimage.grab_test_image(dsize=(512, 512))) >>> nan_poly = kwimage.Polygon.random(rng=None).scale(image.shape[0]) >>> image = nan_poly.fill(image.copy(), np.nan) >>> process_func = lambda x: kwimage.gaussian_blur(x, sigma=7).mean(axis=2) >>> non_chunked = process_func(image) >>> chip_size = (128, 128, 3) >>> overlap = (32, 32, 0) >>> output_dtype = np.uint8 >>> verbose = 0 >>> print('kwarray') >>> result1 = process_image_chunked(image, process_func, chip_size, overlap, output_dtype, verbose=1, sliding_window_method='kwarray') >>> print('dask') >>> result2 = process_image_chunked(image, process_func, chip_size, overlap, output_dtype, verbose=1, sliding_window_method='dask') >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(image, pnum=(1, 3, 1), doclf=True) >>> kwplot.imshow(result1, pnum=(1, 3, 2), title='kwarray') >>> kwplot.imshow(result2, pnum=(1, 3, 3), title='dask') """ if sliding_window_method == 'kwarray': return _process_image_chunked_with_kwarray(image, process_func, chip_size, overlap, output_dtype, verbose) elif sliding_window_method == 'dask': return _process_image_chunked_with_dask(image, process_func, chip_size, overlap, output_dtype, verbose) else: raise KeyError(sliding_window_method)