Source code for geowatch.utils.lightning_ext.argparse_ext

"""
Do a better job with default argparse

TODO: work on this

import liberator

lib = liberator.Liberator()
lib.add_dynamic(get_init_arguments_and_types)
lib.add_dynamic(str_to_bool)
lib.add_dynamic(str_to_bool_or_int)
lib.add_dynamic(str_to_bool_or_str)
lib.add_dynamic(_int_or_float_type)
lib.add_dynamic(_gpus_allowed_type)
lib.expand(['pytorch_lightning'])
print(lib.current_sourcecode())

"""
import inspect


[docs] def get_init_arguments_and_types(cls): """ Scans the class signature and returns argument names, types and default values. Returns: List with tuples of 3 values: (argument name, set with argument types, argument default value). """ cls_default_params = inspect.signature(cls).parameters name_type_default = [] for arg in cls_default_params: arg_type = cls_default_params[arg].annotation arg_default = cls_default_params[arg].default try: if (type(arg_type).__name__ == '_LiteralGenericAlias'): arg_types = tuple({type(a) for a in arg_type.__args__}) elif (('typing.Literal' in str(arg_type)) or ('typing_extensions.Literal' in str(arg_type))): arg_types = tuple({type(a) for union_args in arg_type.__args__ for a in union_args.__args__}) else: arg_types = tuple(arg_type.__args__) except (AttributeError, TypeError): arg_types = (arg_type,) name_type_default.append((arg, arg_types, arg_default)) return name_type_default
[docs] def str_to_bool_or_str(val: str): """ Possibly convert a string representation of truth to bool. Returns the input otherwise. Based on the python implementation distutils.utils.strtobool. True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. """ lower = val.lower() if (lower in ('y', 'yes', 't', 'true', 'on', '1')): return True if (lower in ('n', 'no', 'f', 'false', 'off', '0')): return False return val
[docs] def str_to_bool_or_int(val: str): """ Convert a string representation to truth of bool if possible, or otherwise try to convert it to an int. >>> str_to_bool_or_int("FALSE") False >>> str_to_bool_or_int("1") True >>> str_to_bool_or_int("2") 2 >>> str_to_bool_or_int("abc") 'abc' """ val_converted = str_to_bool_or_str(val) if isinstance(val_converted, bool): return val_converted try: return int(val_converted) except ValueError: return val_converted
[docs] def str_to_bool(val: str) -> bool: """ Convert a string representation of truth to bool. True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. Raises: ValueError: If ``val`` isn't in one of the aforementioned true or false values. >>> str_to_bool('YES') True >>> str_to_bool('FALSE') False """ val_converted = str_to_bool_or_str(val) if isinstance(val_converted, bool): return val_converted raise ValueError(f'invalid truth value {val_converted}')
def _gpus_allowed_type(x: str): if (',' in x): return str(x) return int(x) def _int_or_float_type(x): if ('.' in str(x)): return float(x) return int(x)
[docs] def parse_docstring_args(cls): import inspect from xdoctest.docstr import docscrape_google symbol = cls.__init__ if symbol.__doc__ is None: arg_infos = [] else: arg_infos = list(docscrape_google.parse_google_args()) if not arg_infos: # Try cls instead arg_infos = list(docscrape_google.parse_google_args(cls.__doc__)) ignore_arg_names = ['self', 'args', 'kwargs'] if hasattr(cls, 'get_deprecated_arg_names'): ignore_arg_names += cls.get_deprecated_arg_names() # Get symbols from cls or init function. args_and_types = get_init_arguments_and_types(symbol) sig_lut = {name: (sig_type, sig_default) for name, sig_type, sig_default in args_and_types} for arg_info in arg_infos: name = arg_info['name'] arg_info['str_type'] = arg_info['type'] if name in sig_lut: sig_type, sig_default = sig_lut[name] arg_info['sig_type'] = sig_type arg_info['sig_default'] = sig_default type_lut = { 'int': int, 'float': float, 'str': str, } return_infos = [] for arg_info in arg_infos: name = arg_info['name'] if name in ignore_arg_names: continue str_type = arg_info['str_type'] sig_types = arg_info.get('sig_type', inspect._empty) arg_types = [] for sig_type in sig_types: if sig_type is not inspect._empty: arg_types.append(sig_type) resolved = type_lut.get(str_type, inspect._empty) if resolved is not inspect._empty: arg_types.append(resolved) arg_kwargs = {} if bool in arg_types: arg_kwargs.update(nargs='?', const=True) # if the only arg type is bool if len(arg_types) == 1: use_type = str_to_bool elif int in arg_types: use_type = str_to_bool_or_int elif str in arg_types: use_type = str_to_bool_or_str else: # filter out the bool as we need to use more general use_type = [at for at in arg_types if at is not bool][0] else: if len(arg_types) == 0: use_type = inspect._empty else: use_type = arg_types[0] if name == 'gpus' or name == 'tpu_cores': use_type = _gpus_allowed_type # hack for types in (int, float) if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types): use_type = _int_or_float_type # hack for track_grad_norm if name == 'track_grad_norm': use_type = float arg_info['use_type'] = use_type arg_info['arg_kwargs'] = arg_kwargs return_infos.append(arg_info) return return_infos
[docs] def add_arginfos_to_parser(parent_parser, arg_infos): import inspect for arg_info in arg_infos: name = arg_info['name'] name = arg_info['name'] arg_default = arg_info['sig_default'] use_type = arg_info['use_type'] desc = arg_info['desc'] arg_kwargs = arg_info['arg_kwargs'].copy() if use_type is not inspect._empty: arg_kwargs['type'] = use_type parent_parser.add_argument( f"--{name}", dest=name, default=arg_default, help=desc, **arg_kwargs ) return parent_parser
[docs] def add_argparse_args(cls, parent_parser): arg_infos = parse_docstring_args(cls) add_arginfos_to_parser(parent_parser, arg_infos)