geowatch.cli.torch_model_stats module

Print stats about a torch model

Exapmle Usage:

DVC_DPATH=$(geowatch_dvc) PACKAGE_FPATH=$DVC_DPATH/models/fusion/SC-20201117/BOTH_smt_it_stm_p8_L1_DIL_v55/BOTH_smt_it_stm_p8_L1_DIL_v55_epoch=5-step=53819.pt python -m geowatch.cli.torch_model_stats $PACKAGE_FPATH

class geowatch.cli.torch_model_stats.TorchModelStatsConfig(*args, **kwargs)[source]

Bases: DataConfig

Print stats about a torch model.

Currently some things are hard-coded for fusion models

Valid options: []

Parameters:
  • *args – positional arguments for this data config

  • **kwargs – keyword arguments for this data config

default = {'hparams': <Value(True)>, 'src': <PathList(None)>, 'stem_stats': <Value(True)>}
geowatch.cli.torch_model_stats.main(cmdline=False, **kwargs)[source]
geowatch.cli.torch_model_stats.torch_checkpoint_stats(checkpoint_fpath)[source]

A fallback set of statistics we can make for checkpoints only.

Summarizes a PyTorch checkpoint by extracting useful metadata about the model’s state_dict and optimizer states.

Parameters:

checkpoint_path (str | PathLike) – Path to the checkpoint file.

Returns:

Summary statistics of the checkpoint.

Return type:

dict

geowatch.cli.torch_model_stats.torch_model_stats(package_fpath, stem_stats=True, dvc_dpath=None)[source]
geowatch.cli.torch_model_stats.fallback(fpath)[source]