[docs]
def coerce_devices(gpus):
"""
Coerce a command line argument for GPUs into a valid set of torch devices
This is a wrapper around lightning
:func:`pytorch_lightning.utilities.parse_gpu_ids` (which was deprecated in
lighting 1.8 so we have to vendor it)
It extends the cases
that it can handle and is specific to torch devices. As of lightning 1.6
their own device parsing is pretty good, so this may not be necessary.
If `gpus` is a list of integers, then those devices are used.
If `gpus` is None or "cpu", then the CPU is used.
If `gpus` is "cuda", that is equivalent to `gpus=[0]`.
If `gpus` is a string without commas, then the string should be of a number
indicating how many gpus should be used.
If `gpus` is a string with commas separating integers, then that
indicates the device indexes that should be used.
Args:
gpus (List[int] | str | int | None):
adds ability to parse "cpu", "auto", "auto:N".
Returns:
List[torch.device]
Example:
>>> from geowatch.utils.lightning_ext import util_device
>>> print(util_device.coerce_devices('cpu'))
>>> print(util_device.coerce_devices(None))
>>> # xdoctest: +SKIP
>>> # breaks without a cuda machine
>>> #print(util_device.coerce_devices("0"))
>>> print(util_device.coerce_devices("1"))
>>> print(util_device.coerce_devices("0"))
>>> print(util_device.coerce_devices("0,"))
>>> print(util_device.coerce_devices(1))
>>> print(util_device.coerce_devices([0, 1]))
>>> print(util_device.coerce_devices("0, 1"))
>>> print(util_device.coerce_devices("auto"))
>>> if torch.cuda.device_count() > 0:
>>> print(util_device.coerce_devices("auto:1"))
>>> if torch.cuda.device_count() > 1:
>>> print(util_device.coerce_devices("auto:2"))
>>> if torch.cuda.device_count() > 2:
>>> print(util_device.coerce_devices("auto:3"))
"""
import torch
needs_gpu_coerce = True
auto_select_gpus = False
if isinstance(gpus, str):
if gpus == 'cpu':
gpu_ids = None
needs_gpu_coerce = False
elif gpus == 'cuda':
gpu_ids = [0]
needs_gpu_coerce = False
elif gpus.startswith('auto'):
auto_select_gpus = True
parts = gpus.split(':')
if len(parts) == 1:
gpus = -1
else:
gpus = int(parts[1])
else:
try:
gpus = [int(p.strip()) for p in gpus.split(',') if p.strip()]
except Exception:
pass
needs_gpu_coerce = True
print(f'gpus={gpus}')
print(f'auto_select_gpus={auto_select_gpus}')
if auto_select_gpus:
from pytorch_lightning.tuner import auto_gpu_select
gpu_ids = auto_gpu_select.pick_multiple_gpus(gpus)
elif needs_gpu_coerce:
try:
from geowatch.utils.lightning_ext import old_parser_devices
# from pytorch_lightning.utilities import device_parser
gpu_ids = old_parser_devices.parse_gpu_ids(gpus)
except Exception as ex:
print(f'WARNING. Ignoring ex={ex}')
gpu_ids = gpus
import ubelt as ub
if gpu_ids is not None and ub.iterable(gpu_ids):
assert all(isinstance(g, int) for g in gpu_ids)
if gpu_ids is None:
devices = [torch.device('cpu')]
else:
devices = [torch.device(_id) for _id in gpu_ids]
return devices
def _test_lightning_is_sane():
from geowatch.utils.lightning_ext import old_parser_devices as device_parser
# from pytorch_lightning.utilities import device_parser
import torch
num_devices = torch.cuda.device_count()
assert device_parser.parse_gpu_ids('0') is None
assert device_parser.parse_gpu_ids('[]') is None
assert device_parser.parse_gpu_ids(0) is None
if num_devices > 0:
assert device_parser.parse_gpu_ids('1') == [0]
assert device_parser.parse_gpu_ids(1) == [0]
assert device_parser.parse_gpu_ids('0,') == [0]
if num_devices > 1:
assert device_parser.parse_gpu_ids('2') == [0, 1]
assert device_parser.parse_gpu_ids(2) == [0, 1]
assert device_parser.parse_gpu_ids([0 , 1]) == [0, 1]
assert device_parser.parse_gpu_ids('0, 1') == [0, 1]
if num_devices:
assert device_parser.parse_gpu_ids(-1) == list(range(num_devices))
def _test_coerce_is_sane():
import torch
num_devices = torch.cuda.device_count()
if num_devices:
all_devices = [torch.device(i) for i in range(num_devices)]
assert coerce_devices('-1') == all_devices
assert coerce_devices(-1) == all_devices
assert coerce_devices('auto') == all_devices
assert coerce_devices('0') == [torch.device('cpu')]
assert coerce_devices('[]') == [torch.device('cpu')]
assert coerce_devices(0) == [torch.device('cpu')]
if num_devices > 0:
assert coerce_devices('1') == [torch.device(0)]
assert coerce_devices(1) == [torch.device(0)]
assert coerce_devices('0,') == [torch.device(0)]
assert coerce_devices('auto:1') == [torch.device(0)]
if num_devices > 1:
assert coerce_devices('2') == [torch.device(0), torch.device(1)]
assert coerce_devices(2) == [torch.device(0), torch.device(1)]
assert coerce_devices([0 , 1]) == [torch.device(0), torch.device(1)]
assert coerce_devices('0, 1') == [torch.device(0), torch.device(1)]
assert coerce_devices('auto:2') == [torch.device(0), torch.device(1)]