Source code for engine.data

"""
Interface for data module
"""

from torch.utils.data import IterableDataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch
from typing import List


[docs] class DataModule: """Class of interfaces for data modules .. warning:: This class can only be used as a base class for inheritance. read_data and get_labels methods must be overridden in the child class. """ def __init__( self, root: str = "./data", num_workers: int = 4, batch_size: int = 1, ): """ :param root: The directory where datasets are stored. Defaults to "./data". :type root: str, optional :param num_workers: A positive integer will turn on multi-process data loading with the specified number of loader worker processes. Defaults to 4. :type num_workers: int, optional :param batch_size: Number of elements in a batch. Defaults to 1. :type batch_size: int, optional """ self._root = root self._num_workers = num_workers self._train_dataset: IterableDataset = None self._test_dataset: IterableDataset = None self._val_dataset: IterableDataset = None self.batch_size = batch_size def _get_dataloader(self, batch_size: int, split="train") -> DataLoader: self._update_dataset(split) return DataLoader( self._get_dataset(split), batch_size, num_workers=self._num_workers, collate_fn=_stack_data, persistent_workers=True, ) def _get_dataset(self, split: str) -> IterableDataset: match split: case "train": return self._train_dataset case "test": return self._test_dataset case "val": return self._val_dataset case _: raise ValueError(f'The split parameter cannot be "{split}"!')
[docs] def train_dataloader(self) -> DataLoader: """Returns the training dataloader""" return self._get_dataloader(self.batch_size, split="train")
[docs] def test_dataloader(self) -> DataLoader: """Returns the test dataloader""" return self._get_dataloader(self.batch_size, split="test")
[docs] def val_dataloader(self) -> DataLoader: """Returns a validation dataloader""" return self._get_dataloader(self.batch_size, split="val")
def _update_dataset(self, split: str) -> None: if self._get_dataset(split) is None: self.read_data(split)
[docs] def read_data(self, split: str) -> None: """Read the dataset images and labels :param split: "train", "test" or "val" :type split: str """ raise NotImplementedError
[docs] def get_labels(self) -> List[str]: """Returns a list of class names""" return []
def _stack_data(batch): """Combines samples into a batch taking into account the time dimension""" features = torch.stack([sample[0] for sample in batch], dim=1) targets = pad_sequence( [sample[1] for sample in batch], batch_first=True, padding_value=-1, ) return features, targets