import kwcoco
import ubelt as ub
import torch
import numpy as np
from geowatch.tasks.fusion.methods.network_modules import _class_weights_from_freq
[docs]
class ExtendTorchMixin:
[docs]
def reset_weights(self):
for name, mod in self.named_modules():
if hasattr(mod, 'reset_parameters'):
mod.reset_parameters()
def _device_dict(self):
return {key: item.device for key, item in self.state_dict().items()}
[docs]
def devices(self):
"""
Returns all devices this module state is mounted on
Returns:
Set[torch.device]: set of devices used by this model
"""
state_devices = self._device_dict()
devices = set(state_devices.values())
if hasattr(self, 'device_ids'):
# Handle data parallel
for _id in self.device_ids:
devices.add(torch.device(_id))
return devices
@property
def main_device(self):
"""
The main/src torch device used by this model
"""
if hasattr(self, 'src_device_obj'):
return self.src_device_obj
else:
devices = self.devices()
if len(devices) > 1:
raise NotImplementedError('no information maintained on which device is primary')
else:
return list(devices)[0]
[docs]
class MSIDemoDataMixin:
[docs]
@classmethod
def demo_dataset_stats(cls):
"""
Mock data that mimiks a dataset summary a kwcoco dataloader could
provide.
"""
channels = kwcoco.ChannelSpec.coerce('pan,red|green|blue,nir|swir16|swir22')
unique_sensor_modes = {
('sensor1', 'pan'),
('sensor1', 'red|green|blue'),
('sensor1', 'nir|swir16|swir22'),
}
input_stats = {k: {
'mean': np.random.rand(len(k[1].split('|')), 1, 1),
'std': np.random.rand(len(k[1].split('|')), 1, 1),
} for k in unique_sensor_modes}
classes = kwcoco.CategoryTree.coerce(3)
dataset_stats = {
'unique_sensor_modes': unique_sensor_modes,
'input_stats': input_stats,
'class_freq': {c: np.random.randint(0, 10000) for c in classes},
}
return channels, classes, dataset_stats
[docs]
def demo_batch(self, batch_size=1, num_timesteps=3, width=8, height=8, nans=0, rng=None, new_mode_sample=0):
"""
Example:
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA
>>> from geowatch.utils.util_netharn import _debug_inbatch_shapes
>>> channels, clases, dataset_stats = MultimodalTransformer.demo_dataset_stats()
>>> self = MultimodalTransformer(
>>> arch_name='smt_it_stm_p1', tokenizer='linconv',
>>> decoder='mlp', classes=clases, global_saliency_weight=1,
>>> dataset_stats=dataset_stats, input_sensorchan=channels)
>>> batch = self.demo_batch()
>>> if 1:
>>> print(_debug_inbatch_shapes(batch))
>>> result = self.forward_step(batch)
>>> if 1:
>>> print(_debug_inbatch_shapes(batch))
Example:
>>> # With nans
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA
>>> from geowatch.utils.util_netharn import _debug_inbatch_shapes
>>> channels, clases, dataset_stats = MultimodalTransformer.demo_dataset_stats()
>>> self = MultimodalTransformer(
>>> arch_name='smt_it_stm_p1', tokenizer='linconv',
>>> decoder='mlp', classes=clases, global_saliency_weight=1,
>>> dataset_stats=dataset_stats, input_sensorchan=channels)
>>> batch = self.demo_batch(nans=0.5, num_timesteps=2)
>>> item = batch[0]
>>> if 1:
>>> print(_debug_inbatch_shapes(batch))
>>> result1 = self.forward_step(batch)
>>> result2 = self.forward_step(batch, with_loss=0)
>>> if 1:
>>> print(_debug_inbatch_shapes(result1))
>>> print(_debug_inbatch_shapes(result2))
Example:
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA
>>> from geowatch.utils.util_netharn import _debug_inbatch_shapes
>>> channels, clases, dataset_stats = MultimodalTransformer.demo_dataset_stats()
>>> self = MultimodalTransformer(
>>> arch_name='smt_it_stm_p1', tokenizer='linconv',
>>> decoder='mlp', classes=clases, global_saliency_weight=1,
>>> dataset_stats=dataset_stats, input_sensorchan=channels)
>>> batch = self.demo_batch(new_mode_sample=1)
>>> print(_debug_inbatch_shapes(batch))
"""
import kwarray
from kwarray import distributions
import kwimage
def _specific_coerce(val, rng=None):
# Coerce for what we want to do here,
import numbers
if isinstance(val, numbers.Integral):
distri = distributions.Constant(val, rng=rng)
elif isinstance(val, (tuple, list)) and len(val) == 2:
low, high = val
distri = distributions.DiscreteUniform(low, high, rng=rng)
else:
raise TypeError(val)
return distri
rng = kwarray.ensure_rng(rng)
B = batch_size
C = len(self.classes)
T = num_timesteps
batch = []
width_distri = _specific_coerce(width, rng=rng)
height_distri = _specific_coerce(height, rng=rng)
sensor_to_modes = ub.ddict(list)
for sensor, mode in self.dataset_stats['unique_sensor_modes']:
sensor_to_modes[sensor].append(mode)
import itertools as it
import kwcoco
# sensor_mode_iter = it.cycle(self.dataset_stats['unique_sensor_modes'])
sensor_iter = it.cycle(sensor_to_modes.keys())
# OLD_MODES = 0
for bx in range(B):
modes = []
frames = []
for time_index in range(T):
# Sample output target shape
H0 = height_distri.sample()
W0 = width_distri.sample()
# Sample input shapes
H1 = height_distri.sample()
W1 = width_distri.sample()
H2 = height_distri.sample()
W2 = width_distri.sample()
if new_mode_sample:
sensor = next(sensor_iter)
modes = {}
H, W = H0, W0
for mode in sensor_to_modes[sensor]:
size = kwcoco.FusedChannelSpec.coerce(mode).numel()
modes[mode] = rng.rand(size, H, W).astype("float32")
H = height_distri.sample()
W = width_distri.sample()
else:
sensor = 'sensor1'
modes = {
'pan': rng.rand(1, H0, W0).astype("float32"),
'red|green|blue': rng.rand(3, H1, W1).astype("float32"),
'nir|swir16|swir22': rng.rand(3, H2, W2).astype("float32"),
}
# Add in channels the model exepcts
for stream in self.input_sensorchan.streams():
C = stream.chans.numel()
modes[stream.chans.spec] = rng.rand(C, H1, W1).astype("float32")
frame = {}
if time_index == 0:
frame['change'] = None
frame['change_weights'] = None
frame['change_output_dims'] = None
else:
frame['change'] = rng.randint(low=0, high=1, size=(H0, W0))
frame['change_weights'] = rng.rand(H0, W0)
frame['change_output_dims'] = (H0, W0)
# TODO: allow class-ohe xor class-idxs to be missing
frame['class_idxs'] = rng.randint(low=0, high=C - 1, size=(H0, W0))
frame['class_ohe'] = np.eye(C)[np.random.choice(C, H0 * W0)].reshape(H0, W0, C)
frame['class_weights'] = rng.rand(H0, W0)
frame['class_output_dims'] = (H0, W0)
frame['saliency'] = rng.randint(low=0, high=1, size=(H0, W0))
frame['saliency_weights'] = rng.rand(H0, W0)
frame['saliency_output_dims'] = (H0, W0)
frame['box_ltrb'] = kwimage.Boxes.random(10).to_ltrb().tensor().data
frame['date_captured'] = '',
frame['gid'] = bx
frame['sensor'] = sensor
frame['time_index'] = bx
frame['time_offset'] = np.array([1]),
frame['timestamp'] = 1
frame['modes'] = modes
# specify the desired predicted output size for this frame
frame['output_dims'] = (H0, W0)
if nans:
for v in modes.values():
flags = rng.rand(*v.shape[1:]) < nans
v[:, flags] = float('nan')
if time_index > 0:
frame['change_weights'][flags] = 0
frame['class_weights'][flags] = 0
frame['saliency_weights'][flags] = 0
for k in ['change', 'change_weights', 'class_idxs', 'class_ohe',
'class_weights', 'saliency', 'saliency_weights']:
v = frame[k]
if v is not None:
frame[k] = torch.from_numpy(v)
for k in modes.keys():
v = modes[k]
if v is not None:
modes[k] = torch.from_numpy(v)
frames.append(frame)
positional_tensors = {
'mode_tensor': torch.rand(T, 16),
'sensor': torch.rand(T, 16),
'time_index': torch.rand(T, 8),
'time_offset': torch.rand(T, 1),
}
target = {
'gids': list(range(T)),
'space_slice': [
slice(0, H0),
slice(0, W0),
]
}
item = {
'video_id': 3,
'video_name': 'toy_video_3',
'frames': frames,
'positional_tensors': positional_tensors,
'target': target,
}
batch.append(item)
return batch
[docs]
class LightningModelMixin:
@property
def has_trainer(self):
try:
# Lightning 1.7 raises an attribute error if not attached
return self.trainer is not None
except RuntimeError:
return False
[docs]
class OverfitMixin:
[docs]
def overfit(self, batch):
"""
Overfit script and demo
CommandLine:
python -m xdoctest -m geowatch.tasks.fusion.methods.channelwise_transformer MultimodalTransformer.overfit --overfit-demo
Example:
>>> # xdoctest: +REQUIRES(--overfit-demo)
>>> # ============
>>> # DEMO OVERFIT:
>>> # ============
>>> from geowatch.tasks.fusion.methods.heterogeneous import * # NOQA
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> from geowatch.utils.util_data import find_dvc_dpath
>>> import geowatch
>>> import kwcoco
>>> from os.path import join
>>> import os
>>> if 0:
>>> '''
>>> # Generate toy datasets
>>> DATA_DPATH=$HOME/data/work/toy_change
>>> TRAIN_FPATH=$DATA_DPATH/vidshapes_msi_train/data.kwcoco.json
>>> mkdir -p "$DATA_DPATH"
>>> kwcoco toydata --key=vidshapes-videos8-frames5-randgsize-speed0.2-msi-multisensor --bundle_dpath "$DATA_DPATH/vidshapes_msi_train" --verbose=5
>>> '''
>>> coco_fpath = ub.expandpath('$HOME/data/work/toy_change/vidshapes_msi_train/data.kwcoco.json')
>>> coco_dset = kwcoco.CocoDataset.coerce(coco_fpath)
>>> channels="B11,r|g|b,B1|B8|B11"
>>> if 1:
>>> dvc_dpath = geowatch.find_dvc_dpath(tags='phase3_data', hardware='auto')
>>> coco_dset = (dvc_dpath / 'Drop8-ARA-Median10GSD-V1') / 'KR_R001/imganns-KR_R001-rawbands.kwcoco.zip'
>>> channels='blue|green|red|nir'
>>> if 0:
>>> coco_dset = geowatch.demo.demo_kwcoco_multisensor(max_speed=0.5)
>>> # coco_dset = 'special:vidshapes8-frames9-speed0.5-multispectral'
>>> #channels='B1|B11|B8|r|g|b|gauss'
>>> channels='X.2|Y:2:6,B1|B8|B8a|B10|B11,r|g|b,disparity|gauss,flowx|flowy|distri'
>>> coco_dset = kwcoco.CocoDataset.coerce(coco_dset)
>>> datamodule = datamodules.KWCocoVideoDataModule(
>>> train_dataset=coco_dset,
>>> chip_size=128, batch_size=1, time_steps=5,
>>> channels=channels,
>>> normalize_peritem='blue|green|red|nir',
>>> normalize_inputs=32, neg_to_pos_ratio=0,
>>> num_workers='avail/2',
>>> mask_low_quality=True,
>>> observable_threshold=0.6,
>>> use_grid_positives=False, use_centered_positives=True,
>>> )
>>> datamodule.setup('fit')
>>> dataset = torch_dset = datamodule.torch_datasets['train']
>>> torch_dset.disable_augmenter = True
>>> dataset_stats = datamodule.dataset_stats
>>> input_sensorchan = datamodule.input_sensorchan
>>> classes = datamodule.classes
>>> print('dataset_stats = {}'.format(ub.urepr(dataset_stats, nl=3)))
>>> print('input_sensorchan = {}'.format(input_sensorchan))
>>> print('classes = {}'.format(classes))
>>> # Choose subclass to test this with (does not cover all cases)
>>> self = methods.HeterogeneousModel(
>>> classes=classes,
>>> dataset_stats=dataset_stats,
>>> input_sensorchan=channels,
>>> #token_dim=708,
>>> #token_dim=768 - 60,
>>> #backbone='vit_B_16_imagenet1k',
>>> #token_dim=208,
>>> #backbone='sits-former',
>>> )
>>> # Choose subclass to test this with (does not cover all cases)
>>> self = methods.MultimodalTransformer(
>>> classes=classes,
>>> dataset_stats=dataset_stats,
>>> input_sensorchan=channels,
>>> #token_dim=708,
>>> #token_dim=768 - 60,
>>> #backbone='vit_B_16_imagenet1k',
>>> #token_dim=208,
>>> #backbone='sits-former',
>>> )
>>> self.datamodule = datamodule
>>> datamodule._notify_about_tasks(model=self)
>>> # Run one visualization
>>> loader = datamodule.train_dataloader()
>>> # Load one batch and show it before we do anything
>>> batch = next(iter(loader))
>>> print(ub.urepr(dataset.summarize_item(batch[0]), nl=3))
>>> import kwplot
>>> plt = kwplot.autoplt(force='Qt5Agg')
>>> plt.ion()
>>> canvas = datamodule.draw_batch(batch, max_channels=5, overlay_on_image=0)
>>> kwplot.imshow(canvas, fnum=1)
>>> # Run overfit
>>> device = 0
>>> self.overfit(batch)
"""
import kwplot
# import torch_optimizer
import xdev
import kwimage
import pandas as pd
from torch.optim import AdamW
from kwutil.slugify_ext import smart_truncate
from kwplot.mpl_make import render_figure_to_image
sns = kwplot.autosns()
datamodule = self.datamodule
device = 0
self = self.to(device)
# loader = datamodule.train_dataloader()
# batch = next(iter(loader))
walker = ub.IndexableWalker(batch)
for path, val in walker:
if isinstance(val, torch.Tensor):
walker[path] = val.to(device)
outputs = self.training_step(batch)
max_channels = 3
canvas = datamodule.draw_batch(batch, outputs=outputs, max_channels=max_channels, overlay_on_image=0)
kwplot.imshow(canvas)
loss_records = []
loss_records = [g[0] for g in ub.group_items(loss_records, lambda x: x['step']).values()]
step = 0
_frame_idx = 0
# dpath = ub.ensuredir('_overfit_viz09')
optimizer = AdamW(lr=3e-4, weight_decay=3e-6, params=self.parameters())
# optim = torch_optimizer.RAdam(self.parameters(), lr=self.learning_rate, weight_decay=1e-5)
fnum = 2
fig = kwplot.figure(fnum=fnum, doclf=True)
fig.set_size_inches(15, 6)
fig.subplots_adjust(left=0.05, top=0.9)
prev = None
for _frame_idx in xdev.InteractiveIter(list(range(_frame_idx + 1, 1000))):
# for _frame_idx in list(range(_frame_idx, 1000)):
num_steps = 20
ex = None
for _i in ub.ProgIter(range(num_steps), desc='overfit'):
optimizer.zero_grad()
outputs = self.training_step(batch)
# outputs['item_losses']
loss = outputs['loss']
if torch.any(torch.isnan(loss)):
print('NAN OUTPUT!!!')
print('loss = {!r}'.format(loss))
print('prev = {!r}'.format(prev))
ex = Exception('prev = {!r}'.format(prev))
break
# elif loss > 1e4:
# # Turn down the learning rate when loss gets huge
# scale = (loss / 1e4).detach()
# loss /= scale
prev = loss
# item_losses_ = default_collate(outputs['item_losses'])
# item_losses = ub.map_vals(lambda x: sum(x).item(), item_losses_)
loss.backward()
item_losses = {'loss': loss.detach().cpu().numpy().ravel().mean()}
loss_records.extend([{'part': key, 'val': val, 'step': step} for key, val in item_losses.items()])
optimizer.step()
step += 1
canvas = datamodule.draw_batch(batch, outputs=outputs, max_channels=max_channels, overlay_on_image=0, max_items=4)
kwplot.imshow(canvas, pnum=(1, 2, 1), fnum=fnum)
fig = kwplot.figure(fnum=fnum, pnum=(1, 2, 2))
#kwplot.imshow(canvas, pnum=(1, 2, 1))
ax = sns.lineplot(data=pd.DataFrame(loss_records), x='step', y='val', hue='part')
try:
ax.set_yscale('logit')
except Exception:
...
fig.suptitle(smart_truncate(str(optimizer).replace('\n', ''), max_length=64))
img = render_figure_to_image(fig)
img = kwimage.convert_colorspace(img, src_space='bgr', dst_space='rgb')
# fpath = join(dpath, 'frame_{:04d}.png'.format(_frame_idx))
#kwimage.imwrite(fpath, img)
xdev.InteractiveIter.draw()
if ex:
raise ex
# TODO: can we get this batch to update in real time?
# TODO: start a server process that listens for new images
# as it gets new images, it starts playing through the animation
# looping as needed
[docs]
class PackageMixin:
def _save_package(self, package_path, context=None, verbose=1):
"""
We define this as a protected method to allow modules to reuse the core
code, but force each module to define the ``save_package`` method
themselves with a doctest. In the future if this logic is general we
may remove that restriction and refactor tests to be part of unit
tests.
"""
# import copy
import json
import torch.package
# Fix an issue on 3.10 with torch 1.12
from geowatch.monkey import monkey_torch
monkey_torch.fix_package_modules()
# shallow copy of self, to apply attribute hacks to
# model = copy.copy(self)
model = self
backup_attributes = {}
# Remove attributes we don't want to pickle before we serialize
# then restore them
unsaved_attributes = [
'trainer',
'train_dataloader',
'val_dataloader',
'test_dataloader',
'_load_state_dict_pre_hooks', # lightning 1.5
'_trainer', # lightning 1.7
]
for key in unsaved_attributes:
try:
val = getattr(model, key, None)
except Exception:
val = None
if val is not None:
backup_attributes[key] = val
train_dpath_hint = getattr(model, 'train_dpath_hint', None)
if model.has_trainer:
if train_dpath_hint is None:
train_dpath_hint = model.trainer.log_dir
datamodule = model.trainer.datamodule
if datamodule is not None:
model.datamodule_hparams = datamodule.hparams
metadata_fpaths = []
if train_dpath_hint is not None:
train_dpath_hint = ub.Path(train_dpath_hint)
metadata_fpaths += list(train_dpath_hint.glob('hparams.yaml'))
metadata_fpaths += list(train_dpath_hint.glob('fit_config.yaml'))
metadata_fpaths += list(train_dpath_hint.glob('config.yaml'))
try:
for key in backup_attributes.keys():
setattr(model, key, None)
arch_name = 'model.pkl'
module_name = 'watch_tasks_fusion'
"""
exp = torch.package.PackageExporter(package_path, debug=True)
"""
import warnings
import sys
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
# with torch.package.PackageExporter(package_path) as exp:
with torch.package.PackageExporter(package_path) as exp:
# if True:
# TODO: this is not a problem yet, but some package types (mainly
# binaries) will need to be excluded and added as mocks
exp.extern('**', exclude=[
'geowatch.tasks.fusion.**',
'geowatch.tasks.fusion.methods.*'
])
# exp.intern('geowatch.tasks.fusion.methods.*', allow_empty=False)
exp.intern('geowatch.tasks.fusion.**', allow_empty=False)
# Attempt to standardize some form of package metadata that can
# allow for model importing with fewer hard-coding requirements
# TODO:
# Add information about how this was trained, and what epoch it
# was saved at.
package_header = {
'version': '0.4.0',
'arch_name': arch_name,
'module_name': module_name,
'packaging_time': ub.timestamp(),
'git_hash': None,
'module_path': None,
'torch_version': torch.__version__,
'python_version': sys.version,
'context': context,
}
# Encode a git hash if we can identify that we are in a git
# repository
try:
import os
module_path = ub.Path(ub.modname_to_modpath(self.__class__.__module__)).absolute()
package_header['module_path'] = os.fspath(module_path)
info = ub.cmd('git rev-parse --short HEAD', cwd=module_path.parent)
if info.returncode == 0:
package_header['git_hash'] = info.stdout.strip()
except Exception:
...
exp.save_text(
'package_header', 'package_header.json',
json.dumps(package_header)
)
exp.save_pickle(module_name, arch_name, model)
# Save metadata
for meta_fpath in metadata_fpaths:
with open(meta_fpath, 'r') as file:
text = file.read()
exp.save_text('package_header', meta_fpath.name, text)
finally:
# restore attributes
for key, val in backup_attributes.items():
setattr(model, key, val)
[docs]
@classmethod
def load_package(cls, package_path, verbose=1):
"""
DEPRECATE IN FAVOR OF geowatch.tasks.fusion.utils.load_model_from_package
TODO:
- [ ] Make the logic that defines the save_package and load_package
methods with appropriate package header data a lightning
abstraction.
"""
# NOTE: there is no gaurentee that this loads an instance of THIS
# model, the model is defined by the package and the tool that loads it
# is agnostic to the model contained in said package.
# This classmethod existing is a convinience more than anything else
from geowatch.tasks.fusion.utils import load_model_from_package
self = load_model_from_package(package_path)
return self
[docs]
class CoerceMixins:
def _coerce_class_weights(self, class_weights):
"""
Handles automatic class weighting based on dataset stats.
Args:
class_weights (str | FloatTensor):
If already a tensor does nothing. If the string "auto" then
class frequency weighting is used. The string "auto" can be
suffixed with a "class modulation code".
Note:
A class modulate code is a a special syntax that lets the user
modulate automatically computed class weights. Should be a comma
separated list of name*weight or name*weight+offset. E.g.
`auto:negative*0,background*0.001,No Activity*0.1+1`
Example:
>>> # xdoctest: +IGNORE_WANT
>>> from geowatch.tasks.fusion.methods.watch_module_mixins import * # NOQA
>>> self = WatchModuleMixins()
>>> self.classes = ['a', 'b', 'c', 'd', 'e']
>>> self.class_freq = {
>>> 'a': 100, 'b': 100, 'c': 100, 'd': 100, 'e': 100, 'f': 100,
>>> }
>>> self._coerce_class_weights('auto')
tensor([1., 1., 1., 1., 1.])
>>> self.class_freq = {
>>> 'a': 100, 'b': 100, 'c': 200, 'd': 300, 'e': 500, 'f': 800,
>>> }
>>> self._coerce_class_weights('auto')
tensor([1.0000, 1.0000, 0.5000, 0.3333, 0.2000])
>>> self._coerce_class_weights('auto:a+1,b*2,c*0+31415')
tensor([2.0000e+00, 2.0000e+00, 3.1415e+04, 3.3333e-01, 2.0000e-01])
"""
from geowatch import heuristics
hueristic_ignore_keys = heuristics.IGNORE_CLASSNAMES
if isinstance(class_weights, str):
if class_weights.startswith('auto'):
if ':' in class_weights:
class_weights, modulate_class_weights = class_weights.split(':')
else:
modulate_class_weights = None
class_weights.split(':')
if self.class_freq is None:
heuristic_weights = {}
else:
class_freq = ub.udict(self.class_freq) - hueristic_ignore_keys
total_freq = np.array(list(class_freq.values()))
cat_weights = _class_weights_from_freq(total_freq)
catnames = list(class_freq.keys())
print('total_freq = {!r}'.format(total_freq))
print('cat_weights = {!r}'.format(cat_weights))
print('catnames = {!r}'.format(catnames))
heuristic_weights = ub.dzip(catnames, cat_weights)
#print('heuristic_weights = {}'.format(ub.urepr(heuristic_weights, nl=1)))
heuristic_weights.update({k: 0 for k in hueristic_ignore_keys})
print('heuristic_weights = {}'.format(ub.urepr(heuristic_weights, nl=1, align=':')))
class_weights = []
for catname in self.classes:
w = heuristic_weights.get(catname, 1.0)
class_weights.append(w)
using_class_weights = ub.dzip(self.classes, class_weights)
# Add in user-specific modulation of the weights
if modulate_class_weights:
import re
parts = [p.strip() for p in modulate_class_weights.split(',')]
parts = [p for p in parts if p]
for part in parts:
toks = re.split('([+*])', part)
catname = toks[0]
rest_iter = iter(toks[1:])
weight = using_class_weights[catname]
nrhtoks = len(toks) - 1
assert nrhtoks % 2 == 0
nstmts = nrhtoks // 2
for _ in range(nstmts):
opcode = next(rest_iter)
arg = float(next(rest_iter))
if opcode == '*':
weight = weight * arg
elif opcode == '+':
weight = weight + arg
else:
raise KeyError(opcode)
# Modulate
using_class_weights[catname] = weight
print('using_class_weights = {}'.format(ub.urepr(using_class_weights, nl=1, align=':')))
class_weights = [
using_class_weights.get(catname, 1.0)
for catname in self.classes
if catname not in hueristic_ignore_keys
]
class_weights = torch.FloatTensor(class_weights)
else:
raise KeyError(class_weights)
else:
raise NotImplementedError(f'{class_weights!r}')
return class_weights
def _coerce_saliency_weights(self, saliency_weights):
"""
Finds weights to balance saliency forward / background classes.
Args:
saliency_weights (Tensor | str | None):
Can be None, a raw tensor, "auto", or a string ``"<bg>:<fg>"``.
Can also accept a YAML mapping from the keys "bg" and "fg" to
their respective float weights, e.g. ``"{fg: 1, bg: 2}"``.
Returns:
Tensor:
A length-2 tensor where index 0 contains the background weight
and index 1 contains the foreground weight.
CommandLine:
xdoctest -m geowatch.tasks.fusion.methods.watch_module_mixins WatchModuleMixins._coerce_saliency_weights
Example:
>>> # xdoctest: +IGNORE_WANT
>>> from geowatch.tasks.fusion.methods.watch_module_mixins import * # NOQA
>>> self = WatchModuleMixins()
>>> self.saliency_num_classes = 2
>>> self.background_classes = ['a', 'b', 'c']
>>> self.foreground_classes = ['d', 'e']
>>> self.class_freq = {
>>> 'a': 100, 'b': 100, 'c': 100, 'd': 100, 'e': 100, 'f': 100,
>>> }
>>> self._coerce_saliency_weights('auto')
tensor([1.0000, 1.4925])
>>> self.background_classes = ['a', 'b', 'c']
>>> self.foreground_classes = []
>>> self._coerce_saliency_weights('auto')
tensor([ 1., 300.])
>>> self.background_classes = []
>>> self.foreground_classes = []
>>> self._coerce_saliency_weights('auto')
tensor([1., 0.])
>>> self._coerce_saliency_weights('2:1')
tensor([2., 1.])
>>> self._coerce_saliency_weights('70:20')
tensor([70., 20.])
>>> self._coerce_saliency_weights('{fg: 1, bg: 2}')
tensor([2., 1.])
>>> self._coerce_saliency_weights({'fg': 1, 'bg': 2})
tensor([2., 1.])
>>> self._coerce_saliency_weights(ub.codeblock(
'''
foreground: 3
background: 5
'''))
tensor([5., 3.])
>>> import pytest
>>> with pytest.raises(ValueError):
>>> self._coerce_saliency_weights(ub.codeblock(
'''
foreground2: 3
background: 5
'''))
>>> with pytest.raises(Exception):
>>> self._coerce_saliency_weights(123)
"""
from kwutil.util_yaml import Yaml
saliency_weights_ = Yaml.coerce(saliency_weights)
# Provide aliases for the primary keys "fg" and "bg".
# This allows the user to user more explicit aliases. This also allows
# us to change the internal representation while maintaining backwards
# compatability.
primary_key_to_aliases = {
'bg': ['background'],
'fg': ['foreground'],
}
if isinstance(saliency_weights_, dict):
for primary_key, aliases in primary_key_to_aliases.items():
for alias_key in aliases:
if alias_key in saliency_weights_:
saliency_weights_[primary_key] = saliency_weights_.pop(alias_key)
try:
if saliency_weights_ is None:
saliency_weights_ = {'bg': 1.0, 'fg': 1.0}
elif isinstance(saliency_weights_, str):
if saliency_weights_ == 'auto':
class_freq = self.class_freq
saliency_weights_ = {'bg': 1.0, 'fg': 1.0}
if class_freq is not None:
print(f'class_freq={class_freq}')
bg_freq = sum(class_freq.get(k, 0) for k in self.background_classes)
fg_freq = sum(class_freq.get(k, 0) for k in self.foreground_classes)
saliency_weights_['fg'] = bg_freq / (fg_freq + 1)
elif saliency_weights_.lower() in {'null', 'none'}:
saliency_weights_ = {'bg': 1.0, 'fg': 1.0}
else:
bg_weight, fg_weight = saliency_weights_.split(':')
saliency_weights_ = {
'bg': float(bg_weight.strip()),
'fg': float(fg_weight.strip()),
}
elif isinstance(saliency_weights_, dict):
...
else:
raise TypeError('Unexpected saliency weights type')
except Exception as ex:
from kwutil.util_exception import add_exception_note
notes = f'saliency_weights : {type(saliency_weights)} = {saliency_weights!r}'
new_ex = add_exception_note(ex, notes)
raise new_ex
print(f'saliency_weights_ = {ub.urepr(saliency_weights_, nl=1)}')
if set(primary_key_to_aliases) != set(saliency_weights_):
raise ValueError(ub.paragraph(
'''
Saliency weights must contain values for each primary key or
use one of the known aliases. Valid primary keys and their
aliases are: {primary_key_to_aliases!r}. The user input for
saliency_weights={saliency_weights!r} and this was interpreted
as: {saliency_weights_!r}.
'''))
bg_fg_weights = [
saliency_weights_.get('bg', 1.0),
saliency_weights_.get('fg', 1.0)
]
# What is the motivation for having "saliency_num_classes" be not 2?
_n = self.saliency_num_classes
_w = bg_fg_weights + ([0.0] * (_n - len(bg_fg_weights)))
final_saliency_weights = torch.Tensor(_w)
return final_saliency_weights
[docs]
class DatasetStatsMixin:
[docs]
def set_dataset_specific_attributes(self, input_sensorchan, dataset_stats):
"""
Set module attributes based on dataset stats it will be trained on.
Args:
input_sensorchan (str | kwcoco.SensorchanSpec | None):
The input sensor channels the model should expect
dataset_stats (Dict | None):
See :func:`demo_dataset_stats` for an example of this structure
Returns:
None | Dict: input_stats
The following attributes will be set after calling this method.
* self.class_freq
* self.dataset_stats
* self.input_sensorchan
* self.unique_sensor_modes
We also return an ``input_stats`` variable which should be used for
setting model-dependent handling of input normalization.
The handling of dataset_stats and input_sensorchan are weirdly coupled
for legacy reasons and duplicated across several modules. This is a
common location for that code to allow it to be more easily refactored
and simplified at a later date.
"""
if dataset_stats is not None:
input_stats = dataset_stats['input_stats']
class_freq = dataset_stats['class_freq']
if input_sensorchan is None:
input_sensorchan = ','.join(
[f'{s}:{c}' for s, c in dataset_stats['unique_sensor_modes']])
else:
class_freq = None
input_stats = None
# Handle channel-wise input mean/std in the network (This is in
# contrast to common practice where it is done in the dataloader)
if input_sensorchan is None:
raise Exception(
'need to specify input_sensorchan at least as the number of '
f'input channels. Got input_sensorchan={input_sensorchan!r} and '
f'dataset_stats={dataset_stats!r}')
input_sensorchan = kwcoco.SensorChanSpec.coerce(input_sensorchan)
if dataset_stats is None:
# Handle the case where we know what the input streams are, but not
# what their statistics are.
input_stats = None
unique_sensor_modes = {
(s.sensor.spec, s.chans.spec)
for s in input_sensorchan.streams()
}
else:
unique_sensor_modes = dataset_stats['unique_sensor_modes']
if True:
# hack None to '*'
_fixed_sensor_modes = []
for k in unique_sensor_modes:
if isinstance(k, str):
if k == '*':
s = c = '*'
else:
raise AssertionError
else:
s, c = k
if s is None:
s = '*'
_fixed_sensor_modes.append((s, c))
unique_sensor_modes = _fixed_sensor_modes
if input_stats is not None:
_fixed_input_stats = {}
for k, v in input_stats.items():
s, c = k
if s is None:
s = '*'
_fixed_input_stats[(s, c)] = v
input_stats = _fixed_input_stats
self.class_freq = class_freq
self.dataset_stats = dataset_stats
self.unique_sensor_modes = unique_sensor_modes
self.input_sensorchan = input_sensorchan
return input_stats
[docs]
class WatchModuleMixins(MSIDemoDataMixin, ExtendTorchMixin,
LightningModelMixin, CoerceMixins, PackageMixin,
OverfitMixin, DatasetStatsMixin, DeprecatedMixin):
"""
Mixin methods for geowatch lightning modules
"""