"""
I dont like the name of this file. I want to rename it, but it exists to keep
the size of the datamodule down for now.
TODO:
- [ ] Break BalancedSampleTree and BalancedSampleForest into their own balanced sampling module.
- [ ] Make a good augmentation module
- [ ] Determine where MultiscaleMask should live.
"""
import numpy as np
import ubelt as ub
import kwimage
import kwarray
import networkx as nx
try:
from line_profiler import profile
except Exception:
profile = ub.identity
[docs]
def resolve_scale_request(request=None, data_gsd=None):
"""
Helper for handling user and machine specified spatial scale requests
Args:
request (None | float | str):
Indicate a relative or absolute requested scale. If given as a
float, this is interpreted as a scale factor relative to the
underlying data. If given as a string, it will accept the format
"{:f} *GSD" and resolve to an absolute GSD. Defaults to 1.0.
data_gsd (None | float):
if specified, this indicates the GSD of the underlying data.
(Only valid for geospatial data). TODO: is there a better
generalization?
Returns:
Dict[str, Any] : resolved : containing keys
scale (float): the scale factor to obtain the requested
gsd (float | None): if data_gsd is given, this is the absolute
GSD of the request.
Note:
The returned scale is relative to the DATA. If you are resizing a
sampled image, then use it directly, but if you are adjusting a sample
WINDOW, then it needs to be used inversely.
Example:
>>> from geowatch.tasks.fusion.datamodules.data_utils import * # NOQA
>>> resolve_scale_request(1.0)
>>> resolve_scale_request('native')
>>> resolve_scale_request('10 GSD', data_gsd=10)
>>> resolve_scale_request('20 GSD', data_gsd=10)
Example:
>>> from geowatch.tasks.fusion.datamodules.data_utils import * # NOQA
>>> import ubelt as ub
>>> grid = list(ub.named_product({
>>> 'request': ['10GSD', '30GSD'],
>>> 'data_gsd': [10, 30],
>>> }))
>>> grid += list(ub.named_product({
>>> 'request': [None, 1.0, 2.0, 0.25, 'native'],
>>> 'data_gsd': [None, 10, 30],
>>> }))
>>> for kwargs in grid:
>>> print('kwargs = {}'.format(ub.urepr(kwargs, nl=0)))
>>> resolved = resolve_scale_request(**kwargs)
>>> print('resolved = {}'.format(ub.urepr(resolved, nl=0)))
>>> print('---')
"""
# FIXME: rectify with util_resolution
final_gsd = None
final_scale = None
if request is None:
final_scale = 1.0
final_gsd = data_gsd
elif isinstance(request, str):
if request == 'native':
final_gsd = 'native'
final_scale = 'native'
elif request.lower().endswith('gsd'):
if data_gsd is None:
raise ValueError(
'The request was given in terms of GSD, but '
'the underlying data GSD was unspecified')
final_gsd = float(request[:-3].strip())
final_scale = data_gsd / final_gsd
else:
final_scale = float(request)
else:
final_scale = float(request)
if final_gsd is None:
if data_gsd is not None:
final_gsd = np.array(data_gsd) / final_scale
resolved = {
'scale': final_scale,
'gsd': final_gsd,
'data_gsd': data_gsd,
}
return resolved
[docs]
def abslog_scaling(arr):
orig_sign = np.nan_to_num(np.sign(arr))
shifted = np.abs(arr) + 1
shifted = np.log(shifted)
shifted[np.isnan(shifted)] = 0.1
return orig_sign * shifted
[docs]
def fliprot(img, rot_k=0, flip_axis=None, axes=(0, 1)):
"""
Args:
img (ndarray): H, W, C
rot_k (int): number of ccw rotations
flip_axis(Tuple[int, ...]):
either [], [0], [1], or [0, 1].
0 is the y axis and 1 is the x axis.
axes (Typle[int, int]): the location of the y and x axes
Example:
>>> img = np.arange(16).reshape(4, 4)
>>> unique_fliprots = [
>>> {'rot_k': 0, 'flip_axis': None},
>>> {'rot_k': 0, 'flip_axis': (0,)},
>>> {'rot_k': 1, 'flip_axis': None},
>>> {'rot_k': 1, 'flip_axis': (0,)},
>>> {'rot_k': 2, 'flip_axis': None},
>>> {'rot_k': 2, 'flip_axis': (0,)},
>>> {'rot_k': 3, 'flip_axis': None},
>>> {'rot_k': 3, 'flip_axis': (0,)},
>>> ]
>>> for params in unique_fliprots:
>>> img_fw = fliprot(img, **params)
>>> img_inv = inv_fliprot(img_fw, **params)
>>> assert np.all(img == img_inv)
"""
if rot_k != 0:
img = np.rot90(img, k=rot_k, axes=axes)
if flip_axis is not None:
_flip_axis = np.asarray(axes)[flip_axis]
img = np.flip(img, axis=_flip_axis)
return img
[docs]
def fliprot_annot(annot, rot_k, flip_axis=None, axes=(0, 1), canvas_dsize=None):
"""
Ignore:
>>> from geowatch.tasks.fusion.datamodules.data_utils import * # NOQA
>>> import kwimage
>>> H, W = 121, 153
>>> canvas_dsize = (W, H)
>>> box1 = kwimage.Boxes.random(1).scale((W, H)).quantize()
>>> ltrb = box1.data
>>> rot_k = 4
>>> annot = box1
>>> annot = box1.to_polygons()[0]
>>> annot1 = annot.copy()
>>> unique_fliprots = [
>>> {'rot_k': 0, 'flip_axis': None},
>>> {'rot_k': 0, 'flip_axis': (0,)},
>>> {'rot_k': 1, 'flip_axis': None},
>>> {'rot_k': 1, 'flip_axis': (0,)},
>>> {'rot_k': 2, 'flip_axis': None},
>>> {'rot_k': 2, 'flip_axis': (0,)},
>>> {'rot_k': 3, 'flip_axis': None},
>>> {'rot_k': 3, 'flip_axis': (0,)},
>>> ]
>>> results = []
>>> for params in unique_fliprots:
>>> annot2 = fliprot_annot(annot, canvas_dsize=canvas_dsize, **params)
>>> annot3 = inv_fliprot_annot(annot2, canvas_dsize=canvas_dsize, **params)
>>> results.append({
>>> 'annot2': annot2,
>>> 'annot3': annot3,
>>> 'params': params,
>>> })
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> image1 = kwimage.grab_test_image('astro', dsize=(W, H))
>>> pnum_ = kwplot.PlotNums(nSubplots=len(results))
>>> for result in results:
>>> image2 = fliprot(image1.copy(), **result['params'])
>>> image3 = inv_fliprot(image2.copy(), **result['params'])
>>> annot2 = result['annot2']
>>> annot3 = result['annot3']
>>> canvas1 = annot1.draw_on(image1.copy(), edgecolor='kitware_green', fill=False)
>>> canvas2 = annot2.draw_on(image2.copy(), edgecolor='kitware_blue', fill=False)
>>> canvas3 = annot3.draw_on(image3.copy(), edgecolor='kitware_red', fill=False)
>>> canvas = kwimage.stack_images([canvas1, canvas2, canvas3], axis=1)
>>> kwplot.imshow(canvas, pnum=pnum_(), title=ub.urepr(result['params'], nl=0, compact=1, nobr=1))
"""
# TODO: can use the new `Affine.fliprot` when 0.9.22 releases
import kwimage
if rot_k != 0:
x0 = canvas_dsize[0] / 2
y0 = canvas_dsize[1] / 2
# generalized way
# Translate center of old canvas to the origin
T1 = kwimage.Affine.translate((-x0, -y0))
# Construct the rotation
tau = np.pi * 2
theta = -(rot_k * tau / 4)
R = kwimage.Affine.rotate(theta=theta)
# Find the center of the new rotated canvas
canvas_box = kwimage.Box.from_dsize(canvas_dsize)
new_canvas_box = canvas_box.warp(R)
x2 = new_canvas_box.width / 2
y2 = new_canvas_box.height / 2
# Translate to the center of the new canvas
T2 = kwimage.Affine.translate((x2, y2))
# print(f'T1=\n{ub.urepr(T1)}')
# print(f'R=\n{ub.urepr(R)}')
# print(f'T2=\n{ub.urepr(T2)}')
A = T2 @ R @ T1
annot = annot.warp(A)
# TODO: specialized faster way
# lt_x, lt_y, rb_x, rb_y = boxes.components
else:
x2 = y2 = None
# boxes = kwimage.Boxes(ltrb, 'ltrb')
if flip_axis is not None:
if x2 is None:
x2 = canvas_dsize[0] / 2
y2 = canvas_dsize[1] / 2
# Make the flip matrix
F = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
for axis in flip_axis:
mdim = 1 - axis
F[mdim, mdim] *= -1
T1 = kwimage.Affine.translate((-x2, -y2))
T2 = kwimage.Affine.translate((x2, y2))
A = T2 @ F @ T1
annot = annot.warp(A)
return annot
[docs]
def inv_fliprot_annot(annot, rot_k, flip_axis=None, axes=(0, 1), canvas_dsize=None):
if rot_k % 2 == 1:
canvas_dsize = canvas_dsize[::-1]
annot = fliprot_annot(annot, -rot_k, flip_axis=None, axes=axes, canvas_dsize=canvas_dsize)
if rot_k % 2 == 1:
canvas_dsize = canvas_dsize[::-1]
annot = fliprot_annot(annot, 0, flip_axis=flip_axis, axes=axes, canvas_dsize=canvas_dsize)
return annot
[docs]
def inv_fliprot(img, rot_k=0, flip_axis=None, axes=(0, 1)):
"""
Undo a fliprot
Args:
img (ndarray): H, W, C
"""
if flip_axis is not None:
_flip_axis = np.asarray(axes)[flip_axis]
img = np.flip(img, axis=_flip_axis)
if rot_k != 0:
img = np.rot90(img, k=-rot_k, axes=axes)
return img
@ub.memoize
def _string_to_hashvec(key):
"""
Transform a string into a 16D float32 uniformly distributed random Tensor
based on the hash of the string.
Note there are magic numbers hard-coded in this function, and is the reason
for the blake3 dependency. Would likely be better to make it configurable
and use sha256 as the default.
"""
key_hash = ub.hash_data(key, base=16, hasher='blake3').encode()
key_tensor = np.frombuffer(memoryview(key_hash), dtype=np.int32).astype(np.float32)
key_tensor = key_tensor / np.linalg.norm(key_tensor)
return key_tensor
def _boxes_snap_to_edges(given_box, snap_target):
"""
Ignore:
>>> from geowatch.tasks.fusion.datamodules.data_utils import * # NOQA
>>> import kwimage
>>> from geowatch.tasks.fusion.datamodules.data_utils import _string_to_hashvec, _boxes_snap_to_edges
>>> from geowatch.tasks.fusion.datamodules.data_utils import _boxes_snap_to_edges
>>> snap_target = kwimage.Boxes([[0, 0, 10, 10]], 'ltrb')
>>> given_box = kwimage.Boxes([[-3, 5, 3, 13]], 'ltrb')
>>> adjusted_box = _boxes_snap_to_edges(given_box, snap_target)
>>> print('adjusted_box = {!r}'.format(adjusted_box))
_boxes_snap_to_edges(kwimage.Boxes([[-3, 3, 20, 13]], 'ltrb'), snap_target)
_boxes_snap_to_edges(kwimage.Boxes([[-3, -3, 3, 3]], 'ltrb'), snap_target)
_boxes_snap_to_edges(kwimage.Boxes([[7, 7, 15, 15]], 'ltrb'), snap_target)
"""
s_x1, s_y1, s_x2, s_y2 = snap_target.components
g_x1, g_y1, g_x2, g_y2 = given_box.components
xoffset1 = -np.minimum((g_x1 - s_x1), 0)
yoffset1 = -np.minimum((g_y1 - s_y1), 0)
xoffset2 = np.minimum((s_x2 - g_x2), 0)
yoffset2 = np.minimum((s_y2 - g_y2), 0)
xoffset = (xoffset1 + xoffset2).ravel()[0]
yoffset = (yoffset1 + yoffset2).ravel()[0]
adjusted_box = given_box.translate((xoffset, yoffset))
return adjusted_box
[docs]
class BalancedSampleTree(ub.NiceRepr):
"""
Manages a sampling from a tree of indexes. Helps with balancing
samples over multiple criteria.
TODO:
Move to its own file - possibly a new module. This is a very general
construct, and would benefit from binary-language optimizations.
Example:
>>> from geowatch.tasks.fusion.datamodules.data_utils import BalancedSampleTree
>>> # Given a grid of sample locations and attribute information
>>> # (e.g., region, category).
>>> sample_grid = [
>>> { 'region': 'region1', 'category': 'background', 'color': "blue" },
>>> { 'region': 'region1', 'category': 'background', 'color': "purple" },
>>> { 'region': 'region1', 'category': 'background', 'color': "blue" },
>>> { 'region': 'region1', 'category': 'background', 'color': "red" },
>>> { 'region': 'region1', 'category': 'background', 'color': "green" },
>>> { 'region': 'region1', 'category': 'background', 'color': "purple" },
>>> { 'region': 'region1', 'category': 'background', 'color': "blue" },
>>> { 'region': 'region1', 'category': 'rare', 'color': "red" },
>>> { 'region': 'region1', 'category': 'rare', 'color': "green" },
>>> { 'region': 'region1', 'category': 'background', 'color': "red" },
>>> { 'region': 'region1', 'category': 'background', 'color': "green" },
>>> { 'region': 'region2', 'category': 'background', 'color': "blue" },
>>> { 'region': 'region2', 'category': 'background', 'color': "purple" },
>>> { 'region': 'region2', 'category': 'background', 'color': "red" },
>>> { 'region': 'region2', 'category': 'background', 'color': "green" },
>>> { 'region': 'region2', 'category': 'rare', 'color': "purple" },
>>> { 'region': 'region2', 'category': 'rare', 'color': "blue" },
>>> ]
>>> #
>>> # First we can just create a flat uniform sampling grid
>>> # and inspect the imbalance that causes.
>>> self = BalancedSampleTree(sample_grid)
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist0 = ub.dict_hist([(g['region'], g['category']) for g in sampled])
>>> print('hist0 = {}'.format(ub.urepr(hist0, nl=1)))
>>> #
>>> # We can subdivide the indexes based on region to improve balance.
>>> self.subdivide('region')
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist1 = ub.dict_hist([(g['region'], g['category']) for g in sampled])
>>> print('hist1 = {}'.format(ub.urepr(hist1, nl=1)))
>>> #
>>> # We can further subdivide by category.
>>> self.subdivide('category')
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist2 = ub.dict_hist([(g['region'], g['category']) for g in sampled])
>>> print('hist2 = {}'.format(ub.urepr(hist2, nl=1)))
>>> #
>>> # We can further subdivide by color, with custom weights.
>>> weights = { 'red': .25, 'blue': .25, 'green': .4, 'purple': .1 }
>>> self.subdivide('color', weights=weights)
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist3 = ub.dict_hist([
>>> (g['region'], g['category'], g['color']) for g in sampled
>>> ])
>>> print('hist3 = {}'.format(ub.urepr(hist3, nl=1)))
>>> hist3_color = ub.dict_hist([(g['color']) for g in sampled])
>>> print('color weights = {}'.format(ub.urepr(weights, nl=1)))
>>> print('hist3 (color) = {}'.format(ub.urepr(hist3_color, nl=1)))
"""
@profile
def __init__(self, sample_grid, rng=None):
"""
Args:
sample_grid (List[Dict]):
List of items with properties to be sampled
rng (int | None | RandomState):
random number generator or seed
"""
self.rng = kwarray.ensure_rng(rng)
# validate input
if not isinstance(sample_grid, list):
raise TypeError(ub.paragraph(
"""
BalancedSampleTree only accepts List[Dict], but outer type
was {type(sample_grid)}.
"""))
if not sample_grid:
raise ValueError('Input sample_grid is empty')
if not isinstance(sample_grid[0], dict):
raise TypeError(ub.paragraph(
"""
BalancedSampleTree only accepts List[Dict], but inner type
was {type(sample_grid[0])}.
"""))
self.graph = self._create_graph(sample_grid)
self._leaf_nodes = [n for n in self.graph.nodes if self.graph.out_degree[n] == 0]
[docs]
def reseed(self, rng):
"""
Reseed (or unseed) the random number generator
Args:
rng (int | None | RandomState):
random number generator or seed
"""
self.rng = kwarray.ensure_rng(rng)
@profile
def _create_graph(self, sample_grid):
graph = nx.DiGraph()
# make a special root node
root_node = '__root__'
graph.add_node(root_node, weights=None)
for index, item in enumerate(sample_grid):
# Using urepr in the critial loop is too slow for large sample
# grids
# maybe we add an option to enable this for debugging / demo?
# label = f'{index:02d} ' + ub.urepr(item, nl=0, compact=1, nobr=1)
label = f'{index:02d}'
graph.add_node(index, label=label, **item)
graph.add_edge(root_node, index)
return graph
@profile
def _get_parent(self, n):
""" Get the parent of a node (assume a tree). None if it doesnt exist """
preds = self.graph.pred[n]
# This function is called a lot, disable sanity checks
# if len(preds):
# assert len(preds) == 1
# return next(iter(preds))
# else:
# return None
return next(iter(preds))
@profile
def _reweight(self, node, idx_child):
if self.graph.nodes[node]['weights'] is not None:
_weights = self.graph.nodes[node]['weights']
# remove weight for this child
_weights = np.delete(_weights, idx_child)
# reweight
if _weights.sum() != 0:
_weights = _weights / _weights.sum()
else:
_weights = np.zeros(1)
self.graph.nodes[node]['weights'] = _weights
@profile
def _prune_and_reweight(self, nodes):
for parent, orphans in nodes:
grandpa = self._get_parent(parent)
if grandpa is None:
# already removed this branch
self.graph.remove_nodes_from([parent] + orphans)
continue
# get parent index from grandpa, remove nodes
idx_parent = list(self.graph.successors(grandpa)).index(parent)
self.graph.remove_nodes_from([parent] + orphans)
# update weights of grandpa, walking up the tree
queue = [(grandpa, idx_parent)]
while queue:
curr_grandpa, curr_idx_parent = queue.pop()
num_children = len(list(self.graph.successors(curr_grandpa)))
if num_children >= 1:
self._reweight(curr_grandpa, curr_idx_parent)
else:
# removed only child, remove the grandparent
_parent = curr_grandpa
_grandpa = self._get_parent(curr_grandpa)
if _grandpa is not None:
_idx_parent = list(self.graph.successors(_grandpa)).index(_parent)
queue.append((_grandpa, _idx_parent))
self.graph.remove_node(curr_grandpa)
# update leaf nodes
self._leaf_nodes = [n for n in self._leaf_nodes if self.graph.has_node(n)]
if len(self._leaf_nodes) == 0:
raise ValueError("Leaf nodes became empty.")
[docs]
@profile
def subdivide(self, key, weights=None, default_weight=0):
"""
Args:
key (str):
A key into the item dictionary of a sample that maps to the
property to balance over.
weights (None | Dict[Any, Number]):
an optional mapping from values that ``key`` could point to
to a numeric weight.
default_weight (None | Number):
if an attribute is unspecified in the weight table, this is
the default weight it should be given. Default is 0.
"""
remove_nodes = []
remove_edges = []
add_edges = []
add_nodes = []
# Group all leaf nodes by their direct parents
# It is possible that we could optimize this with a column-based data
# structure, but this current structure if far more general and easier
# to read.
parent_to_leafs = ub.group_items(self._leaf_nodes, key=lambda n: self._get_parent(n))
for parent, children in parent_to_leafs.items():
# Group children by the new attribute
val_to_subgroup = ub.group_items(children, lambda n: self.graph.nodes[n][key])
# try:
# val_to_subgroup = ub.odict(sorted(val_to_subgroup.items()))
# except TypeError:
# val_to_subgroup = ub.odict(sorted(val_to_subgroup.items(), key=str))
# Add weights to the prior parent
if weights is not None:
weights_group = np.asarray(list(ub.take(weights, val_to_subgroup.keys(), default=default_weight)))
denom = weights_group.sum()
if denom != 0:
weights_group = weights_group / denom
self.graph.nodes[parent]['weights'] = weights_group
else:
# All options have zero weight, schedule group for pruning
remove_nodes.append((parent, children))
continue
else:
self.graph.nodes[parent]["weights"] = None
# Create a node for each child
for value, subgroup in val_to_subgroup.items():
# Use a dotted name to make unambiguous tree splits
new_parent = f'{parent}.{key}={value}'
# Mark edges to add / remove to implement the split
remove_edges.extend([(parent, n) for n in subgroup])
add_edges.extend([(parent, new_parent) for n in subgroup])
add_edges.extend([(new_parent, n) for n in subgroup])
add_nodes.append(new_parent)
# Modify the graph
self.graph.remove_edges_from(remove_edges)
self.graph.add_nodes_from(add_nodes, weights=None)
self.graph.add_edges_from(add_edges)
self._prune_and_reweight(remove_nodes)
@profile
def _sample_many(self, num):
for _ in range(num):
idx = self.sample()
yield idx
[docs]
@profile
def sample(self):
current = '__root__'
while self.graph.out_degree(current) > 0:
children = list(self.graph.successors(current))
num = len(children)
weights = self.graph.nodes[current]['weights']
if weights is None:
idx = self.rng.randint(0, num)
else:
idx = self.rng.choice(num, 1, p=weights)[0]
current = children[idx]
return current
@profile
def __len__(self):
return len(list(self._leaf_nodes))
@profile
def __nice__(self):
n_nodes = self.graph.number_of_nodes()
n_edges = self.graph.number_of_edges()
n_leafs = self.__len__()
n_depth = len(nx.algorithms.dag.dag_longest_path(self.graph))
return f'nodes={n_nodes}, edges={n_edges}, leafs={n_leafs}, depth={n_depth}'
[docs]
class BalancedSampleForest(ub.NiceRepr):
"""
Manages a sampling from a forest of BalancedSampleTree's. Helps with balancing
samples in the multi-label case.
CommandLine:
LINE_PROFILE=1 xdoctest -m geowatch.tasks.fusion.datamodules.data_utils BalancedSampleForest:1 --benchmark
Example:
>>> from geowatch.tasks.fusion.datamodules.data_utils import BalancedSampleForest
>>> sample_grid = [
>>> { 'region': 'region1', 'color': {'blue': 10, 'red': 3}},
>>> { 'region': 'region1', 'color': {'green': 3, 'purple': 2}},
>>> { 'region': 'region1', 'color': {'blue': 1}},
>>> { 'region': 'region1', 'color': {'green': 3, 'red': 5}},
>>> { 'region': 'region1', 'color': {'purple': 1, 'blue': 1}},
>>> { 'region': 'region2', 'color': {'blue': 5, 'red': 5}},
>>> { 'region': 'region2', 'color': {'green': 5, 'purple': 5}},
>>> ]
>>> #
>>> self = BalancedSampleForest(sample_grid)
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist0 = ub.dict_hist([g['region'] for g in sampled])
>>> print('hist0 = {}'.format(ub.urepr(hist0, nl=1)))
>>> #
>>> self.subdivide('region')
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist1 = ub.dict_hist([g['region'] for g in sampled])
>>> print('hist1 = {}'.format(ub.urepr(hist1, nl=1)))
>>> #
>>> self.subdivide('color')
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist2 = ub.dict_hist([(g['region'],) + tuple(g['color'].keys()) for g in sampled])
>>> print('hist2 = {}'.format(ub.urepr(hist2, nl=1)))
Example:
>>> # xdoctest: +REQUIRES(--benchmark)
>>> from geowatch.tasks.fusion.datamodules.data_utils import BalancedSampleForest
>>> # Make a very large dataset to test speed constraints
>>> sample_grid = [
>>> { 'region': 'region1', 'color': {'blue': 10, 'red': 3}},
>>> { 'region': 'region1', 'color': {'green': 3, 'purple': 2}},
>>> { 'region': 'region1', 'color': {'blue': 1}},
>>> { 'region': 'region1', 'color': {'green': 3, 'red': 5}},
>>> { 'region': 'region1', 'color': {'purple': 1, 'blue': 1}},
>>> { 'region': 'region2', 'color': {'blue': 5, 'red': 5}},
>>> { 'region': 'region2', 'color': {'green': 5, 'purple': 5}},
>>> ] * 10000
>>> #
>>> self = BalancedSampleForest(sample_grid)
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist0 = ub.dict_hist([g['region'] for g in sampled])
>>> print('hist0 = {}'.format(ub.urepr(hist0, nl=1)))
>>> #
>>> self.subdivide('region')
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist1 = ub.dict_hist([g['region'] for g in sampled])
>>> print('hist1 = {}'.format(ub.urepr(hist1, nl=1)))
>>> #
>>> self.subdivide('color')
>>> print(f'self={self}')
>>> sampled = list(ub.take(sample_grid, self._sample_many(100)))
>>> hist2 = ub.dict_hist([(g['region'],) + tuple(g['color'].keys()) for g in sampled])
>>> print('hist2 = {}'.format(ub.urepr(hist2, nl=1)))
TODO:
Currently this will look at all attributes passed in each item in the
sample grid. I think we want to specify what the attributes that could
be balanced over are, which would help prevent a deep copy.
"""
@profile
def __init__(self, sample_grid, rng=None, n_trees=16, scoring='uniform'):
super().__init__()
self.rng = rng = kwarray.ensure_rng(rng)
# TODO: validate input
self.n_trees = n_trees
self.forest = self._create_forest(sample_grid, n_trees, scoring)
[docs]
def reseed(self, rng):
"""
Reseed (or unseed) the random number generator
Args:
rng (int | None | RandomState):
random number generator or seed
"""
self.rng = kwarray.ensure_rng(rng)
for tree in self.forest:
tree.reseed(self.rng)
@profile
def _create_forest(self, sample_grid, n_trees, scoring):
"""
Generate N BalancedSampleTree's, producing a hard assignment for
each multi-label attribute. Expects a multi-label attribute to arrive
as a dictionary with possible values as keys and frequencies as values.
"""
import copy
forest = []
verbose = 1
for idx in ub.ProgIter(range(n_trees), desc='Build balanced forests', verbose=verbose):
local_sample_grid = copy.deepcopy(sample_grid)
for sample in local_sample_grid:
for key, val in sample.items():
if isinstance(val, dict):
if len(val) == 0:
sample[key] = None
continue
elif len(val) == 1:
sample[key] = list(val.keys())[0]
continue
# two or more choices
if scoring == 'inverse':
labels = list(val.keys())
freqs = np.asarray(list(val.values()))
weights = 1 - (freqs / freqs.sum())
weights = weights / weights.sum()
idx = self.rng.choice(len(labels), 1, p=weights)[0]
sample[key] = labels[idx]
elif scoring == 'uniform':
sample[key] = self.rng.choice(list(val.keys()))
else:
raise NotImplementedError
# initialize a BalancedSampleTree with this sample grid
bst = BalancedSampleTree(local_sample_grid, rng=self.rng)
forest.append(bst)
return forest
[docs]
@profile
def subdivide(self, key, weights=None, default_weight=0):
for tree in self.forest:
tree.subdivide(key, weights=weights, default_weight=default_weight)
@profile
def _sample_many(self, num):
for _ in range(num):
idx = self.sample()
yield idx
[docs]
@profile
def sample(self):
""" Uniformly sample a tree from the forest, then sample from it. """
idx = self.rng.choice(self.n_trees)
return self.forest[idx].sample()
@profile
def __len__(self):
return len(self.forest[0])
@profile
def __nice__(self):
graph = self.forest[0].graph
n_trees = self.n_trees
n_nodes = graph.number_of_nodes()
n_edges = graph.number_of_edges()
n_leafs = len(self)
n_depth = len(nx.algorithms.dag.dag_longest_path(graph))
return f'trees={n_trees}, nodes={n_nodes}, edges={n_edges}, leafs={n_leafs}, depth={n_depth}'
[docs]
def samecolor_nodata_mask(stream, hwc, relevant_bands, use_regions=0,
samecolor_values=None):
"""
Find a 2D mask that indicates what values should be set to nan.
This is typically done by finding clusters of zeros in specific bands.
Example:
>>> from geowatch.tasks.fusion.datamodules.data_utils import * # NOQA
>>> import kwcoco
>>> import kwarray
>>> stream = kwcoco.FusedChannelSpec.coerce('foo|red|green|bar')
>>> stream_oset = ub.oset(stream)
>>> relevant_bands = ['red', 'green']
>>> relevant_band_idxs = [stream_oset.index(b) for b in relevant_bands]
>>> rng = kwarray.ensure_rng(0)
>>> hwc = (rng.rand(32, 32, stream.numel()) * 3).astype(int)
>>> use_regions = 0
>>> samecolor_values = {0}
>>> samecolor_mask = samecolor_nodata_mask(
>>> stream, hwc, relevant_bands, use_regions=use_regions,
>>> samecolor_values=samecolor_values)
>>> assert samecolor_mask.sum() == (hwc[..., relevant_band_idxs] == 0).any(axis=2).sum()
"""
from geowatch.utils import util_kwimage
stream_oset = ub.oset(stream)
relevant_band_idxs = [stream_oset.index(b) for b in relevant_bands]
relevant_masks = []
for b_sl in relevant_band_idxs:
bands = hwc[:, :, b_sl]
bands = np.ascontiguousarray(bands)
if use_regions:
# Speed up the compuation by doing this at a coarser scale
is_samecolor = util_kwimage.find_samecolor_regions(
bands, scale=0.4, min_region_size=49,
values=samecolor_values)
else:
# Faster histogram method
is_samecolor = util_kwimage.find_high_frequency_values(
bands, values=samecolor_values)
relevant_masks.append(is_samecolor)
if len(relevant_masks) == 1:
samecolor_mask = relevant_masks[0]
else:
samecolor_mask = (np.stack(relevant_masks, axis=2) > 0).any(axis=2)
return samecolor_mask
[docs]
class MultiscaleMask:
"""
A helper class to build up a mask indicating what pixels are unobservable
based on data from different resolution.
In othe words, if you have multiple masks, and each mask has a different
resolution, then this will iteravely upscale the masks to the largest
resolution so far and perform a logical or. This helps keep the memory
footprint small.
TODO:
Does this live in kwimage?
CommandLine:
xdoctest -m geowatch.tasks.fusion.datamodules.data_utils MultiscaleMask --show
Example:
>>> from geowatch.tasks.fusion.datamodules.data_utils import * # NOQA
>>> image = kwimage.grab_test_image()
>>> image = kwimage.ensure_float01(image)
>>> rng = kwarray.ensure_rng(1)
>>> mask1 = kwimage.Mask.random(shape=(12, 12), rng=rng).data
>>> mask2 = kwimage.Mask.random(shape=(32, 32), rng=rng).data
>>> mask3 = kwimage.Mask.random(shape=(16, 16), rng=rng).data
>>> omask = MultiscaleMask()
>>> omask.update(mask1)
>>> omask.update(mask2)
>>> omask.update(mask3)
>>> masked_image = omask.apply(image, np.nan)
>>> # Now we can use our upscaled masks on an image.
>>> masked_image = kwimage.fill_nans_with_checkers(masked_image, on_value=0.3)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> inputs = kwimage.stack_images(
>>> [kwimage.atleast_3channels(m * 255) for m in [mask1, mask2, mask3]],
>>> pad=2, bg_value='kw_green', axis=1)
>>> kwplot.imshow(inputs, pnum=(1, 3, 1), title='input masks')
>>> kwplot.imshow(omask.mask, pnum=(1, 3, 2), title='final mask')
>>> kwplot.imshow(masked_image, pnum=(1, 3, 3), title='masked image')
>>> kwplot.show_if_requested()
"""
def __init__(self):
self.mask = None
self._fraction = None
[docs]
def update(self, mask):
"""
Expand the observable mask to the larger data and take the logical or
of the resized masks.
"""
self._fraction = None
if len(mask.shape) > 2:
if len(mask.shape) != 3 or mask.shape[2] != 1:
raise ValueError(f'bad mask shape {mask.shape}')
mask = mask[..., 0]
if self.mask is None:
self.mask = mask
else:
mask1 = self.mask
mask2 = mask
dsize1 = mask1.shape[0:2][::-1]
dsize2 = mask2.shape[0:2][::-1]
if dsize1 != dsize2:
area1 = np.prod(dsize1)
area2 = np.prod(dsize2)
if area2 > area1:
mask1, mask2 = mask2, mask1
dsize1, dsize2 = dsize2, dsize1
# Enlarge the smaller mask
mask2 = mask2.astype(np.uint8)
mask2 = kwimage.imresize(mask2, dsize=dsize1,
interpolation='nearest')
self.mask = np.logical_or(mask1, mask2)
[docs]
def apply(self, image, value):
"""
Set the locations in ``image`` that correspond to this mask to
``value``.
"""
mask = self.mask
if mask is None:
return image
dsize1 = image.shape[0:2][::-1]
dsize2 = mask.shape[0:2][::-1]
if dsize1 != dsize2:
# Ensure the mask corresponds to the image size
mask = mask.astype(np.uint8)
mask = kwimage.imresize(mask, dsize=dsize1,
interpolation='nearest')
mask = kwarray.atleast_nd(mask, 3)
mask = mask.astype(bool)
assert mask.shape[2] == 1
mask = np.broadcast_to(mask, image.shape)
image[mask] = value
return image
@property
def masked_fraction(self):
if self._fraction is None:
if self.mask is None:
self._fraction = 0
else:
self._fraction = self.mask.mean()
return self._fraction