geowatch.tasks.fusion.methods.channelwise_transformer module¶
Our data might look like this, a sequence of frames where the frames can contain heterogeneous data:
[
{
'frame_index': 0,
'time_offset': 0,
'sensor': 'S2',
'modes': {
'blue|green|red|swir1|swir2|nir': <Tensor shape=(6, 64, 64),
'pan': <Tensor shape=(1, 112, 112),
}
'truth': {
'class_idx': <Tensor shape=(5, 128, 128),
}
},
{
'frame_index': 1,
'time_offset': 100,
'sensor': 'L8',
'modes': {
'blue|green|red|lwir1|lwir2|nir': <Tensor shape=(6, 75, 75),
}
'truth': {
'class_idx': <Tensor shape=(5, 128, 128),
}
},
{
'frame_index': 2,
'time_offset': 120,
'sensor': 'S2',
'modes': {
'blue|green|red|swir1|swir2|nir': <Tensor shape=(6, 64, 64),
'pan': <Tensor shape=(1, 112, 112),
}
'truth': {
'class_idx': <Tensor shape=(5, 128, 128),
}
},
{
'frame_index': 3,
'time_offset': 130,
'sensor': 'WV',
'modes': {
'blue|green|red|nir': <Tensor shape=(4, 224, 224),
'pan': <Tensor shape=(1, 512, 512),
},
'truth': {
'class_idx': <Tensor shape=(5, 128, 128),
}
},
]
- class geowatch.tasks.fusion.methods.channelwise_transformer.MultimodalTransformerConfig(*args, **kwargs)[source]¶
Bases:
DataConfig
Arguments accepted by the MultimodalTransformer
The scriptconfig class is not used directly as it normally would be here. Instead we use it as a convinience to minimize lightning boilerplate needed for the __init__ and add_argparse_args methods.
Note, this does not entirely define the __init__ method, just the parameters that are exposed on the command line. An update to scriptconfig could allow that to be combined, but I’m not sure if its a good idea. The arguments not specified here are usually ones that the dataset must provide at definition time.
Valid options: []
- Parameters:
*args – positional arguments for this data config
**kwargs – keyword arguments for this data config
- default = {'arch_name': <Value('smt_it_joint_p8')>, 'attention_impl': <Value('exact')>, 'attention_kwargs': <Value(None)>, 'backbone_depth': <Value(None)>, 'change_head_hidden': <Value(2)>, 'change_loss': <Value('cce')>, 'class_head_hidden': <Value(2)>, 'class_loss': <Value('focal')>, 'class_weights': <Value('auto')>, 'continual_learning': <Value(False)>, 'decoder': <Value('mlp')>, 'decouple_resolution': <Value(False)>, 'dropout': <Value(0.1)>, 'focal_gamma': <Value(2.0)>, 'global_box_weight': <Value(0.0)>, 'global_change_weight': <Value(1.0)>, 'global_class_weight': <Value(1.0)>, 'global_saliency_weight': <Value(1.0)>, 'learning_rate': <Value(0.001)>, 'lr_scheduler': <Value('CosineAnnealingLR')>, 'modulate_class_weights': <Value('')>, 'multimodal_reduce': <Value('max')>, 'name': <Value('unnamed_model')>, 'negative_change_weight': <Value(1.0)>, 'ohem_ratio': <Value(None)>, 'optimizer': <Value('RAdam')>, 'perterb_scale': <Value(0.0)>, 'positional_dims': <Value(48)>, 'positive_change_weight': <Value(1.0)>, 'predictable_classes': <Value(None)>, 'rescale_nans': <Value(None)>, 'saliency_head_hidden': <Value(2)>, 'saliency_loss': <Value('focal')>, 'saliency_weights': <Value('auto')>, 'squash_modes': <Value(False)>, 'stream_channels': <Value(8)>, 'token_norm': <Value('none')>, 'tokenizer': <Value('rearrange')>, 'weight_decay': <Value(0.0)>, 'window_size': <Value(8)>}¶
- normalize()¶
- class geowatch.tasks.fusion.methods.channelwise_transformer.MultimodalTransformer(classes=10, dataset_stats=None, input_sensorchan=None, input_channels=None, **kwargs)[source]¶
Bases:
LightningModule
,WatchModuleMixins
Todo
[ ] Change name MultimodalTransformer -> FusionModel
[ ] Move parent module methods -> models
CommandLine
xdoctest -m geowatch.tasks.fusion.methods.channelwise_transformer MultimodalTransformer
Example
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> from geowatch.tasks.fusion import datamodules >>> print('(STEP 0): SETUP THE DATA MODULE') >>> datamodule = datamodules.KWCocoVideoDataModule( >>> train_dataset='special:vidshapes-geowatch', num_workers=4, channels='auto') >>> datamodule.setup('fit') >>> dataset = datamodule.torch_datasets['train'] >>> print('(STEP 1): ESTIMATE DATASET STATS') >>> dataset_stats = dataset.cached_dataset_stats(num=3) >>> print('dataset_stats = {}'.format(ub.urepr(dataset_stats, nl=3))) >>> loader = datamodule.train_dataloader() >>> print('(STEP 2): SAMPLE BATCH') >>> batch = next(iter(loader)) >>> for item_idx, item in enumerate(batch): >>> print(f'item_idx={item_idx}') >>> item_summary = dataset.summarize_item(item) >>> print('item_summary = {}'.format(ub.urepr(item_summary, nl=2))) >>> print('(STEP 3): THE REST OF THE TEST') >>> #self = MultimodalTransformer(arch_name='smt_it_joint_p8') >>> self = MultimodalTransformer(arch_name='smt_it_joint_p2', >>> dataset_stats=dataset_stats, >>> classes=datamodule.predictable_classes, >>> decoder='segmenter', >>> change_loss='dicefocal', >>> #attention_impl='performer' >>> attention_impl='exact' >>> ) >>> device = torch.device('cpu') >>> self = self.to(device) >>> # Run forward pass >>> from geowatch.utils import util_netharn >>> num_params = util_netharn.number_of_parameters(self) >>> print('num_params = {!r}'.format(num_params)) >>> output = self.forward_step(batch, with_loss=True) >>> import torch.profiler >>> from torch.profiler import profile, ProfilerActivity >>> with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: >>> with torch.profiler.record_function("model_inference"): >>> output = self.forward_step(batch, with_loss=True) >>> print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
Example
>>> # Note: it is important that the non-kwargs are saved as hyperparams >>> from geowatch.tasks.fusion.methods.channelwise_transformer import MultimodalTransformer >>> self = model = MultimodalTransformer(arch_name="smt_it_joint_p2", input_sensorchan='r|g|b') >>> assert "classes" in model.hparams >>> assert "dataset_stats" in model.hparams >>> assert "input_sensorchan" in model.hparams >>> assert "tokenizer" in model.hparams
- classmethod add_argparse_args(parent_parser)[source]¶
Only required for backwards compatibility until lightning CLI is the primary entry point.
Example
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> from geowatch.utils.configargparse_ext import ArgumentParser >>> cls = MultimodalTransformer >>> parent_parser = ArgumentParser(formatter_class='defaults') >>> cls.add_argparse_args(parent_parser) >>> parent_parser.print_help() >>> parent_parser.parse_known_args()
print(scfg.Config.port_argparse(parent_parser, style=’dataconf’))
- classmethod compatible(cfgdict)[source]¶
Given keyword arguments, find the subset that is compatible with this constructor. This is somewhat hacked because of usage of scriptconfig, but could be made nicer by future updates. # init_kwargs = ub.compatible(config, cls.__init__)
- configure_optimizers()[source]¶
Todo
[ ] Enable use of other optimization algorithms on the CLI
[ ] Enable use of other scheduler algorithms on the CLI
Note
Is this even called when using LightningCLI? Nope, the LightningCLI overwrites it.
References
https://pytorch-optimizer.readthedocs.io/en/latest/index.html https://pytorch-lightning.readthedocs.io/en/stable/common/optimization.html
Example
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # noqa >>> from geowatch.utils.lightning_ext.monkeypatches import disable_lightning_hardware_warnings >>> disable_lightning_hardware_warnings() >>> self = MultimodalTransformer(arch_name="smt_it_joint_p2", input_sensorchan='r|g|b') >>> max_epochs = 80 >>> self.trainer = pl.Trainer(max_epochs=max_epochs) >>> [opt], [sched] = self.configure_optimizers() >>> rows = [] >>> # Insepct what the LR curve will look like >>> for _ in range(max_epochs): ... sched.last_epoch += 1 ... lr = sched.get_last_lr()[0] ... rows.append({'lr': lr, 'last_epoch': sched.last_epoch}) >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> import pandas as pd >>> data = pd.DataFrame(rows) >>> sns = kwplot.autosns() >>> sns.lineplot(data=data, y='lr', x='last_epoch')
Example
>>> # Verify lr and decay is set correctly >>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> my_lr = 2.3e-5 >>> my_decay = 2.3e-5 >>> kw = dict(arch_name="smt_it_joint_p2", input_sensorchan='r|g|b', learning_rate=my_lr, weight_decay=my_decay) >>> self = MultimodalTransformer(**kw) >>> [opt], [sched] = self.configure_optimizers() >>> assert opt.param_groups[0]['lr'] == my_lr >>> assert opt.param_groups[0]['weight_decay'] == my_decay >>> # >>> self = MultimodalTransformer(**kw, optimizer='sgd') >>> [opt], [sched] = self.configure_optimizers() >>> assert opt.param_groups[0]['lr'] == my_lr >>> assert opt.param_groups[0]['weight_decay'] == my_decay >>> # >>> self = MultimodalTransformer(**kw, optimizer='AdamW') >>> [opt], [sched] = self.configure_optimizers() >>> assert opt.param_groups[0]['lr'] == my_lr >>> assert opt.param_groups[0]['weight_decay'] == my_decay >>> # >>> # self = MultimodalTransformer(**kw, optimizer='MADGRAD') >>> # [opt], [sched] = self.configure_optimizers() >>> # assert opt.param_groups[0]['lr'] == my_lr >>> # assert opt.param_groups[0]['weight_decay'] == my_decay
- overfit(batch)[source]¶
Overfit script and demo
CommandLine
python -m xdoctest -m geowatch.tasks.fusion.methods.channelwise_transformer MultimodalTransformer.overfit --overfit-demo
Example
>>> # xdoctest: +REQUIRES(--overfit-demo) >>> # ============ >>> # DEMO OVERFIT: >>> # ============ >>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion import datamodules >>> from geowatch.utils.util_data import find_dvc_dpath >>> import geowatch >>> import kwcoco >>> from os.path import join >>> import os >>> if 1: >>> print(''' ... # Generate toy datasets ... DATA_DPATH=$HOME/data/work/toy_change ... TRAIN_FPATH=$DATA_DPATH/vidshapes_msi_train/data.kwcoco.json ... mkdir -p "$DATA_DPATH" ... kwcoco toydata --key=vidshapes-videos8-frames5-randgsize-speed0.2-msi-multisensor --bundle_dpath "$DATA_DPATH/vidshapes_msi_train" --verbose=5 ... ''') >>> coco_fpath = ub.expandpath('$HOME/data/work/toy_change/vidshapes_msi_train/data.kwcoco.json') >>> coco_fpath = 'vidshapes-videos8-frames5-randgsize-speed0.2-msi-multisensor' >>> coco_dset = kwcoco.CocoDataset.coerce(coco_fpath) >>> channels="B11,r|g|b,B1|B8|B11" >>> if 0: >>> dvc_dpath = geowatch.find_dvc_dpath(tags='phase2_data', hardware='auto') >>> coco_dset = (dvc_dpath / 'Drop4-BAS') / 'data_vali.kwcoco.json' >>> channels='swir16|swir22|blue|green|red|nir' >>> coco_dset = (dvc_dpath / 'Drop4-BAS') / 'combo_vali_I2.kwcoco.json' >>> channels='blue|green|red|nir,invariants.0:17' >>> if 0: >>> coco_dset = geowatch.demo.demo_kwcoco_multisensor(max_speed=0.5) >>> # coco_dset = 'special:vidshapes8-frames9-speed0.5-multispectral' >>> #channels='B1|B11|B8|r|g|b|gauss' >>> channels='X.2|Y:2:6,B1|B8|B8a|B10|B11,r|g|b,disparity|gauss,flowx|flowy|distri' >>> coco_dset = kwcoco.CocoDataset.coerce(coco_dset) >>> datamodule = datamodules.KWCocoVideoDataModule( >>> train_dataset=coco_dset, >>> chip_size=128, batch_size=1, time_steps=5, >>> channels=channels, >>> normalize_peritem='blue|green|red|nir', >>> normalize_inputs=32, neg_to_pos_ratio=0, >>> num_workers='avail/2', >>> mask_low_quality=True, >>> observable_threshold=0.6, >>> use_grid_positives=False, use_centered_positives=True, >>> ) >>> datamodule.setup('fit') >>> dataset = torch_dset = datamodule.torch_datasets['train'] >>> torch_dset.disable_augmenter = True >>> dataset_stats = datamodule.dataset_stats >>> input_sensorchan = datamodule.input_sensorchan >>> classes = datamodule.classes >>> print('dataset_stats = {}'.format(ub.urepr(dataset_stats, nl=3))) >>> print('input_sensorchan = {}'.format(input_sensorchan)) >>> print('classes = {}'.format(classes)) >>> # Choose subclass to test this with (does not cover all cases) >>> self = methods.MultimodalTransformer( >>> # =========== >>> # Backbone >>> #arch_name='smt_it_joint_p2', >>> arch_name='smt_it_stm_p8', >>> stream_channels = 16, >>> #arch_name='deit', >>> optimizer='AdamW', >>> learning_rate=1e-5, >>> weight_decay=1e-3, >>> #attention_impl='performer', >>> attention_impl='exact', >>> #decoder='segmenter', >>> #saliency_head_hidden=4, >>> decoder='mlp', >>> change_loss='dicefocal', >>> #class_loss='cce', >>> class_loss='dicefocal', >>> #saliency_loss='dicefocal', >>> saliency_loss='focal', >>> # =========== >>> # Change Loss >>> global_change_weight=1e-5, >>> positive_change_weight=1.0, >>> negative_change_weight=0.5, >>> # =========== >>> # Class Loss >>> global_class_weight=1e-5, >>> class_weights='auto', >>> # =========== >>> # Saliency Loss >>> global_saliency_weight=1.00, >>> # =========== >>> # Domain Metadata (Look Ma, not hard coded!) >>> dataset_stats=dataset_stats, >>> classes=classes, >>> input_sensorchan=input_sensorchan, >>> #tokenizer='dwcnn', >>> tokenizer='linconv', >>> multimodal_reduce='learned_linear', >>> #tokenizer='rearrange', >>> # normalize_perframe=True, >>> window_size=8, >>> ) >>> self.datamodule = datamodule >>> datamodule._notify_about_tasks(model=self) >>> # Run one visualization >>> loader = datamodule.train_dataloader() >>> # Load one batch and show it before we do anything >>> batch = next(iter(loader)) >>> print(ub.urepr(dataset.summarize_item(batch[0]), nl=3)) >>> import kwplot >>> plt = kwplot.autoplt(force='Qt5Agg') >>> plt.ion() >>> canvas = datamodule.draw_batch(batch, max_channels=5, overlay_on_image=0) >>> kwplot.imshow(canvas, fnum=1) >>> # Run overfit >>> device = 0 >>> self.overfit(batch)
- forward_step(batch, with_loss=False, stage='unspecified')[source]¶
Generic forward step used for test / train / validation
- Returns:
with keys for various predictions / losses
- Return type:
Dict
CommandLine
xdoctest -m geowatch.tasks.fusion.methods.channelwise_transformer MultimodalTransformer.forward_step
Example
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion import datamodules >>> import geowatch >>> datamodule = datamodules.KWCocoVideoDataModule( >>> train_dataset='special:vidshapes-geowatch', >>> num_workers=0, chip_size=96, time_steps=4, >>> normalize_inputs=8, neg_to_pos_ratio=0, batch_size=5, >>> channels='auto', >>> ) >>> datamodule.setup('fit') >>> train_dset = datamodule.torch_datasets['train'] >>> loader = datamodule.train_dataloader() >>> batch = next(iter(loader)) >>> # Test with "failed samples" >>> batch[0] = None >>> batch[2] = None >>> batch[3] = None >>> batch[4] = None >>> if 1: >>> from geowatch.utils.util_netharn import _debug_inbatch_shapes >>> print(_debug_inbatch_shapes(batch)) >>> # Choose subclass to test this with (does not cover all cases) >>> self = model = methods.MultimodalTransformer( >>> arch_name='smt_it_joint_p8', tokenizer='rearrange', >>> decoder='segmenter', >>> dataset_stats=datamodule.dataset_stats, global_saliency_weight=1.0, global_change_weight=1.0, global_class_weight=1.0, >>> classes=datamodule.predictable_classes, input_sensorchan=datamodule.input_sensorchan) >>> with_loss = True >>> outputs = self.forward_step(batch, with_loss=with_loss) >>> canvas = datamodule.draw_batch(batch, outputs=outputs, max_items=3, overlay_on_image=False) >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(canvas) >>> kwplot.show_if_requested()
Example
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> channels, classes, dataset_stats = MultimodalTransformer.demo_dataset_stats() >>> self = MultimodalTransformer( >>> arch_name='smt_it_stm_p1', tokenizer='linconv', >>> decoder='segmenter', classes=classes, global_saliency_weight=1, >>> dataset_stats=dataset_stats, input_sensorchan=channels) >>> batch = self.demo_batch() >>> outputs = self.forward_step(batch, with_loss=True) >>> from geowatch.utils.util_netharn import _debug_inbatch_shapes >>> print(_debug_inbatch_shapes(batch)) >>> print(_debug_inbatch_shapes(outputs))
Example
>>> # Test learned_linear multimodal reduce >>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> channels, classes, dataset_stats = MultimodalTransformer.demo_dataset_stats() >>> self = MultimodalTransformer( >>> arch_name='smt_it_stm_p1', tokenizer='linconv', >>> decoder='mlp', classes=classes, global_saliency_weight=1, >>> dataset_stats=dataset_stats, input_sensorchan=channels, multimodal_reduce='learned_linear') >>> batch = self.demo_batch() >>> outputs = self.forward_step(batch, with_loss=True) >>> from geowatch.utils.util_netharn import _debug_inbatch_shapes >>> print(_debug_inbatch_shapes(batch)) >>> print(_debug_inbatch_shapes(outputs)) >>> # outputs['loss'].backward()
- forward_item(item, with_loss=False)[source]¶
CommandLine
xdoctest -m geowatch.tasks.fusion.methods.channelwise_transformer MultimodalTransformer.forward_item:1
Example
>>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> channels, classes, dataset_stats = MultimodalTransformer.demo_dataset_stats() >>> self = MultimodalTransformer( >>> arch_name='smt_it_stm_p1', tokenizer='linconv', >>> decoder='segmenter', classes=classes, global_saliency_weight=1, >>> dataset_stats=dataset_stats, input_sensorchan=channels) >>> item = self.demo_batch(width=64, height=65)[0] >>> outputs = self.forward_item(item, with_loss=True) >>> print('item') >>> from geowatch.utils.util_netharn import _debug_inbatch_shapes >>> print(_debug_inbatch_shapes(item)) >>> print('outputs') >>> print(_debug_inbatch_shapes(outputs))
Example
>>> # Box head >>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> channels, classes, dataset_stats = MultimodalTransformer.demo_dataset_stats() >>> self = MultimodalTransformer( >>> arch_name='smt_it_stm_p1', tokenizer='linconv', >>> decoder='mlp', classes=classes, global_saliency_weight=1, >>> dataset_stats=dataset_stats, input_sensorchan=channels, >>> decouple_resolution=False, global_box_weight=1) >>> batch = self.demo_batch(width=64, height=64, num_timesteps=3) >>> item = batch[0] >>> from geowatch.utils.util_netharn import _debug_inbatch_shapes >>> print(_debug_inbatch_shapes(batch)) >>> result1 = self.forward_step(batch, with_loss=True) >>> assert len(result1['box']) >>> assert 'box_ltrb' in result1['box'][0] >>> assert len(result1['box'][0]['box_ltrb'].shape) == 3 >>> print(_debug_inbatch_shapes(result1)) >>> # Check we can go backward >>> result1['loss'].backward()
Example
>>> # Decoupled resolutions >>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> channels, classes, dataset_stats = MultimodalTransformer.demo_dataset_stats() >>> self = MultimodalTransformer( >>> arch_name='smt_it_stm_p1', tokenizer='linconv', >>> decoder='mlp', classes=classes, global_saliency_weight=1, >>> dataset_stats=dataset_stats, input_sensorchan=channels, >>> decouple_resolution=True, global_box_weight=0) >>> batch = self.demo_batch(width=(11, 21), height=(16, 64), num_timesteps=3) >>> item = batch[0] >>> from geowatch.utils.util_netharn import _debug_inbatch_shapes >>> print(_debug_inbatch_shapes(batch)) >>> result1 = self.forward_step(batch, with_loss=True) >>> print(_debug_inbatch_shapes(result1)) >>> # Check we can go backward >>> result1['loss'].backward()
- save_package(package_path, context=None, verbose=1)[source]¶
Save model architecture and checkpoint as a torch package.
- Parameters:
package_path (str | PathLike) – where to save the package
context (Any) – custom json-serializable data to save in the header
verbose (int) – verbosity level
CommandLine
xdoctest -m geowatch.tasks.fusion.methods.channelwise_transformer MultimodalTransformer.save_package
Example
>>> # Test without datamodule >>> import ubelt as ub >>> from os.path import join >>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir() >>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test >>> from geowatch.tasks.fusion import methods >>> from geowatch.tasks.fusion import datamodules >>> model = self = methods.MultimodalTransformer( >>> arch_name="smt_it_joint_p2", input_sensorchan=5, >>> change_head_hidden=0, saliency_head_hidden=0, >>> class_head_hidden=0)
>>> # Save the model (TODO: need to save datamodule as well) >>> model.save_package(package_path)
>>> # Test that the package can be reloaded >>> recon = methods.MultimodalTransformer.load_package(package_path) >>> # Check consistency and data is actually different >>> recon_state = recon.state_dict() >>> model_state = model.state_dict() >>> assert recon is not model >>> assert set(recon_state) == set(recon_state) >>> for key in recon_state.keys(): >>> assert (model_state[key] == recon_state[key]).all() >>> assert model_state[key] is not recon_state[key]
- forward(batch)[source]¶
Example
>>> import pytest >>> pytest.skip('not currently used') >>> from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA >>> from geowatch.tasks.fusion import datamodules >>> channels = 'B1,B8|B8a,B10|B11' >>> channels = 'B1|B8|B10|B8a|B11' >>> datamodule = datamodules.KWCocoVideoDataModule( >>> train_dataset='special:vidshapes8-multispectral', num_workers=0, channels=channels) >>> datamodule.setup('fit') >>> train_dataset = datamodule.torch_datasets['train'] >>> dataset_stats = train_dataset.cached_dataset_stats() >>> loader = datamodule.train_dataloader() >>> tokenizer = 'convexpt-v1' >>> tokenizer = 'dwcnn' >>> batch = next(iter(loader)) >>> #self = MultimodalTransformer(arch_name='smt_it_joint_p8') >>> self = MultimodalTransformer( >>> arch_name='smt_it_joint_p8', >>> dataset_stats=dataset_stats, >>> change_loss='dicefocal', >>> decoder='dicefocal', >>> attention_impl='performer', >>> tokenizer=tokenizer, >>> ) >>> #images = torch.stack([ub.peek(f['modes'].values()) for f in batch[0]['frames']])[None, :] >>> #images.shape >>> #self.forward(images)
- geowatch.tasks.fusion.methods.channelwise_transformer.slice_to_agree(a1, a2, axes=None)[source]¶
Example
from geowatch.tasks.fusion.methods.channelwise_transformer import * # NOQA a1 = np.random.rand(3, 5, 7, 9, 3) a2 = np.random.rand(3, 5, 6, 9, 3) b1, b2 = slice_to_agree(a1, a2) print(f’{a1.shape=} {a2.shape=}’) print(f’{b1.shape=} {b2.shape=}’)
a1 = np.random.rand(3, 5, 7, 9, 1) a2 = np.random.rand(3, 1, 6, 9, 3) b1, b2 = slice_to_agree(a1, a2, axes=[0, 1, 2, 3]) print(f’{a1.shape=} {a2.shape=}’) print(f’{b1.shape=} {b2.shape=}’)
- geowatch.tasks.fusion.methods.channelwise_transformer.perterb_params(optimizer, std)[source]¶
Given an optimizer, perterb all parameters with Gaussian noise
From: [ShrinkAndPerterb].
While the presented conventional approaches do not remedy the warm-start problem, we have identified a remarkably simple trick that efficiently closes the generalization gap. At each round of training t, when new samples are appended to the training set, we propose initializing the network’s parameters by shrinking the weights found in the previous round of optimization towards zero, then adding a small amount of parameter noise.
Specifically, we initialize each learnable parameter
Math:
θ[i, t] = λ * θ[i, t - 1] + p[t]
where p[t] ∼ N (0, (σ ** 2)) and 0 < λ < 1.
References