"""
Functions that may eventually be moved to kwarray
"""
import functools
import itertools as it
import math
import numpy as np
import os
import ubelt as ub
import warnings
try:
from packaging.version import parse as Version
except ImportError:
from distutils.version import LooseVersion as Version
try:
import importlib.metadata
try:
_TORCH_VERSION = Version(importlib.metadata.version('torch'))
except importlib.metadata.PackageNotFoundError:
_TORCH_VERSION = None
except ImportError:
import pkg_resources
try:
_TORCH_VERSION = Version(pkg_resources.get_distribution('torch').version)
except pkg_resources.DistributionNotFound:
_TORCH_VERSION = None
if _TORCH_VERSION is None:
_TORCH_LT_1_7_0 = None
_TORCH_LT_2_1_0 = None
_TORCH_HAS_MAX_BUG = None
else:
_TORCH_LT_1_7_0 = _TORCH_VERSION < Version('1.7')
_TORCH_LT_2_1_0 = _TORCH_VERSION < Version('2.1')
_TORCH_HAS_MAX_BUG = _TORCH_LT_1_7_0
try:
# The math variant only exists in Python 3+ but is faster for scalars
# so try and use it
from math import isclose
except Exception:
from numpy import isclose
[docs]
def cartesian_product(*arrays):
"""
Fast numpy version of itertools.product
TODO: Move to kwarray
Referencs:
https://stackoverflow.com/a/11146645/887074
"""
la = len(arrays)
dtype = np.result_type(*arrays)
arr = np.empty([len(a) for a in arrays] + [la], dtype=dtype)
for i, a in enumerate(np.ix_(*arrays)):
arr[..., i] = a
return arr.reshape(-1, la)
[docs]
def tukey_biweight_loss(r, c=4.685):
"""
Beaton Tukey Biweight
Computes the function :
L(r) = (
(c ** 2) / 6 * (1 - 1 * (r / c) ** 2) ** 3) if abs(r) <= c else
(c ** 2)
)
Args:
r (float | ndarray): residual parameter
c (float): tuning constant (defaults to 4.685 which is 95% efficient
for normal distributions of residuals)
TODO:
- [ ] Move elsewhere or find a package that provides it
- [ ] Move elsewhere (kwarray?) or find a package that provides it
Returns:
float | ndarray
References:
https://en.wikipedia.org/wiki/Robust_statistics
https://mathworld.wolfram.com/TukeysBiweight.html
https://statisticaloddsandends.wordpress.com/2021/04/23/what-is-the-tukey-loss-function/
https://arxiv.org/pdf/1505.06606.pdf
Example:
>>> from geowatch.utils.util_kwarray import * # NOQA
>>> import ubelt as ub
>>> r = np.linspace(-20, 20, 1000)
>>> data = {'r': r}
>>> grid = ub.named_product({
>>> 'c': [4.685, 2, 6],
>>> })
>>> for kwargs in grid:
>>> key = ub.urepr(kwargs, compact=1)
>>> loss = tukey_biweight_loss(r, **kwargs)
>>> data[key] = loss
>>> import pandas as pd
>>> melted = pd.DataFrame(data).melt(['r'])
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> sns = kwplot.autosns()
>>> kwplot.figure(fnum=1, doclf=True)
>>> ax = sns.lineplot(data=melted, x='r', y='value', hue='variable', style='variable')
>>> #ax.set_ylim(*robust_limits(melted.value))
"""
# https://statisticaloddsandends.wordpress.com/2021/04/23/what-is-the-tukey-loss-function/
is_inside = np.abs(r) < c
c26 = (c ** 2) / 6
loss = np.full_like(r, fill_value=c26, dtype=np.float32)
r_inside = r[is_inside]
loss_inside = c26 * (1 - (1 - (r_inside / c) ** 2) ** 3)
loss[is_inside] = loss_inside
return loss
[docs]
def asymptotic(x, offset=1, gamma=1, degree=0, horizontal=1):
"""
A function with a horizontal asymptote at ``horizontal``
Args:
x (ndarray): input parameter
offset (float): shifts function to the left or the right
gamma (float): higher values approach the asymptote more slowly
horizontal (float): location of the horiztonal asymptote
TODO:
- [ ] Move elsewhere (kwarray?) or find a package that provides it
Example:
>>> from geowatch.utils.util_kwarray import * # NOQA
>>> import ubelt as ub
>>> x = np.linspace(0, 27, 1000)
>>> data = {'x': x}
>>> grid = ub.named_product({
>>> #'gamma': [0.5, 1.0, 2.0, 3.0],
>>> 'gamma': [1.0, 3.0],
>>> 'degree': [0, 1, 2, 3],
>>> 'offset': [0, 2],
>>> 'horizontal': [1],
>>> })
>>> for kwargs in grid:
>>> key = ub.urepr(kwargs, compact=1)
>>> data[key] = asymptotic(x, **kwargs)
>>> import pandas as pd
>>> melted = pd.DataFrame(data).melt(['x'])
>>> print(melted)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> sns = kwplot.autosns()
>>> kwplot.figure(fnum=1, doclf=True)
>>> ax = sns.lineplot(data=melted, x='x', y='value', hue='variable', style='variable')
>>> ax.set_ylim(0, 2)
"""
gamma_denom = gamma + degree
gamma_numer = gamma
assert gamma_numer <= gamma_denom
hz_offset = horizontal - 1 if gamma_numer == gamma_denom else horizontal
numer = (x + offset) ** gamma_numer
denom = (x + offset + 1) ** gamma_denom
return (numer / denom) + hz_offset
[docs]
def robust_limits(values):
"""
# TODO: Proper Robust estimator for matplotlib ylim and general use
values = np.array([-1000, -4, -3, -2, 0, 2.7, 3.1415, 1, 2, 3, 4, 100000])
robust_limits(values)
"""
quants = [0.0, 0.05, 0.08, 0.2, 0.5, 0.8, 0.9, 0.5, 1.0]
values = values[~np.isnan(values)]
quantiles = np.quantile(values, quants)
print('quantiles = {!r}'.format(quantiles))
lower_idx1 = 1
upper_idx1 = 2
part = quantiles[upper_idx1] - quantiles[lower_idx1]
inner_w = quants[upper_idx1] - quants[lower_idx1]
extrap_w = quants[lower_idx1] - quants[0]
extrap_part = part * extrap_w / inner_w
low_value = quantiles[lower_idx1]
robust_min = low_value - extrap_part
#
lower_idx2 = -3
upper_idx2 = -2
high_value = quantiles[upper_idx2]
part = quantiles[upper_idx2] - quantiles[lower_idx2]
inner_w = quants[upper_idx2] - quants[lower_idx2]
extrap_w = quants[lower_idx1] - quants[0]
extrap_part = part * extrap_w / inner_w
robust_max = high_value + extrap_part
robust_min
return robust_min, robust_max
[docs]
def unique_rows(arr, ordered=False):
"""
Note: function also added to kwarray and will be available in >0.5.20
Example:
>>> import kwarray
>>> from kwarray.util_numpy import * # NOQA
>>> rng = kwarray.ensure_rng(0)
>>> arr = rng.randint(0, 2, size=(12, 3))
>>> arr_unique = unique_rows(arr)
>>> print('arr_unique = {!r}'.format(arr_unique))
"""
dtype_view = np.dtype((np.void, arr.dtype.itemsize * arr.shape[1]))
arr_view = arr.view(dtype_view)
if ordered:
arr_view_unique, idxs = np.unique(arr_view, return_index=True)
arr_flat_unique = arr_view_unique.view(arr.dtype)
arr_unique = arr_flat_unique.reshape(-1, arr.shape[1])
arr_unique = arr_unique[np.argsort(idxs)]
else:
arr_view_unique = np.unique(arr_view)
arr_flat_unique = arr_view_unique.view(arr.dtype)
arr_unique = arr_flat_unique.reshape(-1, arr.shape[1])
return arr_unique
[docs]
def find_robust_normalizers(data, params='auto'):
"""
Finds robust normalization statistics for a single observation
Args:
data (ndarray): a 1D numpy array where invalid data has already been removed
params (str | dict): normalization params
Returns:
Dict[str, str | float]: normalization parameters
TODO:
- [ ] No Magic Numbers! Use first principles to deterimine defaults.
- [ ] Probably a lot of literature on the subject.
- [ ] Is this a kwarray function in general?
- [ ] https://arxiv.org/pdf/1707.09752.pdf
- [ ] https://www.tandfonline.com/doi/full/10.1080/02664763.2019.1671961
- [ ] https://www.rips-irsp.com/articles/10.5334/irsp.289/
Example:
>>> data = np.random.rand(100)
>>> norm_params1 = find_robust_normalizers(data, params='auto')
>>> norm_params2 = find_robust_normalizers(data, params={'low': 0, 'high': 1.0})
>>> norm_params3 = find_robust_normalizers(np.empty(0), params='auto')
>>> print('norm_params1 = {}'.format(ub.urepr(norm_params1, nl=1)))
>>> print('norm_params2 = {}'.format(ub.urepr(norm_params2, nl=1)))
>>> print('norm_params3 = {}'.format(ub.urepr(norm_params3, nl=1)))
"""
if data.size == 0:
normalizer = {
'type': None,
'min_val': np.nan,
'max_val': np.nan,
}
else:
# should center the desired distribution to visualize on zero
# beta = np.median(imdata)
default_params = {
'extrema': 'custom-quantile',
'scaling': 'linear',
'low': 0.01,
'mid': 0.5,
'high': 0.9,
}
fense_extremes = None
if isinstance(params, str):
if params == 'auto':
params = {}
elif params == 'tukey':
params = {
'extrema': 'tukey',
}
elif params == 'std':
pass
else:
raise KeyError(params)
# hack
params = ub.dict_union(default_params, params)
if params['extrema'] == 'tukey':
# TODO:
# https://github.com/derekbeaton/OuRS
# https://en.wikipedia.org/wiki/Feature_scaling
fense_extremes = _tukey_quantile_extreme_estimator(data)
elif params['extrema'] == 'custom-quantile':
fense_extremes = _custom_quantile_extreme_estimator(data, params)
else:
raise KeyError(params['extrema'])
min_val, mid_val, max_val = fense_extremes
beta = mid_val
# division factor
# from scipy.special import logit
# alpha = max(abs(old_min - beta), abs(old_max - beta)) / logit(0.998)
# This chooses alpha such the original min/max value will be pushed
# towards -1 / +1.
alpha = max(abs(min_val - beta), abs(max_val - beta)) / 6.212606
normalizer = {
'type': 'normalize',
'mode': params['scaling'],
'min_val': min_val,
'max_val': max_val,
'beta': beta,
'alpha': alpha,
}
return normalizer
def _custom_quantile_extreme_estimator(data, params):
quant_low = params['low']
quant_mid = params['mid']
quant_high = params['high']
qvals = [0, quant_low, quant_mid, quant_high, 1]
quantile_vals = np.quantile(data, qvals)
(quant_low_abs, quant_low_val, quant_mid_val, quant_high_val,
quant_high_abs) = quantile_vals
# TODO: we could implement a hueristic where we do a numerical inspection
# of the intensity distribution. We could apply a normalization that is
# known to work for data with that sort of histogram distribution.
# This might involve fitting several parametarized distributions to the
# data and choosing the one with the best fit. (check how many modes there
# are).
# inner_range = quant_high_val - quant_low_val
# upper_inner_range = quant_high_val - quant_mid_val
# upper_lower_range = quant_mid_val - quant_low_val
# Compute amount of weight in each quantile
quant_center_amount = (quant_high_val - quant_low_val)
quant_low_amount = (quant_mid_val - quant_low_val)
quant_high_amount = (quant_high_val - quant_mid_val)
if math.isclose(quant_center_amount, 0):
high_weight = 0.5
low_weight = 0.5
else:
high_weight = quant_high_amount / quant_center_amount
low_weight = quant_low_amount / quant_center_amount
quant_high_residual = (1.0 - quant_high)
quant_low_residual = (quant_low - 0.0)
# todo: verify, having slight head fog, not 100% sure
low_pad_val = quant_low_residual * (low_weight * quant_center_amount)
high_pad_val = quant_high_residual * (high_weight * quant_center_amount)
min_val = max(quant_low_abs, quant_low_val - low_pad_val)
max_val = max(quant_high_abs, quant_high_val - high_pad_val)
mid_val = quant_mid_val
return (min_val, mid_val, max_val)
def _tukey_quantile_extreme_estimator(data):
# Tukey method for outliers
# https://www.youtube.com/watch?v=zY1WFMAA-ec
q1, q2, q3 = np.quantile(data, [0.25, 0.5, 0.75])
iqr = q3 - q1
# One might wonder where the 1.5 in the above interval comes from -- Paul
# Velleman, a statistician at Cornell University, was a student of John
# Tukey, who invented this test for outliers. He wondered the same thing.
# When he asked Tukey, "Why 1.5?", Tukey answered, "Because 1 is too small
# and 2 is too large."
# Cite: http://mathcenter.oxford.emory.edu/site/math117/shapeCenterAndSpread/
fence_lower = q1 - 1.5 * iqr
fence_upper = q1 + 1.5 * iqr
return fence_lower, q2, fence_upper
[docs]
def apply_normalizer(data, normalizer, mask=None, set_value_at_mask=float('nan')):
dtype = np.float32
result = data.astype(dtype).copy()
if normalizer['type'] is None:
data_normalized = result
else:
if mask is not None:
valid_data = result[mask]
else:
valid_data = result
if valid_data.size > 0:
data_normalized = normalize(
valid_data.astype(dtype), mode=normalizer['mode'],
beta=normalizer.get('beta'), alpha=normalizer.get('alpha'),
min_val=normalizer.get('min_val'),
max_val=normalizer.get('max_val')
)
else:
data_normalized = valid_data
if mask is not None:
mask_flat = mask.ravel()
result_flat = result.ravel()
result_flat[mask_flat] = data_normalized
result_flat[~mask_flat] = set_value_at_mask
else:
result = data_normalized
return result
[docs]
def normalize(arr, mode='linear', alpha=None, beta=None, out=None,
min_val=None, max_val=None):
"""
Rebalance signal values via contrast stretching.
By default linearly stretches array values to minimum and maximum values.
Args:
arr (ndarray): array to normalize, usually an image
out (ndarray | None): output array. Note, that we will create an
internal floating point copy for integer computations.
mode (str): either linear or sigmoid.
alpha (float): Only used if mode=sigmoid. Division factor
(pre-sigmoid). If unspecified computed as:
``max(abs(old_min - beta), abs(old_max - beta)) / 6.212606``.
Note this parameter is sensitive to if the input is a float or
uint8 image.
beta (float): subtractive factor (pre-sigmoid). This should be the
intensity of the most interesting bits of the image, i.e. bring
them to the center (0) of the distribution.
Defaults to ``(max - min) / 2``. Note this parameter is sensitive
to if the input is a float or uint8 image.
min_val: override minimum value
max_val: override maximum value
References:
https://en.wikipedia.org/wiki/Normalization_(image_processing)
Example:
>>> raw_f = np.random.rand(8, 8)
>>> norm_f = normalize(raw_f)
>>> raw_f = np.random.rand(8, 8) * 100
>>> norm_f = normalize(raw_f)
>>> assert isclose(norm_f.min(), 0)
>>> assert isclose(norm_f.max(), 1)
>>> raw_u = (np.random.rand(8, 8) * 255).astype(np.uint8)
>>> norm_u = normalize(raw_u)
>>> raw_m = (np.zeros((8, 8)) + 10)
>>> norm_m = normalize(raw_m, min_val=0, max_val=20)
>>> assert isclose(norm_m.min(), 0.5)
>>> assert isclose(norm_m.max(), 0.5)
>>> # Ensure that we're clamping if explicit min or max values
>>> # are provided
>>> raw_m = (np.zeros((8, 8)) + 10)
>>> norm_m = normalize(raw_m, min_val=0, max_val=5)
>>> assert isclose(norm_m.min(), 1.0)
>>> assert isclose(norm_m.max(), 1.0)
Example:
>>> # xdoctest: +REQUIRES(module:kwimage)
>>> import kwimage
>>> arr = kwimage.grab_test_image('lowcontrast')
>>> arr = kwimage.ensure_float01(arr)
>>> norms = {}
>>> norms['arr'] = arr.copy()
>>> norms['linear'] = normalize(arr, mode='linear')
>>> norms['sigmoid'] = normalize(arr, mode='sigmoid')
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.figure(fnum=1, doclf=True)
>>> pnum_ = kwplot.PlotNums(nSubplots=len(norms))
>>> for key, img in norms.items():
>>> kwplot.imshow(img, pnum=pnum_(), title=key)
"""
if out is None:
out = arr.copy()
# TODO:
# - [ ] Parametarize new_min / new_max values
# - [ ] infer from datatype
# - [ ] explicitly given
new_min = 0.0
if arr.dtype.kind in ('i', 'u'):
# Need a floating point workspace
float_out = out.astype(np.float32)
new_max = float(np.iinfo(arr.dtype).max)
elif arr.dtype.kind == 'f':
float_out = out
new_max = 1.0
else:
raise NotImplementedError
# TODO:
# - [ ] Parametarize old_min / old_max strategies
# - [X] explicitly given min and max
# - [ ] raw-naive min and max inference
# - [ ] outlier-aware min and max inference
if min_val is not None:
old_min = min_val
float_out[float_out < min_val] = min_val
else:
old_min = float_out.min()
if max_val is not None:
old_max = max_val
float_out[float_out > max_val] = max_val
else:
old_max = float_out.max()
old_span = old_max - old_min
new_span = new_max - new_min
if mode == 'linear':
# linear case
# out = (arr - old_min) * (new_span / old_span) + new_min
factor = 1.0 if old_span == 0 else (new_span / old_span)
if old_min != 0:
float_out -= old_min
elif mode == 'sigmoid':
# nonlinear case
# out = new_span * sigmoid((arr - beta) / alpha) + new_min
from scipy.special import expit as sigmoid
if beta is None:
# should center the desired distribution to visualize on zero
beta = old_max - old_min
if alpha is None:
# division factor
# from scipy.special import logit
# alpha = max(abs(old_min - beta), abs(old_max - beta)) / logit(0.998)
# This chooses alpha such the original min/max value will be pushed
# towards -1 / +1.
alpha = max(abs(old_min - beta), abs(old_max - beta)) / 6.212606
if isclose(alpha, 0):
alpha = 1
energy = float_out
energy -= beta
energy /= alpha
# Ideally the data of interest is roughly in the range (-6, +6)
float_out = sigmoid(energy, out=float_out)
factor = new_span
else:
raise KeyError(mode)
# Stretch / shift to the desired output range
if factor != 1:
float_out *= factor
if new_min != 0:
float_out += new_min
if float_out is not out:
out[:] = float_out.astype(out.dtype)
return out
[docs]
def balanced_number_partitioning(items, num_parts):
"""
Greedy approximation to multiway number partitioning
Uses Greedy number partitioning method to minimize the size of the largest
partition.
Args:
items (np.ndarray): list of numbers (i.e. weights) to split
between paritions.
num_parts (int): number of partitions
Returns:
List[np.ndarray]:
A list for each parition that contains the index of the items
assigned to it.
References:
https://en.wikipedia.org/wiki/Multiway_number_partitioning
https://en.wikipedia.org/wiki/Balanced_number_partitioning
Example:
>>> from geowatch.utils.util_kwarray import * # NOQA
>>> items = np.array([1, 3, 29, 22, 4, 5, 9])
>>> num_parts = 3
>>> bin_assignments = balanced_number_partitioning(items, num_parts)
>>> import kwarray
>>> groups = kwarray.apply_grouping(items, bin_assignments)
>>> bin_weights = [g.sum() for g in groups]
"""
item_weights = np.asanyarray(items)
sortx = np.argsort(item_weights)[::-1]
bin_assignments = [[] for _ in range(num_parts)]
bin_sums = np.zeros(num_parts)
for item_index in sortx:
# Assign item to the smallest bin
item_weight = item_weights[item_index]
bin_index = bin_sums.argmin()
bin_assignments[bin_index].append(item_index)
bin_sums[bin_index] += item_weight
bin_assignments = [np.array(p, dtype=int) for p in bin_assignments]
return bin_assignments
[docs]
def torch_array_equal(data1, data2, equal_nan=False) -> bool:
"""
Example:
>>> # xdoctest: +REQUIRES(module:torch)
>>> import torch
>>> data1 = torch.rand(5, 5)
>>> data2 = data1 + 1
>>> result1 = torch_array_equal(data1, data2)
>>> result3 = torch_array_equal(data1, data1)
>>> assert result1 is False
>>> assert result3 is True
Example:
>>> # xdoctest: +REQUIRES(module:torch)
>>> import torch
>>> data1 = torch.rand(5, 5)
>>> data1[0] = np.nan
>>> data2 = data1
>>> result1 = torch_array_equal(data1, data2)
>>> result3 = torch_array_equal(data1, data2, equal_nan=True)
>>> assert result1 is False
>>> assert result3 is True
"""
# TODO: just use
# return kwarray.ArrayAPI.coerce('torch').array_equal(data1, data2, equal_nan)
import torch
if equal_nan:
val_flags = torch.eq(data1, data2)
nan_flags = (data1.isnan() & data2.isnan())
flags = val_flags | nan_flags
return bool(flags.all())
else:
if _TORCH_LT_2_1_0:
return torch.equal(data1, data2)
else:
# Torch 2.1 introduced a bug so we need an alternate
# implementation.
# References:
# https://github.com/pytorch/pytorch/issues/111251
return bool(torch.eq(data1, data2).all())
[docs]
def combine_mean_stds(means, stds, nums=None, axis=None, keepdims=False,
bessel=True):
r"""
Args:
means (array): means[i] is the mean of the ith entry to combine
stds (array): stds[i] is the std of the ith entry to combine
nums (array | None):
nums[i] is the number of samples in the ith entry to combine.
if None, assumes sample sizes are infinite.
axis (int | Tuple[int] | None):
axis to combine the statistics over
keepdims (bool):
if True return arrays with the same number of dimensions they were
given in.
bessel (int):
Set to 1 to enables bessel correction to unbias the combined std
estimate. Only disable if you have the true population means, or
you think you know what you are doing.
References:
https://stats.stackexchange.com/questions/55999/is-it-possible-to-find-the-combined-standard-deviation
SeeAlso:
development kwarray has a similar hidden function in util_averages.
Might expose later.
Example:
>>> means = np.stack([np.array([1.2, 3.2, 4.1])] * 100, axis=0)
>>> stds = np.stack([np.array([4.2, 0.2, 2.1])] * 100, axis=0)
>>> nums = np.stack([np.array([10, 100, 10])] * 100, axis=0)
>>> cm1, cs1, _ = combine_mean_stds(means, stds, nums, axis=None)
>>> print('combo_mean = {}'.format(ub.urepr(cm1, nl=1)))
>>> print('combo_std = {}'.format(ub.urepr(cs1, nl=1)))
>>> means = np.stack([np.array([1.2, 3.2, 4.1])] * 1, axis=0)
>>> stds = np.stack([np.array([4.2, 0.2, 2.1])] * 1, axis=0)
>>> nums = np.stack([np.array([10, 100, 10])] * 1, axis=0)
>>> cm2, cs2, _ = combine_mean_stds(means, stds, nums, axis=None)
>>> print('combo_mean = {}'.format(ub.urepr(cm2, nl=1)))
>>> print('combo_std = {}'.format(ub.urepr(cs2, nl=1)))
>>> means = np.stack([np.array([1.2, 3.2, 4.1])] * 5, axis=0)
>>> stds = np.stack([np.array([4.2, 0.2, 2.1])] * 5, axis=0)
>>> nums = np.stack([np.array([10, 100, 10])] * 5, axis=0)
>>> cm3, cs3, combo_num = combine_mean_stds(means, stds, nums, axis=1)
>>> print('combo_mean = {}'.format(ub.urepr(cm3, nl=1)))
>>> print('combo_std = {}'.format(ub.urepr(cs3, nl=1)))
>>> assert np.allclose(cm1, cm2) and np.allclose(cm2, cm3)
>>> assert not np.allclose(cs1, cs2)
>>> assert np.allclose(cs2, cs3)
Example:
>>> from geowatch.utils.util_kwarray import * # NOQA
>>> means = np.random.rand(2, 3, 5, 7)
>>> stds = np.random.rand(2, 3, 5, 7)
>>> nums = (np.random.rand(2, 3, 5, 7) * 10) + 1
>>> cm, cs, cn = combine_mean_stds(means, stds, nums, axis=1, keepdims=1)
>>> print('cs = {}'.format(ub.urepr(cs, nl=1)))
>>> assert cm.shape == cs.shape == cn.shape
...
>>> print(f'cm.shape={cm.shape}')
>>> cm, cs, cn = combine_mean_stds(means, stds, nums, axis=(0, 2), keepdims=1)
>>> assert cm.shape == cs.shape == cn.shape
>>> print(f'cm.shape={cm.shape}')
>>> cm, cs, cn = combine_mean_stds(means, stds, nums, axis=(1, 3), keepdims=1)
>>> assert cm.shape == cs.shape == cn.shape
>>> print(f'cm.shape={cm.shape}')
>>> cm, cs, cn = combine_mean_stds(means, stds, nums, axis=None)
>>> assert cm.shape == cs.shape == cn.shape
>>> print(f'cm.shape={cm.shape}')
cm.shape=(2, 1, 5, 7)
cm.shape=(1, 3, 1, 7)
cm.shape=(2, 1, 5, 1)
cm.shape=()
"""
if nums is None:
# Assume the limit as nums -> infinite
combo_num = None
combo_mean = np.average(means, weights=None, axis=axis)
combo_mean = _postprocess_keepdims(means, combo_mean, axis)
numer_p1 = stds.sum(axis=axis, keepdims=1)
numer_p2 = (((means - combo_mean) ** 2)).sum(axis=axis, keepdims=1)
numer = numer_p1 + numer_p2
denom = len(stds)
# if denom == 0:
with warnings.catch_warnings():
warnings.filterwarnings('ignore', 'invalid value encountered', category=RuntimeWarning)
combo_std = np.sqrt(numer / denom)
# else:
# combo_std = np.full_like(numer, fill_value=np.nan)
else:
combo_num = nums.sum(axis=axis, keepdims=1)
weights = nums / combo_num
combo_mean = np.average(means, weights=weights, axis=axis)
combo_mean = _postprocess_keepdims(means, combo_mean, axis)
numer_p1 = (np.maximum(nums - bessel, 0) * stds).sum(axis=axis, keepdims=1)
numer_p2 = (nums * ((means - combo_mean) ** 2)).sum(axis=axis, keepdims=1)
numer = numer_p1 + numer_p2
denom = np.maximum(combo_num - bessel, 0)
# if denom == 0:
with warnings.catch_warnings():
warnings.filterwarnings('ignore', 'invalid value encountered', category=RuntimeWarning)
combo_std = np.sqrt(numer / denom)
# else:
# combo_std = np.full_like(numer, fill_value=np.nan)
if not keepdims:
indexer = _no_keepdim_indexer(combo_mean, axis)
combo_mean = combo_mean[indexer]
combo_std = combo_std[indexer]
if combo_num is not None:
combo_num = combo_num[indexer]
return combo_mean, combo_std, combo_num
def _no_keepdim_indexer(result, axis):
"""
Computes an indexer to postprocess a result with keepdims=True
that will modify the result as if keepdims=False
"""
if axis is None:
indexer = [0] * len(result.shape)
else:
indexer = [slice(None)] * len(result.shape)
if isinstance(axis, (list, tuple)):
for a in axis:
indexer[a] = 0
else:
indexer[axis] = 0
indexer = tuple(indexer)
return indexer
def _postprocess_keepdims(original, result, axis):
"""
Can update the result of a function that does not support keepdims to look
as if keepdims was supported.
"""
# Newer versions of numpy have keepdims on more functions
if axis is not None:
expander = [slice(None)] * len(original.shape)
if isinstance(axis, (list, tuple)):
for a in axis:
expander[a] = None
else:
expander[axis] = None
result = result[tuple(expander)]
else:
expander = [None] * len(original.shape)
result = np.array(result)[tuple(expander)]
return result
[docs]
def apply_robust_normalizer(normalizer, imdata, imdata_valid, mask, dtype, copy=True):
"""
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/joncrall/code/watch/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py", line 1004, in __getitem__
return self.getitem(index)
File "/home/joncrall/code/watch/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py", line 1375, in getitem
imdata_normalized = apply_robust_normalizer(
File "/home/joncrall/code/watch/geowatch/tasks/fusion/datamodules/kwcoco_dataset.py", line 2513, in apply_robust_normalizer
imdata_valid_normalized = kwarray.normalize(
File "/home/joncrall/code/kwarray/kwarray/util_numpy.py", line 760, in normalize
old_min = np.nanmin(float_out)
File "<__array_function__ internals>", line 5, in nanmin
File "/home/joncrall/.pyenv/versions/3.10.5/envs/pyenv3.10.5/lib/python3.10/site-packages/numpy/lib/nanfunctions.py", line 319, in nanmin
res = np.fmin.reduce(a, axis=axis, out=out, **kwargs)
"""
import kwarray
if normalizer['type'] is None:
imdata_normalized = imdata.astype(dtype, copy=copy)
elif normalizer['type'] == 'normalize':
# Note: we are using kwarray normalize, the one in kwimage is deprecated
arr = imdata_valid.astype(dtype, copy=copy)
imdata_valid_normalized = kwarray.normalize(
arr, mode=normalizer['mode'],
beta=normalizer['beta'], alpha=normalizer['alpha'],
)
if mask is None:
imdata_normalized = imdata_valid_normalized
else:
imdata_normalized = imdata.copy() if copy else imdata
imdata_normalized[mask] = imdata_valid_normalized
else:
raise KeyError(normalizer['type'])
return imdata_normalized
[docs]
@functools.cache
def biased_1d_weights(upweight_time, num_frames):
"""
import kwplot
plt = kwplot.autoplt()
kwplot.figure()
import sys, ubelt
sys.path.append(ubelt.expandpath('~/code/watch'))
from geowatch.tasks.fusion.datamodules.kwcoco_dataset import * # NOQA
kwplot.figure(fnum=1, doclf=1)
num_frames = 5
values = biased_1d_weights(0.5, num_frames)
plt.plot(values)
values = biased_1d_weights(0.1, num_frames)
plt.plot(values)
values = biased_1d_weights(0.0, num_frames)
plt.plot(values)
values = biased_1d_weights(0.9, num_frames)
plt.plot(values)
values = biased_1d_weights(1.0, num_frames)
plt.plot(values)
"""
# from kwarray.distributions import TruncNormal
from scipy.stats import norm
import kwimage
# from kwarray.distributions import TruncNormal
sigma = kwimage.im_cv2._auto_kernel_sigma(kernel=((num_frames, 1)))[1][0]
mean = upweight_time * (num_frames - 1) + 0.5
# rv = TruncNormal(mean=mean, std=sigma, low=0.0, high=num_frames).rv
rv = norm(mean, sigma)
locs = np.arange(num_frames) + 0.5
values = rv.pdf(locs)
return values
[docs]
def argsort_threshold(arr, threshold=None, num_top=None, objective='maximize'):
"""
Find all indexes over a threshold, but always return at least the
`num_top`, and potentially more.
Args:
arr (ndarray): array of scores
threshold (float):
return indexes that are better than this threshold.
num_top (int):
always return at least this number of "best" indexes.
objective (str):
if maximize, filters things above the threshold, otherwise filters
below the threshold.
Returns:
ndarray: top indexes
Example:
>>> from geowatch.utils.util_kwarray import * # NOQA
>>> arr = np.array([0.3, .2, 0.1, 0.15, 0.11, 0.15, 0.2, 0.6, 0.32])
>>> argsort_threshold(arr, threshold=0.5, num_top=0)
array([7])
>>> argsort_threshold(arr, threshold=0.5, num_top=3)
array([7, 8, 0])
>>> argsort_threshold(arr, threshold=0.0, num_top=3)
"""
# Find the "best" indices and their scores
ascending_sortx = arr.argsort()
# Mark any index "better" than the score threshold
if objective == 'maximize':
sortx = ascending_sortx[::-1]
sorted_arr = arr[sortx]
flags = sorted_arr > threshold
elif objective == 'minimize':
sortx = ascending_sortx
sorted_arr = arr[sortx]
flags = sorted_arr < threshold
else:
raise KeyError(objective)
if num_top is not None:
# Always return at least `num_top`
flags[0:num_top] = True
fallback_thresh = sorted_arr[num_top - 1]
threshold = min(fallback_thresh, threshold)
top_inds = sortx[flags]
return top_inds
from geowatch.utils.remedian import Remedian # NOQA
"""
Defines the :class:`SlidingWindow` and :class:`Sticher` classes.
The :class:`SlidingWindow` generates a grid of slices over an
:func:`numpy.ndarray`, which can then be used to compute on subsets of the
data. The :class:`Stitcher` can then take these results and recombine them into
a final result that matches the larger array.
"""
[docs]
class SlidingWindow(ub.NiceRepr):
"""
Slide a window of a certain shape over an array with a larger shape.
This can be used for iterating over a grid of sub-regions of 2d-images,
3d-volumes, or any n-dimensional array.
Yields slices of shape `window` that can be used to index into an array
with shape `shape` via numpy / torch fancy indexing. This allows for fast
fast iteration over subregions of a larger image. Because we generate a
grid-basis using only shapes, the larger image does not need to be in
memory as long as its width/height/depth/etc...
Args:
shape (Tuple[int, ...]): shape of source array to slide across.
window (Tuple[int, ...]): shape of window that will be slid over the
larger image.
overlap (float, default=0): a number between 0 and 1 indicating the
fraction of overlap that parts will have. Specifying this is
mutually exclusive with `stride`. Must be `0 <= overlap < 1`.
stride (int, default=None): the number of cells (pixels) moved on each
step of the window. Mutually exclusive with overlap.
keepbound (bool, default=False): if True, a non-uniform stride will be
taken to ensure that the right / bottom of the image is returned as
a slice if needed. Such a slice will not obey the overlap
constraints. (Defaults to False)
allow_overshoot (bool, default=False): if False, we will raise an
error if the window doesn't slide perfectly over the input shape.
Attributes:
basis_shape - shape of the grid corresponding to the number of strides
the sliding window will take.
basis_slices - slices that will be taken in every dimension
Yields:
Tuple[slice, ...]: slices used for numpy indexing, the number of slices
in the tuple
Note:
For each dimension, we generate a basis (which defines a grid), and we
slide over that basis.
TODO:
- [ ] have an option that is allowed to go outside of the window bounds
on the right and bottom when the slider overshoots.
Example:
>>> shape = (10, 10)
>>> window = (5, 5)
>>> self = SlidingWindow(shape, window)
>>> for i, index in enumerate(self):
>>> print('i={}, index={}'.format(i, index))
i=0, index=(slice(0, 5, None), slice(0, 5, None))
i=1, index=(slice(0, 5, None), slice(5, 10, None))
i=2, index=(slice(5, 10, None), slice(0, 5, None))
i=3, index=(slice(5, 10, None), slice(5, 10, None))
Example:
>>> shape = (16, 16)
>>> window = (4, 4)
>>> self = SlidingWindow(shape, window, overlap=(.5, .25))
>>> print('self.stride = {!r}'.format(self.stride))
self.stride = [2, 3]
>>> list(ub.chunks(self.grid, 5))
[[(0, 0), (0, 1), (0, 2), (0, 3), (0, 4)],
[(1, 0), (1, 1), (1, 2), (1, 3), (1, 4)],
[(2, 0), (2, 1), (2, 2), (2, 3), (2, 4)],
[(3, 0), (3, 1), (3, 2), (3, 3), (3, 4)],
[(4, 0), (4, 1), (4, 2), (4, 3), (4, 4)],
[(5, 0), (5, 1), (5, 2), (5, 3), (5, 4)],
[(6, 0), (6, 1), (6, 2), (6, 3), (6, 4)]]
Example:
>>> # Test shapes that dont fit
>>> # When the window is bigger than the shape, the left-aligned slices
>>> # are returend.
>>> self = SlidingWindow((3, 3), (12, 12), allow_overshoot=True, keepbound=True)
>>> print(list(self))
[(slice(0, 12, None), slice(0, 12, None))]
>>> print(list(SlidingWindow((3, 3), None, allow_overshoot=True, keepbound=True)))
[(slice(0, 3, None), slice(0, 3, None))]
>>> print(list(SlidingWindow((3, 3), (None, 2), allow_overshoot=True, keepbound=True)))
[(slice(0, 3, None), slice(0, 2, None)), (slice(0, 3, None), slice(1, 3, None))]
"""
def __init__(self, shape, window, overlap=None, stride=None,
keepbound=False, allow_overshoot=False):
stride, overlap, window = self._compute_stride(
overlap, stride, shape, window)
stide_kw = [dict(margin=d, stop=D, step=s, keepbound=keepbound,
check=not keepbound and not allow_overshoot)
for d, D, s in zip(window, shape, stride)]
undershot_shape = []
overshoots = []
for kw in stide_kw:
final_pos = (kw['stop'] - kw['margin'])
n_steps = final_pos // kw['step']
overshoot = final_pos % kw['step']
undershot_shape.append(n_steps + 1)
overshoots.append(overshoot)
self._final_step = overshoots
if not allow_overshoot and any(overshoots):
raise ValueError('overshoot={} stide_kw={}'.format(overshoots,
stide_kw))
# make a slice generator for each dimension
self.stride = stride
self.overlap = overlap
self.window = window
self.input_shape = shape
# The undershot basis shape, only contains indices that correspond
# perfectly to the input. It may crop a bit of the ends. If this is
# equal to basis_shape, then the self perfectly fits the input.
self.undershot_shape = undershot_shape
# NOTE: if we have overshot, then basis shape will not perfectly
# align to the original image. This shape will be a bit bigger.
self.basis_slices = [list(_slices1d(**kw)) for kw in stide_kw]
self.basis_shape = [len(b) for b in self.basis_slices]
self.n_total = np.prod(self.basis_shape)
def __nice__(self):
return 'bshape={}, shape={}, window={}, stride={}'.format(
tuple(self.basis_shape),
tuple(self.input_shape),
self.window,
tuple(self.stride)
)
def _compute_stride(self, overlap, stride, shape, window):
"""
Ensures that stride hasoverlap the correct shape. If stride is not
provided, compute stride from desired overlap.
"""
if window is None:
window = shape
if isinstance(stride, np.ndarray):
stride = tuple(stride)
# TODO: some auto overlap?
if isinstance(overlap, np.ndarray):
overlap = tuple(overlap)
if len(window) != len(shape):
raise ValueError('incompatible dims: {} {}'.format(len(window),
len(shape)))
if any(d is None for d in window):
window = [D if d is None else d for d, D in zip(window, shape)]
if overlap is None and stride is None:
overlap = 0
if not (overlap is None) ^ (stride is None):
raise ValueError('specify overlap({}) XOR stride ({})'.format(
overlap, stride))
if stride is None:
if not isinstance(overlap, (list, tuple)):
overlap = [overlap] * len(window)
if any(frac < 0 or frac >= 1 for frac in overlap):
raise ValueError((
'part overlap was {}, but fractional overlaps must be '
'in the range [0, 1)').format(overlap))
stride = [int(round(d - d * frac))
for frac, d in zip(overlap, window)]
else:
if not isinstance(stride, (list, tuple)):
stride = [stride] * len(window)
# Recompute fractional overlap after integer stride is computed
overlap = [(d - s) / d for s, d in zip(stride, window)]
assert len(stride) == len(shape), 'incompatible dims'
if not all(stride):
raise ValueError(
'Step must be positive everywhere. Got={}'.format(stride))
return stride, overlap, window
def __len__(self):
return self.n_total
def _iter_basis_frac(self):
for slices in self:
frac = [sl.start / D for sl, D in zip(slices, self.source.shape)]
yield frac
def __iter__(self):
for slices in it.product(*self.basis_slices):
yield slices
def __getitem__(self, index):
"""
Get a specific item by its flat (raveled) index
Example:
>>> from kwarray.util_slider import * # NOQA
>>> window = (10, 10)
>>> shape = (20, 20)
>>> self = SlidingWindow(shape, window, stride=5)
>>> itered_items = list(self)
>>> assert len(itered_items) == len(self)
>>> indexed_items = [self[i] for i in range(len(self))]
>>> assert itered_items[0] == self[0]
>>> assert itered_items[-1] == self[-1]
>>> assert itered_items == indexed_items
"""
if index < 0:
index = len(self) + index
# Find the nd location in the grid
basis_idx = np.unravel_index(index, self.basis_shape)
# Take the slice for each of the n dimensions
slices = tuple([bdim[i]
for bdim, i in zip(self.basis_slices, basis_idx)])
return slices
@property
def grid(self):
"""
Generate indices into the "basis" slice for each dimension.
This enumerates the nd indices of the grid.
Yields:
Tuple[int, ...]
"""
# Generates basis for "sliding window" slices to break a large image
# into smaller pieces. Use it.product to slide across the coordinates.
basis_indices = map(range, self.basis_shape)
for basis_idxs in it.product(*basis_indices):
yield basis_idxs
@property
def slices(self):
"""
Generate slices for each window (equivalent to iter(self))
Example:
>>> shape = (220, 220)
>>> window = (10, 10)
>>> self = SlidingWindow(shape, window, stride=5)
>>> list(self)[41:45]
[(slice(0, 10, None), slice(205, 215, None)),
(slice(0, 10, None), slice(210, 220, None)),
(slice(5, 15, None), slice(0, 10, None)),
(slice(5, 15, None), slice(5, 15, None))]
>>> print('self.overlap = {!r}'.format(self.overlap))
self.overlap = [0.5, 0.5]
"""
return iter(self)
@property
def centers(self):
"""
Generate centers of each window
Yields:
Tuple[float, ...]: the center coordinate of the slice
Example:
>>> shape = (4, 4)
>>> window = (3, 3)
>>> self = SlidingWindow(shape, window, stride=1)
>>> list(zip(self.centers, self.slices))
[((1.0, 1.0), (slice(0, 3, None), slice(0, 3, None))),
((1.0, 2.0), (slice(0, 3, None), slice(1, 4, None))),
((2.0, 1.0), (slice(1, 4, None), slice(0, 3, None))),
((2.0, 2.0), (slice(1, 4, None), slice(1, 4, None)))]
>>> shape = (3, 3)
>>> window = (2, 2)
>>> self = SlidingWindow(shape, window, stride=1)
>>> list(zip(self.centers, self.slices))
[((0.5, 0.5), (slice(0, 2, None), slice(0, 2, None))),
((0.5, 1.5), (slice(0, 2, None), slice(1, 3, None))),
((1.5, 0.5), (slice(1, 3, None), slice(0, 2, None))),
((1.5, 1.5), (slice(1, 3, None), slice(1, 3, None)))]
"""
for slices in self:
center = tuple(sl_.start + (sl_.stop - sl_.start - 1) / 2
for sl_ in slices)
yield center
[docs]
class Stitcher(ub.NiceRepr):
"""
From kwarray: v0.6.19
Stitches multiple possibly overlapping slices into a larger array.
This is used to invert the SlidingWindow. For semenatic segmentation the
patches are probability chips. Overlapping chips are averaged together.
SeeAlso:
:class:`kwarray.RunningStats` - similarly performs running means, but
can also track other statistics.
Example:
>>> # Build a high resolution image and slice it into chips
>>> highres = np.random.rand(5, 200, 200).astype(np.float32)
>>> target_shape = (1, 50, 50)
>>> slider = SlidingWindow(highres.shape, target_shape, overlap=(0, .5, .5))
>>> # Show how Sticher can be used to reconstruct the original image
>>> stitcher = Stitcher(slider.input_shape)
>>> for sl in list(slider):
... chip = highres[sl]
... stitcher.add(sl, chip)
>>> assert stitcher.weights.max() == 4, 'some parts should be processed 4 times'
>>> recon = stitcher.finalize()
Example:
>>> # Demo stitching 3 patterns where one has nans
>>> pat1 = np.full((32, 32), fill_value=0.2)
>>> pat2 = np.full((32, 32), fill_value=0.4)
>>> pat3 = np.full((32, 32), fill_value=0.8)
>>> pat1[:, 16:] = 0.6
>>> pat2[16:, :] = np.nan
>>> # Test with nan_policy=omit
>>> stitcher = Stitcher(shape=(32, 64), nan_policy='omit')
>>> stitcher[0:32, 0:32](pat1)
>>> stitcher[0:32, 16:48](pat2)
>>> stitcher[0:32, 33:64](pat3[:, 1:])
>>> final1 = stitcher.finalize()
>>> # Test without nan_policy=propogate
>>> stitcher = Stitcher(shape=(32, 64), nan_policy='propogate')
>>> stitcher[0:32, 0:32](pat1)
>>> stitcher[0:32, 16:48](pat2)
>>> stitcher[0:32, 33:64](pat3[:, 1:])
>>> final2 = stitcher.finalize()
>>> # Checks
>>> assert np.isnan(final1).sum() == 16, 'only should contain nan where no data was stiched'
>>> assert np.isnan(final2).sum() == 512, 'should contain nan wherever a nan was stitched'
>>> # xdoctest: +REQUIRES(--show)
>>> # xdoctest: +REQUIRES(module:kwplot)
>>> import kwplot
>>> import kwimage
>>> kwplot.autompl()
>>> kwplot.imshow(pat1, title='pat1', pnum=(3, 3, 1))
>>> kwplot.imshow(kwimage.nodata_checkerboard(pat2, square_shape=1), title='pat2 (has nans)', pnum=(3, 3, 2))
>>> kwplot.imshow(pat3, title='pat3', pnum=(3, 3, 3))
>>> kwplot.imshow(kwimage.nodata_checkerboard(final1, square_shape=1), title='stitched (nan_policy=omit)', pnum=(3, 1, 2))
>>> kwplot.imshow(kwimage.nodata_checkerboard(final2, square_shape=1), title='stitched (nan_policy=propogate)', pnum=(3, 1, 3))
"""
def __init__(self, shape, device='numpy', dtype='float32',
nan_policy='propogate', memmap=False):
"""
Args:
shape (tuple): dimensions of the large image that will be created from
the smaller pixels or patches.
device (str | int | torch.device):
default is 'numpy', but if given as a torch device, then
underlying operations will be done with torch tensors instead.
dtype (str):
the datatype to use in the underlying accumulator.
nan_policy (str):
if omit, check for nans and convert any to zero weight items in
stitching.
memmap (bool | PathLike):
if truthy, the stitcher will use a memory map. If this
pathlike, then we use this as the directory for the memmap.
If True, a temp directory is used.
"""
self.nan_policy = nan_policy
self.shape = shape
self.device = device
self.paths = None
use_memmap = bool(memmap)
if use_memmap:
import uuid
uuid = uuid.uuid4()
if isinstance(memmap, (str, os.PathLike)):
memmap_dpath = ub.Path(memmap)
else:
from tempfile import mkdtemp
memmap_dpath = ub.Path(mkdtemp())
memmap_sums_fpath = memmap_dpath / f'{uuid}-sums.npy'
memmap_weights_fpath = memmap_dpath / f'{uuid}-weights.npy'
self.paths = {
'sums': memmap_sums_fpath,
'weights': memmap_weights_fpath,
}
else:
memmap_dpath = None
memmap_sums_fpath = None
memmap_weights_fpath = None
if device == 'numpy':
if use_memmap:
# Seems to always init to zero
self.sums = np.memmap(memmap_sums_fpath, dtype=dtype, mode='w+', shape=shape)
self.weights = np.memmap(memmap_weights_fpath, dtype=dtype, mode='w+', shape=shape)
else:
self.sums = np.zeros(shape, dtype=dtype)
self.weights = np.zeros(shape, dtype=dtype)
# self.sumview = self.sums.ravel()
# self.weightview = self.weights.ravel()
else:
import torch
if memmap:
raise NotImplementedError('cannot do torch memmaping')
else:
self.sums = torch.zeros(shape, device=device)
self.weights = torch.zeros(shape, device=device)
# self.sumview = self.sums.view(-1)
# self.weightview = self.weights.view(-1)
if self.nan_policy in {'omit', 'raise'}:
if device == 'numpy':
self._isnan = np.isnan
self._any = np.any
else:
self._isnan = torch.isnan
self._any = torch.any
elif self.nan_policy != 'propogate':
raise ValueError(self.nan_policy)
def __nice__(self):
return str(self.sums.shape)
[docs]
def add(self, indices, patch, weight=None):
"""
Incorporate a new (possibly overlapping) patch or pixel using a
weighted sum.
Args:
indices (slice | tuple | None):
typically a Tuple[slice] of pixels or a single pixel, but this
can be any numpy fancy index.
patch (ndarray): data to patch into the bigger image.
weight (float | ndarray): weight of this patch (default to 1.0)
"""
if self.nan_policy == 'omit':
mask = self._isnan(patch)
if self._any(mask):
# Detect nans, set weight and value to zero
if weight is None:
weight = (~mask).astype(self.weights.dtype)
else:
weight = weight * (~mask).astype(self.weights.dtype)
patch = patch.copy()
patch[mask] = 0
elif self.nan_policy == 'raise':
mask = self._isnan(patch)
if self._any(mask):
raise ValueError('nan_policy is raise')
if weight is None:
self.sums[indices] += patch
self.weights[indices] += 1.0
else:
self.sums[indices] += (patch * weight)
self.weights[indices] += weight
def __getitem__(self, indices):
"""
Convinience function to use slice notation directly.
"""
from functools import partial
return partial(self.add, indices)
[docs]
def average(self):
"""
Averages out contributions from overlapping adds using weighted average
Returns:
ndarray: out - the stitched image
"""
out = self.sums / self.weights
return out
[docs]
def finalize(self, indices=None):
"""
Averages out contributions from overlapping adds
Args:
indices (None | slice | tuple): if None, finalize the entire
block, otherwise only finalize a subregion.
Returns:
ndarray: final - the stitched image
"""
if indices is None:
final = self.sums / self.weights
else:
final = self.sums[indices] / self.weights[indices]
return final
def _slices1d(margin, stop, step=None, start=0, keepbound=False, check=True):
"""
Helper to generates slices in a single dimension.
Args:
margin (int): the length of the slice (window)
stop (int): the length of the image dimension
step (int, default=None): the length of each step / distance between
slices
start (int, default=0): starting point (in most cases set this to 0)
keepbound (bool): if True, a non-uniform step will be taken to ensure
that the right / bottom of the image is returned as a slice if
needed. Such a slice will not obey the overlap constraints.
(Defaults to False)
check (bool): if True an error will be raised if the window does not
cover the entire extent from start to stop, even if keepbound is
True.
Yields:
slice : slice in one dimension of size (margin)
Example:
>>> stop, margin, step = 2000, 360, 360
>>> keepbound = True
>>> strides = list(_slices1d(margin, stop, step, keepbound, check=False))
>>> assert all([(s.stop - s.start) == margin for s in strides])
Example:
>>> stop, margin, step = 200, 46, 7
>>> keepbound = True
>>> strides = list(_slices1d(margin, stop, step, keepbound=False, check=True))
>>> starts = np.array([s.start for s in strides])
>>> stops = np.array([s.stop for s in strides])
>>> widths = stops - starts
>>> assert np.all(np.diff(starts) == step)
>>> assert np.all(widths == margin)
Example:
>>> import pytest
>>> stop, margin, step = 200, 36, 7
>>> with pytest.raises(ValueError):
... list(_slices1d(margin, stop, step))
"""
if step is None:
step = margin
if check:
# see how far off the end we would fall if we didnt check bounds
perfect_final_pos = (stop - start - margin)
overshoot = perfect_final_pos % step
if overshoot > 0:
raise ValueError(
('margin={} and step={} overshoot endpoint={} '
'by {} units when starting from={}').format(
margin, step, stop, overshoot, start))
pos = start
# probably could be more efficient with numpy here
while True:
endpos = pos + margin
yield slice(pos, endpos)
# Stop once we reached the end
if endpos == stop:
break
pos += step
if pos + margin > stop:
if keepbound:
# Ensure the boundary is always used even if steps
# would overshoot Could do some other strategy here
pos = stop - margin
if pos < 0:
break
else:
break