geowatch.tasks.fusion.methods.unet_baseline module

class geowatch.tasks.fusion.methods.unet_baseline.NanToNum(num=0.0)[source]

Bases: Module

Module which converts NaN values in input tensors to numbers.

forward(x)[source]
class geowatch.tasks.fusion.methods.unet_baseline.UNetBaseline(classes=10, dataset_stats=None, input_sensorchan=None, token_dim: int = 32, name: str = 'unnamed_model', class_weights: str = 'auto', saliency_weights: str = 'auto', positive_change_weight: float = 1.0, negative_change_weight: float = 1.0, global_class_weight: float = 1.0, global_change_weight: float = 1.0, global_saliency_weight: float = 1.0, change_loss: str = 'cce', class_loss: str = 'focal', saliency_loss: str = 'focal', ohem_ratio: float | None = None, focal_gamma: float | None = 2.0)[source]

Bases: LightningModule, WatchModuleMixins

Parameters:
  • name – Specify a name for the experiment. (Unsure if the Model is the place for this)

  • token_width – Width of each square token.

  • token_dim – Dimensionality of each computed token.

  • spatial_scale_base – The scale assigned to each token equals scale_base / token_density, where the token density is the number of tokens along a given axis.

  • temporal_scale_base – The scale assigned to each token equals scale_base / token_density, where the token density is the number of tokens along a given axis.

  • class_weights – Class weighting strategy.

  • saliency_weights – Class weighting strategy.

Example

>>> # Note: it is important that the non-kwargs are saved as hyperparams
>>> from geowatch.tasks.fusion.methods.unet_baseline import UNetBaseline
>>> model = UNetBaseline(
>>>   input_sensorchan='r|g|b',
>>> )
get_cfgstr()[source]
process_frame(frame) Dict[str, Dict[str, Any]][source]
process_example(example)[source]
process_batch(batch)[source]
encode_frame(processed_frame)[source]
encode_example(processed_example)[source]
encode_batch(processed_batch)[source]
forward(batch)[source]

Example

>>> from geowatch.tasks import fusion
>>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats()
>>> model = fusion.methods.UNetBaseline(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>> )
>>> batch = model.demo_batch(width=64, height=64)
>>> outputs = model.forward(batch)
>>> for task_key, task_outputs in outputs.items():
>>>     if "probs" in task_key: continue
>>>     if task_key == "class": task_key = "class_idxs"
>>>     for task_pred, example in zip(task_outputs, batch):
>>>         for frame_idx, (frame_pred, frame) in enumerate(zip(task_pred, example["frames"])):
>>>             if (frame_idx == 0) and task_key.startswith("change"): continue
>>>             assert frame_pred.shape[1:] == frame[task_key].shape, f"{frame_pred.shape} should equal {frame[task_key].shape} for task '{task_key}'"
shared_step(batch, batch_idx=None, stage='train', with_loss=True)[source]

Example

>>> # xdoctest: +REQUIRES(env:SLOW_TESTS)
>>> from geowatch.tasks import fusion
>>> import torch
>>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats()
>>> model = fusion.methods.UNetBaseline(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> # xdoctest: +REQUIRES(env:SLOW_TESTS)
>>> from geowatch.tasks import fusion
>>> import torch
>>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats()
>>> model = fusion.methods.UNetBaseline(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> batch += [None]
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats()
>>> model = fusion.methods.UNetBaseline(
>>>     classes=classes, token_dim=2,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>> )
>>> batch = model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=0.1)
>>> batch += model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=0.5)
>>> batch += model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=1.0)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()
training_step(batch, batch_idx=None)[source]
validation_step(batch, batch_idx=None)[source]
test_step(batch, batch_idx=None)[source]
predict_step(batch, batch_idx=None)[source]
forward_step(batch, batch_idx=None, stage='train', with_loss=True)

Example

>>> # xdoctest: +REQUIRES(env:SLOW_TESTS)
>>> from geowatch.tasks import fusion
>>> import torch
>>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats()
>>> model = fusion.methods.UNetBaseline(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> # xdoctest: +REQUIRES(env:SLOW_TESTS)
>>> from geowatch.tasks import fusion
>>> import torch
>>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats()
>>> model = fusion.methods.UNetBaseline(
>>>     classes=classes,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>> )
>>> batch = model.demo_batch(batch_size=2, width=64, height=65, num_timesteps=3)
>>> batch += [None]
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()

Example

>>> from geowatch.tasks import fusion
>>> import torch
>>> channels, classes, dataset_stats = fusion.methods.UNetBaseline.demo_dataset_stats()
>>> model = fusion.methods.UNetBaseline(
>>>     classes=classes, token_dim=2,
>>>     dataset_stats=dataset_stats,
>>>     input_sensorchan=channels,
>>> )
>>> batch = model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=0.1)
>>> batch += model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=0.5)
>>> batch += model.demo_batch(batch_size=1, width=32, height=35, num_timesteps=3, nans=1.0)
>>> outputs = model.shared_step(batch)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer.zero_grad()
>>> loss = outputs["loss"]
>>> loss.backward()
>>> optimizer.step()
save_package(package_path, context=None, verbose=1)[source]

CommandLine

xdoctest -m geowatch.tasks.fusion.methods.unet_baseline UNetBaseline.save_package

Example

>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> #from geowatch.tasks.fusion.methods.unet_baseline import *  # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> model = self = methods.UNetBaseline(
>>>     input_sensorchan=5,
>>> )
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.UNetBaseline.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = model.state_dict()
>>> assert recon is not self
>>> assert set(recon_state) == set(recon_state)
>>> from geowatch.utils.util_kwarray import torch_array_equal
>>> for key in recon_state.keys():
>>>     v1 = model_state[key]
>>>     v2 = recon_state[key]
>>>     if not torch.allclose(v1, v2, equal_nan=True):
>>>         print('v1 = {}'.format(ub.urepr(v1, nl=1)))
>>>         print('v2 = {}'.format(ub.urepr(v2, nl=1)))
>>>         raise AssertionError(f'Difference in key={key}')
>>>     assert v1 is not v2, 'should be distinct copies'

Example

>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> #from geowatch.tasks.fusion.methods.unet_baseline import *  # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> model = self = methods.UNetBaseline(
>>>     input_sensorchan=5,
>>> )
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.UNetBaseline.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = self.state_dict()
>>> assert recon is not self
>>> assert set(recon_state) == set(recon_state)
>>> from geowatch.utils.util_kwarray import torch_array_equal
>>> for key in recon_state.keys():
>>>     v1 = model_state[key]
>>>     v2 = recon_state[key]
>>>     if not torch.allclose(v1, v2, equal_nan=True):
>>>         print('v1 = {}'.format(ub.urepr(v1, nl=1)))
>>>         print('v2 = {}'.format(ub.urepr(v2, nl=1)))
>>>         raise AssertionError(f'Difference in key={key}')
>>>     assert v1 is not v2, 'should be distinct copies'

Example

>>> # Test without datamodule
>>> import ubelt as ub
>>> from os.path import join
>>> #from geowatch.tasks.fusion.methods.unet_baseline import *  # NOQA
>>> dpath = ub.Path.appdir('geowatch/tests/package').ensuredir()
>>> package_path = join(dpath, 'my_package.pt')
>>> # Use one of our fusion.architectures in a test
>>> from geowatch.tasks.fusion import methods
>>> from geowatch.tasks.fusion import datamodules
>>> model = self = methods.UNetBaseline(
>>>     input_sensorchan=5,
>>> )
>>> # Save the model (TODO: need to save datamodule as well)
>>> model.save_package(package_path)
>>> # Test that the package can be reloaded
>>> #recon = methods.UNetBaseline.load_package(package_path)
>>> from geowatch.tasks.fusion.utils import load_model_from_package
>>> recon = load_model_from_package(package_path)
>>> # Check consistency and data is actually different
>>> recon_state = recon.state_dict()
>>> model_state = self.state_dict()
>>> assert recon is not self
>>> assert set(recon_state) == set(recon_state)
>>> from geowatch.utils.util_kwarray import torch_array_equal
>>> for key in recon_state.keys():
>>>     v1 = model_state[key]
>>>     v2 = recon_state[key]
>>>     if not torch.allclose(v1, v2, equal_nan=True):
>>>         print('v1 = {}'.format(ub.urepr(v1, nl=1)))
>>>         print('v2 = {}'.format(ub.urepr(v2, nl=1)))
>>>         raise AssertionError(f'Difference in key={key}')
>>>     assert v1 is not v2, 'should be distinct copies'