Source code for utils.devices

"""Functions for working with the torch.device module"""

import torch


[docs] def cpu() -> torch.device: """Get the CPU device""" return torch.device("cpu")
[docs] def gpu(i: int = 0) -> torch.device: """Get a GPU device""" return torch.device(f"cuda:{i}")
[docs] def num_gpus() -> torch.device: """Get the number of available GPUs""" return torch.cuda.device_count()
[docs] def try_gpu(i: int = 0) -> torch.device | None: """Return gpu(i) if exists, otherwise return None""" if num_gpus() >= i + 1: return gpu(i) return None
[docs] def try_all_gpus() -> torch.device: """Return all available GPUs, or [cpu(),] if no GPU exists""" return [gpu(i) for i in range(num_gpus())]