"""
Defines a lightning DataModule for kwcoco video data.
The parameters to each are handled by scriptconfig objects, which prevents us
from needing to specify what the available options are in multiple places.
"""
import kwcoco
import kwimage
# import ndsampler
import pytorch_lightning as pl
import ubelt as ub
import scriptconfig as scfg
from kwcoco_dataloader.utils import util_globals
from kwutil import util_parallel
from kwcoco_dataloader import heuristics
from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_dataset import KWCocoVideoDatasetConfig, KWCocoVideoDataset
from kwcoco_dataloader.tasks.fusion.datamodules.batch_visualization import _memo_legend
from typing import Dict
try:
import xdev
profile = xdev.profile
except Exception:
profile = ub.identity
class KWCocoVideoDataModuleConfig(KWCocoVideoDatasetConfig):
"""
These are the argument accepted by the KWCocoDataModule.
The scriptconfig class is not used directly as it normally would be here.
Instead we use it as a convinience to minimize lightning boilerplate later
when it constructs its own argparse object, and for handling arguments
passed directly to the KWCocoDataModule
In the future this might be convertable to, or handled by omegaconfig
"""
train_dataset = scfg.Value(None, help='path to the train kwcoco file', group='datasets')
vali_dataset = scfg.Value(None, help='path to the validation kwcoco file', group='datasets')
test_dataset = scfg.Value(None, help='path to the test kwcoco file', group='datasets')
batch_size = scfg.Value(4, type=int, help=None)
pin_memory = scfg.Value(True, isflag=True, type=bool, help=ub.paragraph(
'''
Can increase speed, but is potentially unstable. For details,
see https://pytorch.org/docs/stable/data.html#memory-pinning
'''
))
normalize_inputs = scfg.Value(True, help=ub.paragraph(
'''
if True, computes the mean/std for this dataset on each mode
so this can be passed to the model.
If set to a number it will only draw that many samples to estimate
the mean/std.
'''))
num_workers = scfg.Value(4, type=str, alias=['workers'], help=ub.paragraph(
'''
number of background workers. Can be auto or an avail
expression.
'''))
request_rlimit_nofile = scfg.Value('auto', help=ub.paragraph(
'''
As a convinience, on Linux systems this automatically requests that
ulimit raises the maximum number of open files allowed. Auto currently
simply sets this to 8192, so use a number higher than this if you run
into too many open file errors, or set your ulimit explicitly before
running this software.
'''), group='resources')
torch_sharing_strategy = scfg.Value('default', help=ub.paragraph(
'''
Torch multiprocessing sharing strategy. Can be 'default',
"file_descriptor", "file_system". On linux, the default is
"file_descriptor". See https://pytorch.org/docs/stable/multi
processing.html#sharing-strategies for descriptions of
options. When using sqlview=True, using "file_system" can
help prevent the "received 0 items of ancdata" Error. It is
unclear why using "file_descriptor" fails in this case for
some datasets.
'''), group='resources')
torch_start_method = scfg.Value('default', help=ub.paragraph(
'''
Torch multiprocessing sharing strategy. Can be "default",
"fork", "spawn", "forkserver". The default method on Linux
is "spawn".
'''), group='resources')
sampler_backend = scfg.Value(None, help=ub.paragraph(
'''
Can be None, 'npy', or 'cog'.
'''))
test_with_annot_info = scfg.Value(False, isflag=1, help=ub.paragraph(
'''
If True, the test dataset is allowed to use annotations to refine the
sampling. This is useful at predict time for drawing batches.
'''))
sqlview = scfg.Value(False, help=ub.paragraph(
'''
If False, reads the COCO dataset as a json file. Otherwise
it can be "sqlite" or "postgresql" to cache json file in an SQL
database for faster responce times and lower memory
footprint.
'''))
[docs]
class KWCocoVideoDataModule(pl.LightningDataModule):
"""
Prepare the kwcoco dataset as torch video datamodules
Example:
>>> # Demo of the data module on auto-generated toy data
>>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_datamodule import * # NOQA
>>> import kwcoco_dataloader
>>> import kwcoco
>>> coco_dset = kwcoco_dataloader.coerce_kwcoco('vidshapes8-kwcoco_dataloader')
>>> channels = None
>>> batch_size = 1
>>> time_steps = 3
>>> chip_size = 416
>>> self = KWCocoVideoDataModule(
>>> train_dataset=coco_dset,
>>> test_dataset=None,
>>> batch_size=batch_size,
>>> normalize_inputs=8,
>>> channels=channels,
>>> num_workers=0,
>>> time_steps=time_steps,
>>> chip_size=chip_size,
>>> neg_to_pos_ratio=0,
>>> )
>>> self.setup('fit')
>>> dl = self.train_dataloader()
>>> dataset = dl.dataset
>>> batch = next(iter(dl))
>>> batch = [dl.dataset[0]]
>>> # Visualize
>>> canvas = self.draw_batch(batch)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(canvas)
>>> kwplot.show_if_requested()
Example:
>>> # xdoctest: +REQUIRES(env:DVC_DPATH)
>>> # Run the following tests on real kwcoco_dataloader data if DVC is available
>>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_datamodule import * # NOQA
>>> import kwcoco_dataloader
>>> import kwcoco
>>> dvc_dpath = kwcoco_dataloader.find_dvc_dpath()
>>> coco_fpath = dvc_dpath / 'Drop2-Aligned-TA1-2022-02-15/combo_ILM.kwcoco.json'
>>> #coco_fpath = dvc_dpath / 'Aligned-Drop2-TA1-2022-03-07/combo_DILM.kwcoco.json'
>>> #coco_fpath = dvc_dpath / 'Drop2-Aligned-TA1-2022-02-15/combo_DILM.kwcoco.json'
>>> dset = kwcoco.CocoDataset(coco_fpath)
>>> images = dset.images()
>>> train_dataset = dset
>>> #sub_images = dset.videos(names=['KR_R002']).images[0]
>>> #train_dataset = dset.subset(sub_images.lookup('id'))
>>> test_dataset = None
>>> img = ub.peek(train_dataset.imgs.values())
>>> chan_info = kwcoco_extensions.coco_channel_stats(dset)
>>> #channels = chan_info['common_channels']
>>> channels = 'blue|green|red|nir|swir16|swir22,forest|bare_ground,matseg_0|matseg_1|matseg_2,invariants.0:3,cloudmask'
>>> #channels = 'blue|green|red|depth'
>>> #channels = None
>>> #
>>> batch_size = 1
>>> time_steps = 8
>>> chip_size = 512
>>> datamodule = KWCocoVideoDataModule(
>>> train_dataset=train_dataset,
>>> test_dataset=test_dataset,
>>> batch_size=batch_size,
>>> channels=channels,
>>> num_workers=0,
>>> normalize_inputs=8,
>>> time_steps=time_steps,
>>> chip_size=chip_size,
>>> neg_to_pos_ratio=0,
>>> min_spacetime_weight=0.5,
>>> )
>>> datamodule.setup('fit')
>>> dl = datamodule.train_dataloader()
>>> dataset = dl.dataset
>>> dataset.requested_tasks['change'] = False
>>> dataset.disable_augmenter = True
>>> target = 0
>>> item, *_ = batch = [dataset[target]]
>>> #item, *_ = batch = next(iter(dl))
>>> # Visualize
>>> canvas = datamodule.draw_batch(batch)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(canvas, doclf=1)
>>> kwplot.show_if_requested()
Example:
>>> # xdoctest: +SKIP
>>> # NOTE: I DONT KNOW WHY THIS IS FAILING ON CI AT THE MOMENT. FIXME!
>>> # Run the data module on coco demo datamodules for the CI
>>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_datamodule import * # NOQA
>>> import kwcoco
>>> import delayed_image
>>> train_dataset = kwcoco.CocoDataset.demo('vidshapes2-multispectral', num_frames=5)
>>> test_dataset = kwcoco.CocoDataset.demo('vidshapes1-multispectral', num_frames=5)
>>> channels = '|'.join([aux['channels'] for aux in train_dataset.imgs[1]['auxiliary']])
>>> chan_spec = delayed_image.channel_spec.FusedChannelSpec.coerce(channels)
>>> #
>>> batch_size = 2
>>> time_steps = 3
>>> chip_size = 128
>>> channels = channels
>>> self = KWCocoVideoDataModule(
>>> train_dataset=train_dataset,
>>> test_dataset=test_dataset,
>>> batch_size=batch_size,
>>> channels=channels,
>>> num_workers=0,
>>> time_steps=time_steps,
>>> chip_size=chip_size,
>>> normalize_inputs=True,
>>> )
>>> self.setup('fit')
>>> dl = self.train_dataloader()
>>> item, *_ = batch = next(iter(dl))
>>> expect_shape = (batch_size, time_steps, len(chan_spec), chip_size, chip_size)
>>> assert len(batch) == batch_size
>>> for item in batch:
... assert len(item['frames']) == time_steps
... for mode_key, mode_val in item['frames'][0]['modes'].items():
... assert mode_val.shape[1:3] == (chip_size, chip_size)
"""
__scriptconfig__ = KWCocoVideoDataModuleConfig
def __init__(self, verbose=1, **kwargs):
"""
For details on accepted arguments see KWCocoVideoDataModuleConfig
"""
super().__init__()
self.verbose = verbose
self.config = KWCocoVideoDataModuleConfig(**kwargs)
cfgdict = self.config.to_dict()
self.save_hyperparameters(cfgdict)
# Backwards compatibility. Previous iterations had the
# config saved directly as datamodule arguments
self.__dict__.update(cfgdict)
self.train_kwcoco = self.config['train_dataset']
self.vali_kwcoco = self.config['vali_dataset']
self.test_kwcoco = self.config['test_dataset']
if ub.WIN32:
from kwutil import util_windows
self.train_kwcoco = util_windows.fix_msys_path(self.train_kwcoco)
self.vali_kwcoco = util_windows.fix_msys_path(self.vali_kwcoco)
self.test_kwcoco = util_windows.fix_msys_path(self.test_kwcoco)
common_keys = set(KWCocoVideoDatasetConfig.__default__.keys())
# Pass the relevant parts of the config to the underlying datasets
self.train_dataset_config = ub.dict_subset(cfgdict, common_keys)
# with small changes made for validation and test datasets.
self.vali_dataset_config = self.train_dataset_config.copy()
self.vali_dataset_config['chip_overlap'] = 0.0
# TODO: reconsider this hard-coded decision. It may bias our validation
# check towards too many false positives. That is what we want if we
# are having trouble there, but that setting should be configurable.
self.vali_dataset_config['neg_to_pos_ratio'] = 0.0
self.vali_dataset_config['use_grid_positives'] = True
self.vali_dataset_config['use_centered_positives'] = False
self.test_dataset_config = self.train_dataset_config.copy()
self.test_dataset_config['test_with_annot_info'] = self.config['test_with_annot_info']
self.num_workers = util_parallel.coerce_num_workers(cfgdict['num_workers'])
self.dataset_stats = None
# will only correspond to train
self.classes = None
self.predictable_classes = None
# self.input_channels = None
self.input_sensorchan = None
# Can we get rid of inject method?
# Unfortunately lightning seems to only enable / disables
# validation depending on the methods that are defined, so we are
# not able to statically define them.
ub.inject_method(self, lambda self: self._make_dataloader('train', shuffle=True, pin_memory=self.config['pin_memory']), 'train_dataloader')
# Store train / test / vali
self.torch_datasets: Dict[str, KWCocoVideoDataset] = {}
self.coco_datasets: Dict[str, kwcoco.CocoDataset] = {}
self.requested_tasks = None
self.did_setup = False
if self.verbose:
print('Init KWCocoVideoDataModule')
print('self.train_kwcoco = {!r}'.format(self.train_kwcoco))
print('self.vali_kwcoco = {!r}'.format(self.vali_kwcoco))
print('self.test_kwcoco = {!r}'.format(self.test_kwcoco))
print('self.input_sensorchan = {!r}'.format(self.input_sensorchan))
print('self.time_steps = {!r}'.format(self.time_steps))
print('self.chip_dims = {!r}'.format(self.chip_dims))
print('self.window_space_scale = {!r}'.format(self.window_space_scale))
print('self.input_space_scale = {!r}'.format(self.input_space_scale))
print('self.output_space_scale = {!r}'.format(self.output_space_scale))
print(f'self.num_workers={self.num_workers!r}')
[docs]
def setup(self, stage):
if self.did_setup:
print('datamodules are already setup. Ignoring extra setup call')
return
import kwcoco_dataloader
if self.verbose:
print('Setup DataModule: stage = {!r}'.format(stage))
util_globals.configure_global_attributes(**{
'num_workers': self.num_workers,
'torch_sharing_strategy': self.torch_sharing_strategy,
'torch_start_method': self.torch_start_method,
'request_rlimit_nofile': self.request_rlimit_nofile,
})
sqlview = self.config['sqlview']
# Clear existing coco datasets so a reload occurs (should never happen
# if the user doesnt touch `self.did_setup`).
self.coco_datasets.clear()
# make a temp mapping from train/vali/test to the specified coco inputs
_coco_inputs = {
'train': self.train_kwcoco,
'vali': self.vali_kwcoco,
'test': self.test_kwcoco,
}
def _read_kwcoco_split(_key):
"""
Quick and dirty helper originally used to debug an issue. Keeping
something similar to ensure train/test/vali kwcoco are read in the
same way.
This modifies the self.coco_datasets attribute.
"""
_coco_input = _coco_inputs[_key]
_coco_output = self.coco_datasets.get(_key, None)
if _coco_output is None and _coco_input is not None:
if self.verbose:
print(f'Read {_key} kwcoco dataset')
# Use the demo coerce function to read the kwcoco file because
# it allows for special demo inputs useful in doctests.
_coco_output = kwcoco_dataloader.coerce_kwcoco(_coco_input, sqlview=sqlview)
self.coco_datasets[_key] = _coco_output
return _coco_output
if stage in {'fit', 'train'} or stage is None:
train_coco_dset = _read_kwcoco_split('train')
self.coco_datasets['train'] = train_coco_dset
# HACK: load the validation kwcoco before we do any further
# processing.
_read_kwcoco_split('vali')
if self.verbose:
print('Build train kwcoco dataset')
print('self.exclude_sensors', self.exclude_sensors)
# coco_train_sampler = ndsampler.CocoSampler(train_coco_dset)
coco_train_sampler = train_coco_dset
train_dataset = KWCocoVideoDataset(
coco_train_sampler, mode='fit', **self.train_dataset_config,
)
self.classes = train_dataset.classes
self.predictable_classes = train_dataset.predictable_classes
self.torch_datasets['train'] = train_dataset
if self.input_sensorchan is None:
self.input_sensorchan = train_dataset.input_sensorchan
stats_params = {
'num': None,
'with_intensity': False,
'with_class': True,
'num_workers': self.num_workers,
'batch_size': self.batch_size,
}
if isinstance(self.prenormalize_inputs, list):
# The user specified normalization info
...
if self.normalize_inputs:
if isinstance(self.normalize_inputs, str):
if self.normalize_inputs == 'transfer':
# THIS MEANS WE EXPECT THAT WE CAN TRANSFER FROM AN
# EXISTING MODEL. THE FIT METHOD MUST HANDLE THIS
stats_params = None
else:
raise NotImplementedError(
'TODO: handle special normalization keys, '
'e.g. imagenet')
else:
if isinstance(self.normalize_inputs, int):
stats_params['num'] = self.normalize_inputs
else:
stats_params['num'] = None
else:
stats_params['with_intensity'] = False
# Hack for now:
# TODO: Note: also need for class weights
if stats_params is not None:
print(f'stats_params={stats_params}')
self.dataset_stats = train_dataset.cached_dataset_stats(**stats_params)
if self.vali_kwcoco is not None:
vali_coco_dset = _read_kwcoco_split('vali')
if self.verbose:
print('Build validation kwcoco dataset')
# vali_coco_sampler = ndsampler.CocoSampler(vali_coco_dset)
vali_coco_sampler = vali_coco_dset
vali_dataset = KWCocoVideoDataset(
vali_coco_sampler, mode='vali', **self.vali_dataset_config)
self.torch_datasets['vali'] = vali_dataset
ub.inject_method(self, lambda self: self._make_dataloader('vali', shuffle=False, pin_memory=self.config['pin_memory']), 'val_dataloader')
if stage == 'test' or stage is None:
test_coco_dset = _read_kwcoco_split('test')
if self.verbose:
print('Build test kwcoco dataset')
# test_coco_sampler = ndsampler.CocoSampler(test_coco_dset)
test_coco_sampler = test_coco_dset
self.coco_datasets['test'] = test_coco_dset
self.torch_datasets['test'] = KWCocoVideoDataset(
test_coco_sampler, mode='test', **self.test_dataset_config,
)
ub.inject_method(self, lambda self: self._make_dataloader('test', shuffle=False, pin_memory=self.config['pin_memory']), 'test_dataloader')
print('self.torch_datasets = {}'.format(ub.urepr(self.torch_datasets, nl=1)))
self._notify_about_tasks(self.requested_tasks)
self.did_setup = True
# Can we use these instead of inject method?
# def train_dataloader(self):
# return self._make_dataloader('train', shuffle=True)
# def val_dataloader(self):
# return self._make_dataloader('vali', shuffle=True)
# def test_dataloader(self):
# return self._make_dataloader('test', shuffle=True)
@property
def train_dataset(self):
return self.torch_datasets.get('train', None)
@property
def test_dataset(self):
return self.torch_datasets.get('test', None)
@property
def vali_dataset(self):
return self.torch_datasets.get('vali', None)
def _make_dataloader(self, stage, shuffle=False, pin_memory=True):
"""
If the stage doesn't exist, resturns None.
Returns:
torch.utils.data.DataLoader | None
"""
dataset = self.torch_datasets.get(stage, None)
if dataset is None:
return None
loader = dataset.make_loader(
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=shuffle,
pin_memory=pin_memory,
)
return loader
def _notify_about_tasks(self, requested_tasks=None, model=None, predictable_classes=None):
"""
Hacky method. Given the multimodal model, tell all the datasets which
tasks they will need to generate data for. (This helps make the
visualizations cleaner).
"""
if model is not None:
assert requested_tasks is None
if hasattr(model, 'global_head_weights'):
requested_tasks = {k: w > 0 for k, w in model.global_head_weights.items()}
elif hasattr(model, 'heads'):
# Experimental new style of notification
requested_tasks = {}
requested_tasks['saliency'] = False
requested_tasks['class'] = False
requested_tasks['change'] = False
requested_tasks['boxes'] = False
requested_tasks['nonlocal_class'] = False
requested_tasks.update({k: True for k in model.heads.keys()})
# TODO: handle per-head predictable classes.
if hasattr(model, 'predictable_classes'):
predictable_classes = model.predictable_classes
else:
import warnings
warnings.warn(ub.paragraph(
f'''
Model {model.__class__} does not have the structure needed
to notify the dataset about tasks. A better design to make
specifying tasks easier is needed without relying on the
``global_head_weights``.
'''))
print(f'datamodule notified: requested_tasks={requested_tasks} predictable_classes={predictable_classes}')
if requested_tasks is not None:
self.requested_tasks = requested_tasks
for dataset in self.torch_datasets.values():
dataset._notify_about_tasks(requested_tasks, predictable_classes=predictable_classes)
[docs]
@classmethod
def add_argparse_args(cls, parent_parser):
"""
Previously the arguments were in multiple places including here. This
has been updated to use the :class:`KWCocoVideoDataModuleConfig` as the
single point where arguments are defined. The functionality of this
method is roughly the same as it used to be given that scriptconfig
objects can be transformed into argparse objects.
CommandLine:
xdoctest -m /home/joncrall/code/watch/kwcoco_dataloader/tasks/fusion/datamodules/kwcoco_datamodule.py add_argparse_args
Example:
>>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_datamodule import * # NOQA
>>> cls = KWCocoVideoDataModule
>>> # TODO: make use of kwcoco_dataloader.utils.lightning_ext import argparse_ext
>>> import argparse
>>> parent_parser = argparse.ArgumentParser()
>>> cls.add_argparse_args(parent_parser)
>>> parent_parser.print_help()
>>> args, _ = parent_parser.parse_known_args(['--use_grid_positives=True'])
>>> assert args.use_grid_positives
>>> args, _ = parent_parser.parse_known_args(['--use_grid_positives=False'])
>>> assert not args.use_grid_positives
>>> args, _ = parent_parser.parse_known_args(['--exclude_sensors=l8,f3'])
>>> assert args.exclude_sensors == 'l8,f3'
>>> args, _ = parent_parser.parse_known_args(['--exclude_sensors=l8'])
>>> assert args.exclude_sensors == 'l8'
"""
# from functools import partial
parser = parent_parser.add_argument_group('kwcoco_datamodule')
config = KWCocoVideoDataModuleConfig()
config.argparse(parser)
return parent_parser
[docs]
@classmethod
def compatible(cls, cfgdict):
"""
Given keyword arguments, find the subset that is compatible with this
constructor. This is somewhat hacked because of usage of scriptconfig,
but could be made nicer by future updates.
"""
# init_kwargs = ub.compatible(config, cls.__init__)
import inspect
nameable_kinds = {inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY}
cls_sig = inspect.signature(cls)
explicit_argnames = [
argname for argname, argtype in cls_sig.parameters.items()
if argtype.kind in nameable_kinds
]
valid_argnames = explicit_argnames + list(KWCocoVideoDataModuleConfig.__default__.keys())
datamodule_vars = ub.dict_isect(cfgdict, valid_argnames)
return datamodule_vars
[docs]
def draw_batch(self, batch, stage='train', outputs=None, max_items=2,
overlay_on_image=False, classes=None, **kwargs):
r"""
Visualize a batch produced by a KWCocoVideoDataset.
Args:
batch (Dict[str, List[Tensor]]): dictionary of uncollated lists of Dataset Items
change: [ [T-1, H, W] \in [0, 1] \forall examples ]
saliency: [ [T, H, W, 2] \in [0, 1] \forall examples ]
class: [ [T, H, W, 10] \in [0, 1] \forall examples ]
outputs (Dict[str, Tensor]):
maybe-collated list of network outputs?
max_items (int):
Maximum number of items within this batch to draw in a single
figure. Defaults to 2.
overlay_on_image (bool):
if True overlay annotations on image data for a more compact
view. if False separate annotations / images for a less
cluttered view.
CommandLine:
xdoctest -m kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_datamodule KWCocoVideoDataModule.draw_batch
Example:
>>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_datamodule import * # NOQA
>>> from kwcoco_dataloader.tasks.fusion import datamodules
>>> self = datamodules.KWCocoVideoDataModule(
>>> train_dataset='special:vidshapes8-multispectral', channels='auto', num_workers=0)
>>> self.setup('fit')
>>> loader = self.train_dataloader()
>>> batch = next(iter(loader))
>>> item = batch[0]
>>> # Visualize
>>> B = len(batch)
>>> C, H, W = ub.peek(item['frames'][0]['modes'].values()).shape
>>> T = len(item['frames'])
>>> import torch
>>> outputs = {'change_probs': [torch.rand(T - 1, H, W) for _ in range(B)]}
>>> outputs.update({'class_probs': [torch.rand(T, H, W, 10) for _ in range(B)]})
>>> stage = 'train'
>>> canvas = self.draw_batch(batch, stage=stage, outputs=outputs)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(canvas)
>>> kwplot.show_if_requested()
Example:
>>> # xdoctest: +REQUIRES(--slow)
>>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_datamodule import * # NOQA
>>> from kwcoco_dataloader.tasks.fusion import datamodules
>>> import kwcoco_dataloader
>>> train_dataset = kwcoco_dataloader.demo.demo_kwcoco_multisensor()
>>> self = datamodules.KWCocoVideoDataModule(
>>> train_dataset=train_dataset, chip_size=256, time_steps=5, num_workers=0, batch_size=3)
>>> self.setup('fit')
>>> loader = self.train_dataloader()
>>> batch_iter = iter(loader)
>>> batch = next(batch_iter)
>>> batch[1] = None # simulate a dropped batch item
>>> batch[0] = None # simulate a dropped batch item
>>> #item = batch[0]
>>> # Visualize
>>> B = len(batch)
>>> outputs = {'change_probs': [], 'class_probs': [], 'saliency_probs': []}
>>> # Add dummy outputs
>>> import torch
>>> for item in batch:
>>> if item is None:
>>> [v.append([None]) for v in outputs.values()]
>>> else:
>>> [v.append([]) for v in outputs.values()]
>>> for frame_idx, frame in enumerate(item['frames']):
>>> H, W = frame['class_idxs'].shape
>>> if frame_idx > 0:
>>> outputs['change_probs'][-1].append(torch.rand(H, W))
>>> outputs['class_probs'][-1].append(torch.rand(H, W, 10))
>>> outputs['saliency_probs'][-1].append(torch.rand(H, W, 2))
>>> from kwcoco_dataloader.utils import util_nesting
>>> print(ub.urepr(util_nesting.shape_summary(outputs), nl=1, sort=0))
>>> stage = 'train'
>>> canvas = self.draw_batch(batch, stage=stage, outputs=outputs, max_items=4)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(canvas)
>>> kwplot.show_if_requested()
Example:
>>> from kwcoco_dataloader.tasks.fusion.datamodules.kwcoco_datamodule import * # NOQA
>>> from kwcoco_dataloader.tasks.fusion import datamodules
>>> self = datamodules.KWCocoVideoDataModule(
>>> batch_size = 3,
>>> train_dataset='special:vidshapes8-multispectral', channels='auto', num_workers=0)
>>> self.setup('fit')
>>> loader = self.train_dataloader()
>>> batch = next(iter(loader))
>>> batch[1] = None
>>> item = batch[0]
>>> # Visualize
>>> B = len(batch)
>>> C, H, W = ub.peek(item['frames'][0]['modes'].values()).shape
>>> T = len(item['frames'])
>>> import torch
>>> outputs = {'change_probs': [torch.rand(T - 1, H, W) for _ in range(B)]}
>>> outputs.update({'class_probs': [torch.rand(T, H, W, 10) for _ in range(B)]})
>>> stage = 'train'
>>> canvas = self.draw_batch(batch, stage=stage, outputs=outputs)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(canvas)
>>> kwplot.show_if_requested()
"""
dataset = self.torch_datasets[stage]
# Get the raw dataset class
while hasattr(dataset, 'dataset'):
dataset = dataset.dataset
# assume collation is disabled
batch_items = batch
if dataset.requested_tasks['nonlocal_class']:
# FIXME: HACK JUST TO MAKE CIFAR VISUALIZE,
# TODO: NEED TO BETTER CHANGE THE TYPE OF THE OUTPUTS AT THE
# LIGHTNING LEVEL
from kwcoco_dataloader.tasks.fusion.datamodules import network_io
outputs = network_io.CollatedNetworkOutputs(outputs)
DEBUG_INCOMING_DATA = 1
if DEBUG_INCOMING_DATA:
stats = {}
stats['batch_size'] = len(batch_items)
stats['num_None_batch_items'] = 0
for item_idx, item in enumerate(batch_items):
if item is None:
stats['num_None_batch_items'] += 1
KNOWN_HEADS = ['change_probs', 'class_probs', 'saliency_probs', 'nonlocal_class_probs', 'box']
canvas_list = []
for item_idx, item in zip(range(max_items), batch_items):
# HACK: I'm not sure how general accepting outputs is
# TODO: more generic handling of outputs.
# Should be able to accept
# - [ ] binary probability of change
# - [ ] fine-grained probability of change
# - [ ] per-frame semenatic segmentation
# - [ ] detections with box results!
if item is None:
continue
if outputs is not None:
# Extract outputs only for this specific batch item.
if hasattr(outputs, 'decollate'):
# print('experimental')
# Experimental new logic
# decollated_outputs = outputs.decollate()
# item_output = decollated_outputs['item_probs']
item_output = {}
# hack, because decollate is not playing nice
_nonlocal_probs = outputs['probs']['nonlocal_class_probs'][item_idx]
if _nonlocal_probs.ndim == 1:
# Add in a fake time dimension
_nonlocal_probs = [_nonlocal_probs]
item_output['nonlocal_class_probs'] = _nonlocal_probs
# print(f'_nonlocal_probs={_nonlocal_probs}')
else:
# Original logic for multimodal transformer
item_output = ub.AutoDict()
for head_key in KNOWN_HEADS:
if head_key in outputs:
item_output[head_key] = []
head_outputs = outputs[head_key]
head_item_output = head_outputs[item_idx]
if head_item_output is not None:
if head_key == 'box':
# Handle box head separately.
# TODO: Should the network handle this conversion?
box_ltrb = head_item_output['box_ltrb'].data.cpu().float().numpy()
box_probs = head_item_output['box_probs'].data.cpu().float().numpy()
for frame_box_ltrb, frame_box_probs in zip(box_ltrb, box_probs):
item_output[head_key].append({
'box_ltrb': frame_box_ltrb,
'box_probs': frame_box_probs
})
else:
# Handle original heatmap case
for frame_out in head_item_output:
item_output[head_key].append(frame_out.data.cpu().float().numpy())
else:
item_output[head_key].append(None)
else:
item_output = {}
part = dataset.draw_item(
item, item_output=item_output,
overlay_on_image=overlay_on_image, classes=classes, **kwargs)
canvas_list.append(part)
num_images = len(canvas_list)
if 1:
# Choose a sensible chunksize for the grid based on the input image
# aspect ratios
# TODO: could add this as a grid heuristic.
import numpy as np
hs = np.array([c.shape[0] for c in canvas_list])
ws = np.array([c.shape[1] for c in canvas_list])
h_majorness = hs > (ws * 1.2)
w_majorness = ws > (hs * 1.2)
if h_majorness.sum() >= w_majorness.sum():
majors, minors = hs, ws
stack_axis = 0
else:
majors, minors = ws, hs
stack_axis = 1
majors_per_minor = (majors / minors).mean()
# Not sure if this is quite right
chunksize = int(np.ceil(np.sqrt(majors_per_minor * num_images)))
"""
import sympy as sym
majors_per_minor, num_imgs = sym.symbols('majors_per_minor, num_imgs')
real_grid_major, real_grid_minor = sym.symbols('real_grid_w, real_grid_h')
ideal_grid_dim = sym.symbols('ideal_grid_dim')
sym.sqrt(num_imgs)
vars = (majors_per_minor, num_imgs, real_grid_major, real_grid_minor, ideal_grid_dim)
# TODO: get the system that solves for the number of images we
# stack across the minor dimension such that we roughly get a
# square image in the end.
equations = [
sym.Eq(ideal_grid_dim * ideal_grid_dim, num_imgs * majors_per_minor),
sym.Eq(majors_per_minor * ideal_grid_dim, real_grid_minor),
sym.Eq(ideal_grid_dim, real_grid_major),
]
print('equations = {}'.format(ub.urepr(equations, nl=1)))
from sympy import solve
solutions = solve(equations, *vars, dict=True)
print('solutions = {}'.format(ub.urepr(solutions, nl=2)))
solutions = solve(equations, real_grid_major, dict=True)
print('solutions = {}'.format(ub.urepr(solutions, nl=2)))
solutions = solve(equations, real_grid_minor, dict=True)
print('solutions = {}'.format(ub.urepr(solutions, nl=2)))
"""
else:
stack_axis = 1
chunksize = int(np.ceil(np.sqrt(num_images)))
canvas = kwimage.stack_images_grid(
canvas_list, chunksize=chunksize, axis=stack_axis, overlap=-12, bg_value=[64, 60, 60])
with_legend = self.requested_tasks is None or self.requested_tasks.get('class', True)
# with_legend = True
if with_legend:
if classes is None:
classes = dataset.classes
heuristics.category_tree_ensure_color(classes)
label_to_color = {
node: data['color']
for node, data in classes.graph.nodes.items()}
label_to_color = ub.sorted_keys(label_to_color)
legend_img = _memo_legend(label_to_color)
canvas = kwimage.stack_images([canvas, legend_img], axis=1)
return canvas
def _tmp(train_dataset):
import kwplot
label_to_color1 = {
node: data['color']
for node, data in train_dataset.classes.graph.nodes.items()}
label_to_color2 = {
node: data['color']
for node, data in train_dataset.predictable_classes.graph.nodes.items()}
legend_img1 = kwplot.make_legend_img(label_to_color1)
legend_img2 = kwplot.make_legend_img(label_to_color2)
kwplot.imshow(legend_img1, pnum=(1, 2, 1))
kwplot.imshow(legend_img2, pnum=(1, 2, 2))