geowatch.tasks.fusion.datamodules.data_utils module

I dont like the name of this file. I want to rename it, but it exists to keep the size of the datamodule down for now.

Todo

  • [ ] Break BalancedSampleTree and BalancedSampleForest into their own balanced sampling module.

  • [ ] Make a good augmentation module

  • [ ] Determine where MultiscaleMask should live.

geowatch.tasks.fusion.datamodules.data_utils.resolve_scale_request(request=None, data_gsd=None)[source]

Helper for handling user and machine specified spatial scale requests

Parameters:
  • request (None | float | str) – Indicate a relative or absolute requested scale. If given as a float, this is interpreted as a scale factor relative to the underlying data. If given as a string, it will accept the format “{:f} *GSD” and resolve to an absolute GSD. Defaults to 1.0.

  • data_gsd (None | float) – if specified, this indicates the GSD of the underlying data. (Only valid for geospatial data). TODO: is there a better generalization?

Returns:

resolvedcontaining keys

scale (float): the scale factor to obtain the requested gsd (float | None): if data_gsd is given, this is the absolute

GSD of the request.

Return type:

Dict[str, Any]

Note

The returned scale is relative to the DATA. If you are resizing a sampled image, then use it directly, but if you are adjusting a sample WINDOW, then it needs to be used inversely.

Example

>>> from geowatch.tasks.fusion.datamodules.data_utils import *  # NOQA
>>> resolve_scale_request(1.0)
>>> resolve_scale_request('native')
>>> resolve_scale_request('10 GSD', data_gsd=10)
>>> resolve_scale_request('20 GSD', data_gsd=10)

Example

>>> from geowatch.tasks.fusion.datamodules.data_utils import *  # NOQA
>>> import ubelt as ub
>>> grid = list(ub.named_product({
>>>     'request': ['10GSD', '30GSD'],
>>>     'data_gsd': [10, 30],
>>> }))
>>> grid += list(ub.named_product({
>>>     'request': [None, 1.0, 2.0, 0.25, 'native'],
>>>     'data_gsd': [None, 10, 30],
>>> }))
>>> for kwargs in grid:
>>>     print('kwargs = {}'.format(ub.urepr(kwargs, nl=0)))
>>>     resolved = resolve_scale_request(**kwargs)
>>>     print('resolved = {}'.format(ub.urepr(resolved, nl=0)))
>>>     print('---')
geowatch.tasks.fusion.datamodules.data_utils.polygon_distance_transform(poly, shape, dtype)[source]

Example

import cv2 import kwimage poly = kwimage.Polygon.random().scale(32) poly_mask = np.zeros((32, 32), dtype=np.uint8) poly_mask = poly.fill(poly_mask, value=1) dist = cv2.distanceTransform(poly_mask, cv2.DIST_L2, 3) ### import kwplot kwplot.autompl() kwplot.imshow(dist, cmap=’viridis’, doclf=1) poly.draw(fill=0, border=1)

geowatch.tasks.fusion.datamodules.data_utils.abslog_scaling(arr)[source]
geowatch.tasks.fusion.datamodules.data_utils.fliprot(img, rot_k=0, flip_axis=None, axes=(0, 1))[source]
Parameters:
  • img (ndarray) – H, W, C

  • rot_k (int) – number of ccw rotations

  • flip_axis (Tuple[int, …]) – either [], [0], [1], or [0, 1]. 0 is the y axis and 1 is the x axis.

  • axes (Typle[int, int]) – the location of the y and x axes

Example

>>> img = np.arange(16).reshape(4, 4)
>>> unique_fliprots = [
>>>     {'rot_k': 0, 'flip_axis': None},
>>>     {'rot_k': 0, 'flip_axis': (0,)},
>>>     {'rot_k': 1, 'flip_axis': None},
>>>     {'rot_k': 1, 'flip_axis': (0,)},
>>>     {'rot_k': 2, 'flip_axis': None},
>>>     {'rot_k': 2, 'flip_axis': (0,)},
>>>     {'rot_k': 3, 'flip_axis': None},
>>>     {'rot_k': 3, 'flip_axis': (0,)},
>>> ]
>>> for params in unique_fliprots:
>>>     img_fw = fliprot(img, **params)
>>>     img_inv = inv_fliprot(img_fw, **params)
>>>     assert np.all(img == img_inv)
geowatch.tasks.fusion.datamodules.data_utils.fliprot_annot(annot, rot_k, flip_axis=None, axes=(0, 1), canvas_dsize=None)[source]
geowatch.tasks.fusion.datamodules.data_utils.inv_fliprot_annot(annot, rot_k, flip_axis=None, axes=(0, 1), canvas_dsize=None)[source]
geowatch.tasks.fusion.datamodules.data_utils.inv_fliprot(img, rot_k=0, flip_axis=None, axes=(0, 1))[source]

Undo a fliprot

Parameters:

img (ndarray) – H, W, C

class geowatch.tasks.fusion.datamodules.data_utils.BalancedSampleTree(sample_grid, rng=None)[source]

Bases: NiceRepr

Manages a sampling from a tree of indexes. Helps with balancing samples over multiple criteria.

Todo

Move to its own file - possibly a new module. This is a very general construct, and would benefit from binary-language optimizations.

Example

>>> from geowatch.tasks.fusion.datamodules.data_utils import BalancedSampleTree
>>> # Given a grid of sample locations and attribute information
>>> # (e.g., region, category).
>>> sample_grid = [
>>>     { 'region': 'region1', 'category': 'background', 'color': "blue" },
>>>     { 'region': 'region1', 'category': 'background', 'color': "purple" },
>>>     { 'region': 'region1', 'category': 'background', 'color': "blue" },
>>>     { 'region': 'region1', 'category': 'background', 'color': "red" },
>>>     { 'region': 'region1', 'category': 'background', 'color': "green" },
>>>     { 'region': 'region1', 'category': 'background', 'color': "purple" },
>>>     { 'region': 'region1', 'category': 'background', 'color': "blue" },
>>>     { 'region': 'region1', 'category': 'rare',       'color': "red" },
>>>     { 'region': 'region1', 'category': 'rare',       'color': "green" },
>>>     { 'region': 'region1', 'category': 'background', 'color': "red" },
>>>     { 'region': 'region1', 'category': 'background', 'color': "green" },
>>>     { 'region': 'region2', 'category': 'background', 'color': "blue" },
>>>     { 'region': 'region2', 'category': 'background', 'color': "purple" },
>>>     { 'region': 'region2', 'category': 'background', 'color': "red" },
>>>     { 'region': 'region2', 'category': 'background', 'color': "green" },
>>>     { 'region': 'region2', 'category': 'rare',       'color': "purple" },
>>>     { 'region': 'region2', 'category': 'rare',       'color': "blue" },
>>> ]
>>> #
>>> # First we can just create a flat uniform sampling grid
>>> # and inspect the imbalance that causes.
>>> self = BalancedSampleTree(sample_grid)
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist0 = ub.dict_hist([(g['region'], g['category']) for g in sampled])
>>> print('hist0 = {}'.format(ub.urepr(hist0, nl=1)))
>>> #
>>> # We can subdivide the indexes based on region to improve balance.
>>> self.subdivide('region')
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist1 = ub.dict_hist([(g['region'], g['category']) for g in sampled])
>>> print('hist1 = {}'.format(ub.urepr(hist1, nl=1)))
>>> #
>>> # We can further subdivide by category.
>>> self.subdivide('category')
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist2 = ub.dict_hist([(g['region'], g['category']) for g in sampled])
>>> print('hist2 = {}'.format(ub.urepr(hist2, nl=1)))
>>> #
>>> # We can further subdivide by color, with custom weights.
>>> weights = { 'red': .25, 'blue': .25, 'green': .4, 'purple': .1 }
>>> self.subdivide('color', weights=weights)
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist3 = ub.dict_hist([
>>>     (g['region'], g['category'], g['color']) for g in sampled
>>> ])
>>> print('hist3 = {}'.format(ub.urepr(hist3, nl=1)))
>>> hist3_color = ub.dict_hist([(g['color']) for g in sampled])
>>> print('color weights = {}'.format(ub.urepr(weights, nl=1)))
>>> print('hist3 (color) = {}'.format(ub.urepr(hist3_color, nl=1)))
Parameters:
  • sample_grid (List[Dict]) – List of items with properties to be sampled

  • rng (int | None | RandomState) – random number generator or seed

reseed(rng)[source]

Reseed (or unseed) the random number generator

Parameters:

rng (int | None | RandomState) – random number generator or seed

subdivide(key, weights=None, default_weight=0)[source]
Parameters:
  • key (str) – A key into the item dictionary of a sample that maps to the property to balance over.

  • weights (None | Dict[Any, Number]) – an optional mapping from values that key could point to to a numeric weight.

  • default_weight (None | Number) – if an attribute is unspecified in the weight table, this is the default weight it should be given. Default is 0.

sample()[source]
class geowatch.tasks.fusion.datamodules.data_utils.BalancedSampleForest(sample_grid, rng=None, n_trees=16, scoring='uniform')[source]

Bases: NiceRepr

Manages a sampling from a forest of BalancedSampleTree’s. Helps with balancing samples in the multi-label case.

CommandLine

LINE_PROFILE=1 xdoctest -m geowatch.tasks.fusion.datamodules.data_utils BalancedSampleForest:1 --benchmark

Example

>>> from geowatch.tasks.fusion.datamodules.data_utils import BalancedSampleForest
>>> sample_grid = [
>>>     { 'region': 'region1', 'color': {'blue': 10, 'red': 3}},
>>>     { 'region': 'region1', 'color': {'green': 3, 'purple': 2}},
>>>     { 'region': 'region1', 'color': {'blue': 1}},
>>>     { 'region': 'region1', 'color': {'green': 3, 'red': 5}},
>>>     { 'region': 'region1', 'color': {'purple': 1, 'blue': 1}},
>>>     { 'region': 'region2', 'color': {'blue': 5, 'red': 5}},
>>>     { 'region': 'region2', 'color': {'green': 5, 'purple': 5}},
>>> ]
>>> #
>>> self = BalancedSampleForest(sample_grid)
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist0 = ub.dict_hist([g['region'] for g in sampled])
>>> print('hist0 = {}'.format(ub.urepr(hist0, nl=1)))
>>> #
>>> self.subdivide('region')
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist1 = ub.dict_hist([g['region'] for g in sampled])
>>> print('hist1 = {}'.format(ub.urepr(hist1, nl=1)))
>>> #
>>> self.subdivide('color')
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist2 = ub.dict_hist([(g['region'],) + tuple(g['color'].keys()) for g in sampled])
>>> print('hist2 = {}'.format(ub.urepr(hist2, nl=1)))

Example

>>> # xdoctest: +REQUIRES(--benchmark)
>>> from geowatch.tasks.fusion.datamodules.data_utils import BalancedSampleForest
>>> # Make a very large dataset to test speed constraints
>>> sample_grid = [
>>>     { 'region': 'region1', 'color': {'blue': 10, 'red': 3}},
>>>     { 'region': 'region1', 'color': {'green': 3, 'purple': 2}},
>>>     { 'region': 'region1', 'color': {'blue': 1}},
>>>     { 'region': 'region1', 'color': {'green': 3, 'red': 5}},
>>>     { 'region': 'region1', 'color': {'purple': 1, 'blue': 1}},
>>>     { 'region': 'region2', 'color': {'blue': 5, 'red': 5}},
>>>     { 'region': 'region2', 'color': {'green': 5, 'purple': 5}},
>>> ] * 10000
>>> #
>>> self = BalancedSampleForest(sample_grid)
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist0 = ub.dict_hist([g['region'] for g in sampled])
>>> print('hist0 = {}'.format(ub.urepr(hist0, nl=1)))
>>> #
>>> self.subdivide('region')
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist1 = ub.dict_hist([g['region'] for g in sampled])
>>> print('hist1 = {}'.format(ub.urepr(hist1, nl=1)))
>>> #
>>> self.subdivide('color')
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist2 = ub.dict_hist([(g['region'],) + tuple(g['color'].keys()) for g in sampled])
>>> print('hist2 = {}'.format(ub.urepr(hist2, nl=1)))

Todo

Currently this will look at all attributes passed in each item in the sample grid. I think we want to specify what the attributes that could be balanced over are, which would help prevent a deep copy.

reseed(rng)[source]

Reseed (or unseed) the random number generator

Parameters:

rng (int | None | RandomState) – random number generator or seed

subdivide(key, weights=None, default_weight=0)[source]
sample()[source]

Uniformly sample a tree from the forest, then sample from it.

geowatch.tasks.fusion.datamodules.data_utils.samecolor_nodata_mask(stream, hwc, relevant_bands, use_regions=0, samecolor_values=None)[source]

Find a 2D mask that indicates what values should be set to nan. This is typically done by finding clusters of zeros in specific bands.

Example

>>> from geowatch.tasks.fusion.datamodules.data_utils import *  # NOQA
>>> import kwcoco
>>> import kwarray
>>> stream = kwcoco.FusedChannelSpec.coerce('foo|red|green|bar')
>>> stream_oset = ub.oset(stream)
>>> relevant_bands = ['red', 'green']
>>> relevant_band_idxs = [stream_oset.index(b) for b in relevant_bands]
>>> rng = kwarray.ensure_rng(0)
>>> hwc = (rng.rand(32, 32, stream.numel()) * 3).astype(int)
>>> use_regions = 0
>>> samecolor_values = {0}
>>> samecolor_mask = samecolor_nodata_mask(
>>>     stream, hwc, relevant_bands, use_regions=use_regions,
>>>     samecolor_values=samecolor_values)
>>> assert samecolor_mask.sum() == (hwc[..., relevant_band_idxs] == 0).any(axis=2).sum()
class geowatch.tasks.fusion.datamodules.data_utils.MultiscaleMask[source]

Bases: object

A helper class to build up a mask indicating what pixels are unobservable based on data from different resolution.

In othe words, if you have multiple masks, and each mask has a different resolution, then this will iteravely upscale the masks to the largest resolution so far and perform a logical or. This helps keep the memory footprint small.

Todo

Does this live in kwimage?

CommandLine

xdoctest -m geowatch.tasks.fusion.datamodules.data_utils MultiscaleMask --show

Example

>>> from geowatch.tasks.fusion.datamodules.data_utils import *  # NOQA
>>> image = kwimage.grab_test_image()
>>> image = kwimage.ensure_float01(image)
>>> rng = kwarray.ensure_rng(1)
>>> mask1 = kwimage.Mask.random(shape=(12, 12), rng=rng).data
>>> mask2 = kwimage.Mask.random(shape=(32, 32), rng=rng).data
>>> mask3 = kwimage.Mask.random(shape=(16, 16), rng=rng).data
>>> omask = MultiscaleMask()
>>> omask.update(mask1)
>>> omask.update(mask2)
>>> omask.update(mask3)
>>> masked_image = omask.apply(image, np.nan)
>>> # Now we can use our upscaled masks on an image.
>>> masked_image = kwimage.fill_nans_with_checkers(masked_image, on_value=0.3)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> inputs = kwimage.stack_images(
>>>     [kwimage.atleast_3channels(m * 255) for m in [mask1, mask2, mask3]],
>>>     pad=2, bg_value='kw_green', axis=1)
>>> kwplot.imshow(inputs, pnum=(1, 3, 1), title='input masks')
>>> kwplot.imshow(omask.mask, pnum=(1, 3, 2), title='final mask')
>>> kwplot.imshow(masked_image, pnum=(1, 3, 3), title='masked image')
>>> kwplot.show_if_requested()
update(mask)[source]

Expand the observable mask to the larger data and take the logical or of the resized masks.

apply(image, value)[source]

Set the locations in image that correspond to this mask to value.

property masked_fraction