Source code for geowatch.monkey.monkey_torch

[docs] def fix_gelu_issue(method): """ Torch 1.12 added an approximate parameter that our old models dont have. Monkey patch it in. References: https://github.com/pytorch/pytorch/pull/61439 """ for name, mod in method.named_modules(): if mod.__class__.__name__ == "GELU": if not hasattr(mod, "approximate"): mod.approximate = "none"
[docs] def fix_package_modules(): # Monkey Patch torch.package import sys if sys.version_info[0:2] >= (3, 10): try: from torch.package import _stdlib _stdlib._get_stdlib_modules = lambda: sys.stdlib_module_names except Exception: pass
[docs] def add_safe_globals(): """ Stop gap to allow for loading of common classes now that torch load disallows de-serialization of unknown classes. Safely handles cases where attributes might not exist in different Python/NumPy versions. References: https://github.com/huggingface/transformers/pull/34632 """ import torch import kwcoco import numpy as np # Safely get multiarray module try: np_core = np._core # Newer NumPy versions except AttributeError: np_core = np.core # Older NumPy versions # Build safe globals list dynamically safe_globals = [] # Add kwcoco.CategoryTree if available safe_globals.append(kwcoco.category_tree.CategoryTree) # Add numpy multiarray reconstruct if available safe_globals.append(np_core.multiarray._reconstruct) safe_globals.append(np_core.multiarray.scalar) # Add basic numpy types safe_globals.extend([ np.ndarray, np.dtype, ]) # Add numpy dtype classes if they exist dtype_classes = [ 'UInt8DType', 'UInt16DType', 'UInt32DType', 'UInt64DType', 'Int8DType', 'Int16DType', 'Int32DType', 'Int64DType', 'Float16DType', 'Float32DType', 'Float64DType', 'Complex64DType', 'Complex128DType', 'BoolDType' ] for dtype_name in dtype_classes: try: dtype_class = getattr(np.dtypes, dtype_name) safe_globals.append(dtype_class) except AttributeError: continue # Only try to add safe globals if we found any and torch supports it if safe_globals: try: torch.serialization.add_safe_globals(safe_globals) except AttributeError: pass