geowatch.tasks.fusion.utils module

geowatch.tasks.fusion.utils.millify(n)[source]
geowatch.tasks.fusion.utils.load_model_from_package(package_path)[source]

Loads a kitware-flavor torch package (requires a package_header exists)

Notes

  • I don’t like that we need to know module_name and arch_name a-priori given a path to a package, I just want to be able to construct the model instance. The package header solves this.

geowatch.tasks.fusion.utils.load_model_header(package_path)[source]

Only grabs header info from a packaged model.

class geowatch.tasks.fusion.utils.Lambda(lambda_)[source]

Bases: Module

forward(x)[source]
class geowatch.tasks.fusion.utils.DimensionDropout(dim, n_keep)[source]

Bases: Module

forward(x)[source]
geowatch.tasks.fusion.utils.ordinal_position_encoding(num_items, feat_size, method='sin', device='cpu')[source]

A positional encoding that represents ordinal

Parameters:
  • num_items (int) – number of dimensions to be encoded ( e.g. this is a spatial or temporal index)

  • feat_size (int) – this is the number of dimensions in the positional encoding generated for each dimension / item

Example

>>> # Use 5 feature dimensions to encode 3 timesteps
>>> from geowatch.tasks.fusion.utils import *  # NOQA
>>> num_timesteps = num_items = 3
>>> feat_size = 5
>>> encoding = ordinal_position_encoding(num_items, feat_size)
class geowatch.tasks.fusion.utils.SinePositionalEncoding(dest_dim, dim_to_encode, size=4)[source]

Bases: Module

Parameters:
  • dest_dim (int) – feature dimension to concat to

  • dim_to_encode (int) – dimension encoding is supposed to represent

  • size (int) – number of different encodings for the dim_to_encode

Example

>>> from geowatch.tasks.fusion.utils import *  # NOQA
>>> dest_dim = 3
>>> dim_to_encode = 2
>>> size = 8
>>> self = SinePositionalEncoding(dest_dim, dim_to_encode, size=size)
>>> x = torch.rand(3, 5, 7, 11, 13)
>>> y = self(x)
forward(x)[source]
geowatch.tasks.fusion.utils.model_json(model, max_depth=inf, depth=0)[source]

import torchvision model = torchvision.models.resnet50() info = model_json(model, max_depth=1) print(ub.urepr(info, sort=0, nl=-1))

geowatch.tasks.fusion.utils.category_tree_ensure_color(classes)[source]

Ensures that each category in a CategoryTree has a color

Todo

  • [ ] Add to CategoryTree

  • [ ] TODO: better function

  • [ ] Consolidate with ~/code/watch/geowatch/tasks/fusion/utils :: category_tree_ensure_color

  • [ ] Consolidate with ~/code/watch/geowatch/utils/kwcoco_extensions :: category_category_colors

  • [ ] Consolidate with ~/code/watch/geowatch/heuristics.py :: ensure_heuristic_category_tree_colors

  • [ ] Consolidate with ~/code/watch/geowatch/heuristics.py :: ensure_heuristic_coco_colors

Example

>>> import kwcoco
>>> classes = kwcoco.CategoryTree.demo()
>>> assert not any('color' in data for data in classes.graph.nodes.values())
>>> category_tree_ensure_color(classes)
>>> assert all('color' in data for data in classes.graph.nodes.values())