Source code for geowatch.cli.torch_model_stats

"""
Print stats about a torch model

Exapmle Usage:
    DVC_DPATH=$(geowatch_dvc)
    PACKAGE_FPATH=$DVC_DPATH/models/fusion/SC-20201117/BOTH_smt_it_stm_p8_L1_DIL_v55/BOTH_smt_it_stm_p8_L1_DIL_v55_epoch=5-step=53819.pt
    python -m geowatch.cli.torch_model_stats $PACKAGE_FPATH
"""
import scriptconfig as scfg
import ubelt as ub


[docs] class TorchModelStatsConfig(scfg.DataConfig): """ Print stats about a torch model. Currently some things are hard-coded for fusion models """ src = scfg.PathList(help='path to one or more torch models', position=1) stem_stats = scfg.Value(True, isflag=True, help='if True, print more verbose model mean/std') hparams = scfg.Value(True, isflag=True, help='if True, print fit hyperparameters')
[docs] def main(cmdline=False, **kwargs): """ Ignore: import geowatch from geowatch.cli.torch_model_stats import * # NOQA dvc_dpath = geowatch.find_dvc_dpath() package_fpath1 = dvc_dpath / 'models/fusion/SC-20201117/BOTH_smt_it_stm_p8_L1_DIL_v55/BOTH_smt_it_stm_p8_L1_DIL_v55_epoch=5-step=53819.pt' package_fpath2 = dvc_dpath / 'models/fusion/SC-20201117/BAS_smt_it_stm_p8_TUNE_L1_RAW_v58/BAS_smt_it_stm_p8_TUNE_L1_RAW_v58_epoch=3-step=81135.pt' package_fpath = dvc_dpath / 'models/fusion/SC-20201117/BAS_smt_it_stm_p8_L1_raw_v53/BAS_smt_it_stm_p8_L1_raw_v53_epoch=3-step=85011.pt' kwargs = { 'src': [package_fpath2, package_fpath1] } main(cmdline=False, **kwargs) """ config = TorchModelStatsConfig.cli(cmdline=cmdline, data=kwargs, strict=True) import rich from rich.markup import escape rich.print('config = {}'.format(escape(ub.urepr(config, nl=1)))) package_paths = config['src'] import geowatch import warnings if not ub.iterable(package_paths): package_paths = [package_paths] try: dvc_dpath = geowatch.find_dvc_dpath() except Exception: dvc_dpath = None package_rows = [] for package_fpath in package_paths: print('--------') print(f'package_fpath={package_fpath}') stem_stats = config['stem_stats'] try: row = torch_model_stats(package_fpath, stem_stats=stem_stats, dvc_dpath=dvc_dpath) except RuntimeError as pkg_ex: print('Reading as a package failed, attempting to read as a checkpoint') checkpoint_fpath = package_fpath try: checkpoint_stats = torch_checkpoint_stats(checkpoint_fpath) except Exception as ckpt_ex: warnings.warn(f'Unable to read input as a checkpoint due to: {ckpt_ex}.') raise pkg_ex else: warnings.warn(f'Unable to read input as a package due to: {pkg_ex}. Interpreting as a checkpoint instead') rich.print('checkpoint_stats = {}'.format(ub.urepr(checkpoint_stats, nl=2, sort=0, precision=2))) else: model_stats = row.get('model_stats', None) fit_config = row.pop('fit_config', None) config_cli_yaml = row.pop('config_cli_yaml', None) if config.hparams: rich.print('fit_config = {}'.format(ub.urepr(fit_config, nl=1))) rich.print('config_cli_yaml = {}'.format(ub.urepr(config_cli_yaml, nl=2))) rich.print('model_stats = {}'.format(ub.urepr(model_stats, nl=2, sort=0, precision=2))) package_rows.append(row) print('package_rows = {}'.format(ub.urepr(package_rows, nl=2, sort=0)))
[docs] def torch_checkpoint_stats(checkpoint_fpath): """ A fallback set of statistics we can make for checkpoints only. Summarizes a PyTorch checkpoint by extracting useful metadata about the model's state_dict and optimizer states. Args: checkpoint_path (str | PathLike): Path to the checkpoint file. Returns: dict: Summary statistics of the checkpoint. """ import torch data = torch.load(checkpoint_fpath) top_level_keys = list(data.keys()) print(f'top_level_keys = {ub.urepr(top_level_keys, nl=1)}') # TODO: Given a checkpoint path, we may be able to read more information # about it if we can find the corresponding haprams.yaml or # train_config.yaml nested_keys = { 'state_dict', 'optimizer_states', 'loops' } nested = ub.udict.intersection(data, nested_keys) non_nested = ub.udict.difference(data, nested_keys) stats = { 'path': checkpoint_fpath, **non_nested } if 'optimizer_states' in nested: optimizer_states_stats = {} optimizer_states = nested['optimizer_states'] optimizer_states_stats['num_optimizer_states'] = len(optimizer_states) opt_stat_list = [] for opt_state in optimizer_states: opt_state_stats = {} # TODO: Can we do more here? opt_state_stats['num_param_groups'] = len(opt_state['param_groups']) opt_stat_list.append(opt_state_stats) optimizer_states_stats['total_num_param_groups'] = sum(r['num_param_groups'] for r in opt_stat_list) else: optimizer_states_stats = None if 'state_dict' in nested: state_dict_stats = {} state_dict = nested['state_dict'] len(state_dict) # state_shape = ub.udict.map_values(state_dict, lambda x: x.shape) state_numel = ub.udict.map_values(state_dict, lambda x: x.numel()) total_params = sum(state_numel.values()) # TODO: Can we use the optimizer to determine the trainable params? state_dict_stats['num_tensors'] = len(state_dict) state_dict_stats['total_params'] = total_params else: state_dict_stats = None if 'loops' in nested: loop_stats = {} # TODO: extract appopriate stats loops = nested['loops'] else: loops = None loop_stats = None # Just use the loops loop_stats = loops stats['loops'] = loop_stats stats['state_dict'] = state_dict_stats stats['optimizer_states_stats'] = optimizer_states_stats return stats
[docs] def torch_model_stats(package_fpath, stem_stats=True, dvc_dpath=None): import kwcoco from geowatch.tasks.fusion import utils from geowatch.utils import util_netharn from geowatch.monkey import monkey_torchmetrics monkey_torchmetrics.fix_torchmetrics_compatability() package_fpath = ub.Path(package_fpath) if not package_fpath.exists(): if package_fpath.augment(tail='.dvc').exists(): raise Exception('model does not exist, but its dvc file does') else: raise Exception('model does not exist') file_stat = package_fpath.stat() # TODO: generalize the load-package raw_module, package_header = utils.load_model_from_package(package_fpath, with_header=True) if hasattr(raw_module, 'module'): module = raw_module.module else: module = raw_module # TODO: get the category freq model_stats = {} num_params = util_netharn.number_of_parameters(module) print(ub.urepr(utils.model_json(module, max_depth=3), nl=-1, sort=0)) # print(ub.urepr(utils.model_json(module, max_depth=2), nl=-1, sort=0)) # import xdev # with xdev.embed_on_exception_context: try: state = module.state_dict() except Exception: if hasattr(module, 'head_metrics'): module.head_metrics.clear() state = module.state_dict() else: raise state_keys = list(state.keys()) # print('state_keys = {}'.format(ub.urepr(state_keys, nl=1))) unique_sensors = set() config_cli_yaml = None train_dataset = None prenorm_stats = None fit_config = {} if hasattr(module, 'dataset_stats') and module.dataset_stats is not None: dataset_stats = module.dataset_stats.copy() if 'modality_input_stats' in dataset_stats: # This is too much info to print dataset_stats.pop('modality_input_stats', None) known_input_stats = [] unknown_input_stats = [] sensor_modes_with_stats = set() for sens_chan_key, stats in dataset_stats['input_stats'].items(): sensor, channel = sens_chan_key channel = kwcoco.ChannelSpec.coerce(channel).concise().spec sensor_modes_with_stats.add((sensor, channel)) unique_sensors.add(sensor) sensor_stat = { 'sensor': sensor, 'channel': channel, } if stem_stats: import numpy as np sensor_stat.update({ 'mean': np.asarray(stats['mean']).ravel().tolist(), 'std': np.asarray(stats['std']).ravel().tolist(), }) known_input_stats.append(sensor_stat) unique_sensor_modes = list(dataset_stats['unique_sensor_modes']) for sensor, channel in unique_sensor_modes: channel = kwcoco.ChannelSpec.coerce(channel).concise().spec key = (sensor, channel) if key not in sensor_modes_with_stats: unique_sensors.add(sensor) unknown_input_stats.append( { 'sensor': sensor, 'channel': channel, } ) mb_size = file_stat.st_size / (2.0 ** 20) size_str = ub.urepr(mb_size, precision=2) + ' MB' # Add in some params about how this model was trained if hasattr(raw_module, 'config_cli_yaml'): config_cli_yaml = raw_module.config_cli_yaml else: config_cli_yaml = None if hasattr(raw_module, 'fit_config'): # Old non-cli modules fit_config = raw_module.fit_config else: # new lightning cli modules fit_config = ( ub.udict(getattr(raw_module, 'datamodule_hparams', {})) | ub.udict(raw_module.hparams) ) if 'train_dataset' in fit_config: train_dataset = ub.Path(fit_config['train_dataset']) else: if config_cli_yaml is not None: train_dataset = config_cli_yaml.get('data', {}).get('train_dataset', None) else: train_dataset = None if dvc_dpath is not None and train_dataset is not None: try: if str(train_dataset).startswith(str(dvc_dpath)): train_dataset = train_dataset.relative_to(dvc_dpath) if str(package_fpath).startswith(str(dvc_dpath)): package_fpath = package_fpath.relative_to(dvc_dpath) except Exception: ... heads = [] if fit_config['global_class_weight']: heads.append('class') if fit_config['global_saliency_weight']: heads.append('saliency') spacetime_stats = ub.udict(fit_config) & [ 'chip_size', 'time_steps', 'time_sampling', 'time_span', 'chip_dims', 'window_space_scale', 'input_space_scale', ] # spacetime_stats = { # 'chip_size': fit_config['chip_size'], # 'time_steps': fit_config['time_steps'], # 'time_sampling': fit_config['time_sampling'], # 'time_span': fit_config['time_span'], # } model_stats['size'] = size_str model_stats['num_params'] = num_params model_stats['num_states'] = len(state_keys) model_stats['heads'] = heads model_stats['train_dataset'] = None if train_dataset is None else str(train_dataset) model_stats['spacetime_stats'] = spacetime_stats model_stats['classes'] = list(module.classes) model_stats['known_inputs'] = known_input_stats model_stats['unknown_inputs'] = unknown_input_stats # Normalization done in the dataloader prenorm_stats = { 'normalize_peritem': fit_config.get('normalize_peritem'), } param_stats = { name: { "size": param.size(), "min": param.min().item(), "max": param.max().item(), "mean": param.mean().item(), "std": param.std().item(), } for name, param in module.named_parameters() } param_stats_summary = { "min": min([summary["min"] for summary in param_stats.values()]), "max": min([summary["max"] for summary in param_stats.values()]), "mean": min([summary["mean"] for summary in param_stats.values()]), } try: unique_sensors = sorted(unique_sensors) except TypeError: ... row = { 'name': package_fpath.stem, 'task': 'TODO', 'file_name': str(package_fpath), 'sensors': unique_sensors, 'train_dataset': str(train_dataset), 'fit_config': fit_config, 'config_cli_yaml': config_cli_yaml, 'model_stats': model_stats, 'prenorm_stats': prenorm_stats, 'param_stats': param_stats_summary, 'package_header': package_header, } if hasattr(module, 'input_sensorchan'): input_sensorchan = module.input_sensorchan.concise().spec row['input_sensorchan'] = input_sensorchan elif hasattr(module, 'input_channels'): input_channels = module.input_channels.concise().spec row['input_channels'] = input_channels return row
[docs] def fallback(fpath): import zipfile zfile = zipfile.ZipFile(fpath) for internal_path in zfile.namelist(): if internal_path.endswith('.yaml'): data = zfile.read(internal_path) print(data.decode('utf8'))
__cli__ = TorchModelStatsConfig if __name__ == '__main__': """ CommandLine: python -m geowatch.cli.torch_model_stats """ main(cmdline=True)