geowatch.utils.lightning_ext.util_device module¶
- geowatch.utils.lightning_ext.util_device.coerce_devices(gpus)[source]¶
Coerce a command line argument for GPUs into a valid set of torch devices
This is a wrapper around lightning
pytorch_lightning.utilities.parse_gpu_ids()(which was deprecated in lighting 1.8 so we have to vendor it)It extends the cases that it can handle and is specific to torch devices. As of lightning 1.6 their own device parsing is pretty good, so this may not be necessary.
If gpus is a list of integers, then those devices are used.
If gpus is None or “cpu”, then the CPU is used.
If gpus is “cuda”, that is equivalent to gpus=[0].
- If gpus is a string without commas, then the string should be of a number
indicating how many gpus should be used.
- If gpus is a string with commas separating integers, then that
indicates the device indexes that should be used.
- Parameters:
gpus (List[int] | str | int | None) – adds ability to parse “cpu”, “auto”, “auto:N”.
- Returns:
List[torch.device]
Note
Prefer
coerce_accelerator_devices()over this, but even that new function may change in the future.Example
>>> from geowatch.utils.lightning_ext import util_device >>> print(util_device.coerce_devices('cpu')) >>> print(util_device.coerce_devices(None)) >>> # xdoctest: +SKIP >>> # breaks without a cuda machine >>> #print(util_device.coerce_devices("0")) >>> print(util_device.coerce_devices("1")) >>> print(util_device.coerce_devices("0")) >>> print(util_device.coerce_devices("0,")) >>> print(util_device.coerce_devices(1)) >>> print(util_device.coerce_devices([0, 1])) >>> print(util_device.coerce_devices("0, 1")) >>> print(util_device.coerce_devices("auto")) >>> if torch.cuda.device_count() > 0: >>> print(util_device.coerce_devices("auto:1")) >>> if torch.cuda.device_count() > 1: >>> print(util_device.coerce_devices("auto:2")) >>> if torch.cuda.device_count() > 2: >>> print(util_device.coerce_devices("auto:3"))
- geowatch.utils.lightning_ext.util_device.coerce_accelerator_devices(accelerator, devices, _use_private_api=False)[source]¶
A simplified version of lightning’s accelerator connector, which is currently non a public API, so we are avoiding depending on it. If it becomes a public API we may just use it.
Example
>>> from geowatch.utils.lightning_ext.util_device import * # NOQA >>> coerce_accelerator_devices('cpu', 1) [device(type='cpu')]
Example
>>> import ubelt as ub >>> from geowatch.utils.lightning_ext.util_device import * # NOQA >>> import torch >>> results = [] >>> if True: >>> basis = { >>> 'accelerator': ['cpu'], >>> 'devices': ['auto', 1], >>> '_use_private_api': [0, 1], >>> } >>> for kwargs in ub.named_product(basis): >>> row = {**kwargs, 'result': coerce_accelerator_devices(**kwargs)} >>> results.append(row) >>> if True: >>> basis = { >>> 'accelerator': ['auto'], >>> 'devices': ['auto'], >>> '_use_private_api': [0, 1], >>> } >>> for kwargs in ub.named_product(basis): >>> row = {**kwargs, 'result': coerce_accelerator_devices(**kwargs)} >>> results.append(row) >>> if torch.cuda.is_available(): >>> basis = { >>> 'accelerator': ['cuda'], >>> 'devices': [1, (0,), '0,'], >>> '_use_private_api': [0, 1], >>> } >>> for kwargs in ub.named_product(basis): >>> row = {**kwargs, 'result': coerce_accelerator_devices(**kwargs)} >>> results.append(row) >>> if torch.cuda.is_available() and torch.cuda.device_count() > 1: >>> basis = { >>> 'accelerator': ['cuda'], >>> 'devices': [1, (1, 0), (1,), '1,'], >>> '_use_private_api': [0, 1], >>> } >>> for kwargs in ub.named_product(basis): >>> row = {**kwargs, 'result': coerce_accelerator_devices(**kwargs)} >>> results.append(row) >>> print(f'results = {ub.urepr(results, nl=1)}') >>> groups = ub.group_items(results, key=lambda x: (x['accelerator'], x['devices'])) >>> # check that our function matches the lightning API >>> for group in groups.values(): ... assert ub.allsame([g['result'] for g in group])