Source code for utils.devices
"""Functions for working with the torch.device module"""
import torch
[docs]
def cpu() -> torch.device:
"""Get the CPU device"""
return torch.device("cpu")
[docs]
def gpu(i: int = 0) -> torch.device:
"""Get a GPU device"""
return torch.device(f"cuda:{i}")
[docs]
def num_gpus() -> torch.device:
"""Get the number of available GPUs"""
return torch.cuda.device_count()
[docs]
def try_gpu(i: int = 0) -> torch.device | None:
"""Return gpu(i) if exists, otherwise return None"""
if num_gpus() >= i + 1:
return gpu(i)
return None
[docs]
def try_all_gpus() -> torch.device:
"""Return all available GPUs, or [cpu(),] if no GPU exists"""
return [gpu(i) for i in range(num_gpus())]