"""
This module is an exension of jsonargparse and lightning CLI that will respect
scriptconfig style arguments
References:
https://github.com/Lightning-AI/lightning/issues/15038
"""
from kwutil.util_environ import envflag
from packaging.version import parse as Version
if envflag('USE_PATCHED_JSONARGPARSE'):
# Use old patching system, instead of new forking system
import jsonargparse
try:
from pytorch_lightning.cli import ActionConfigFile
except Exception:
from jsonargparse import ActionConfigFile # NOQA
from pytorch_lightning.cli import LightningArgumentParser
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.cli import Namespace
JSONARGPARSE_VERSION = Version(jsonargparse.__version__)
if Version('4.0.0') <= JSONARGPARSE_VERSION < Version('4.21.0'):
from geowatch.utils.lightning_ext import _jsonargparse_ext_ge_4_00_and_lt_4_21 as _jsonargparse_ext
elif Version('4.21.0') <= JSONARGPARSE_VERSION < Version('4.22.0'):
from geowatch.utils.lightning_ext import _jsonargparse_ext_ge_4_21_and_lt_4_22 as _jsonargparse_ext
elif Version('4.22.0') <= JSONARGPARSE_VERSION < Version('4.24.0'):
from geowatch.utils.lightning_ext import _jsonargparse_ext_ge_4_22_and_lt_4_24 as _jsonargparse_ext
elif Version('4.24.0') <= JSONARGPARSE_VERSION < Version('4.24.2'):
from geowatch.utils.lightning_ext import _jsonargparse_ext_ge_4_24_and_lt_4_24_2 as _jsonargparse_ext
elif Version('4.30.0') <= JSONARGPARSE_VERSION < Version('5.0.0'):
from geowatch.utils.lightning_ext import _jsonargparse_ext_ge_4_30_and_lt_5_xx as _jsonargparse_ext
else:
...
if JSONARGPARSE_VERSION < Version('4.30.0'):
class LightningArgumentParser_Extension(_jsonargparse_ext.ArgumentParserPatches, LightningArgumentParser):
"""
CommandLine:
xdoctest -m geowatch.utils.lightning_ext.lightning_cli_ext LightningArgumentParser_Extension
Example:
>>> from geowatch.utils.lightning_ext.lightning_cli_ext import * # NOQA
>>> LightningArgumentParser_Extension()
Refactor references:
~/.pyenv/versions/3.10.5/envs/pyenv3.10.5/lib/python3.10/site-packages/pytorch_lightning/cli.py
~/.pyenv/versions/3.10.5/envs/pyenv3.10.5/lib/python3.10/site-packages/jsonargparse/core.py
~/.pyenv/versions/3.10.5/envs/pyenv3.10.5/lib/python3.10/site-packages/jsonargparse/signatures.py
"""
# Monkey patch jsonargparse so its subcommands use our extended functionality
jsonargparse.ArgumentParser = LightningArgumentParser_Extension
if JSONARGPARSE_VERSION < Version('4.22.0'):
jsonargparse.core.ArgumentParser = LightningArgumentParser_Extension
jsonargparse.core._find_action_and_subcommand = _jsonargparse_ext._find_action_and_subcommand
jsonargparse.actions._find_action_and_subcommand = _jsonargparse_ext._find_action_and_subcommand
elif JSONARGPARSE_VERSION < Version('4.22.0'):
jsonargparse._core.ArgumentParser = LightningArgumentParser_Extension
jsonargparse._core._find_action_and_subcommand = _jsonargparse_ext._find_action_and_subcommand
jsonargparse._actions._find_action_and_subcommand = _jsonargparse_ext._find_action_and_subcommand
else:
# Monkey patching is simpler in newer versions
_jsonargparse_ext.apply_monkeypatch()
LightningArgumentParser_Extension = LightningArgumentParser
else:
# Use new TPL forking systems with a pre-patched jsonargparse
import geowatch_tpl
import sys
jsonargparse = geowatch_tpl.import_submodule('jsonargparse_fork')
assert 'pytorch_lightning.cli' not in sys.modules, (
'We need to import our fork of jsonargparse before we import lightning CLI')
sys.modules['jsonargparse'] = jsonargparse
print(f'jsonargparse={jsonargparse}')
try:
from pytorch_lightning.cli import ActionConfigFile
except Exception:
from jsonargparse_fork import ActionConfigFile # NOQA
from pytorch_lightning.cli import LightningArgumentParser
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.cli import Namespace
LightningArgumentParser_Extension = LightningArgumentParser
# Should try to patch into upstream
[docs]
class LightningCLI_Extension(LightningCLI):
"""
Our customized :class:`LightningCLI` extension.
"""
...
[docs]
def init_parser(self, **kwargs):
# Hack in our modified parser
DEBUG = 0
if DEBUG:
kwargs['error_handler'] = None
import pytorch_lightning as pl
kwargs.setdefault("dump_header", [f"pytorch_lightning=={pl.__version__}"])
parser = LightningArgumentParser_Extension(**kwargs)
parser.add_argument(
"-c", "--config", action=ActionConfigFile,
help="Path to a configuration file in json or yaml format."
)
return parser
[docs]
def parse_arguments(self, parser: LightningArgumentParser, args) -> None:
"""Parses command line arguments and stores it in ``self.config``."""
import sys
if args is not None and len(sys.argv) > 1:
# Please let us shoot ourselves in the foot.
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
# import warnings
rank_zero_warn(
"LightningCLI's args parameter is intended to run from within Python like if it were from the command "
"line. To prevent mistakes it is not recommended to provide both args and command line arguments, got: "
f"sys.argv[1:]={sys.argv[1:]}, args={args}."
)
if isinstance(args, (dict, Namespace)):
self.config = parser.parse_object(args)
else:
self.config = parser.parse_args(args, _skip_check=True)
# def _add_instantiators(self) -> None:
# import yaml
# from pytorch_lightning.cli import _InstantiatorFn
# from pytorch_lightning.cli import _get_module_type
# self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_check=True, skip_link_targets=False, skip_none=False))
# if "subcommand" in self.config:
# self.config_dump = self.config_dump[self.config.subcommand]
# self.parser.add_instantiator(
# _InstantiatorFn(cli=self, key="model"),
# _get_module_type(self._model_class),
# subclasses=self.subclass_mode_model,
# )
# self.parser.add_instantiator(
# _InstantiatorFn(cli=self, key="data"),
# _get_module_type(self._datamodule_class),
# subclasses=self.subclass_mode_data,
# )