Source code for engine.trainer

"""
Implements tools for training neural networks
"""

import torch
from engine.data import DataModule
from engine.model import Model
from utils import devices
from utils.progress_board import ProgressBoard
from tqdm import tqdm


[docs] class Trainer: """Class for training a model on a selected dataset""" _train_batch_idx: int = 0 _test_batch_idx: int = 0 _val_batch_idx: int = 0 _stop_flag: bool = False def __init__( self, board: ProgressBoard, gpu_index: int = 0, epoch_size: int = 60 ) -> None: """ :param board: The board that plots data points in animation. :type board: ProgressBoard :param gpu_index: CUDA index for the GPU selected for training. See :external:ref:`CUDA semantics <cuda-semantics>`. Defaults to 0. :type gpu_index: int, optional :param epoch_size: Size of one epoch, defaults to 60 :type epoch_size: int, optional """ self.board, self.epoch_size = board, epoch_size self.gpu = devices.try_gpu(gpu_index)
[docs] def prepare(self, model: Model, data: DataModule) -> None: """Prepares the model and data module Must be called before training begins. :param model: Model for training :type model: Model :param data: Data used for training :type data: DataModule """ self._prepare_data(data) self._prepare_model(model) self.optim = model.configure_optimizers()
def _prepare_data(self, data: DataModule) -> None: self.train_dataloader = data.train_dataloader() self.test_dataloader = data.test_dataloader() self.val_dataloader = data.val_dataloader() self.train_dataloader_iter = iter(self.train_dataloader) self.test_dataloader_iter = iter(self.test_dataloader) self.val_dataloader_iter = iter(self.val_dataloader) def _prepare_model(self, model: Model) -> None: if self.gpu: model.to(self.gpu) self.model = model def _prepare_batch( self, batch: tuple[torch.Tensor, torch.Tensor] ) -> tuple[torch.Tensor, torch.Tensor]: if not self.gpu: return batch return batch[0].to(self.gpu), batch[1].to(self.gpu) def _plot(self, loss: torch.Tensor, split: str) -> None: match split: case "train": x = self._train_batch_idx case "test": x = ( self._train_batch_idx - self.epoch_size + self._test_batch_idx % self.epoch_size ) case "val": x = ( self._train_batch_idx - self.epoch_size + self._val_batch_idx % self.epoch_size ) case _: raise ValueError(f'The split parameter cannot be "{split}"!') self.board.draw( x, loss.to(devices.cpu()).item(), split + " loss", )
[docs] def stop(self) -> None: """Interrupts training The state is saved and training can be continued. """ self._stop_flag = True
[docs] def fit(self, num_epochs: int = 1) -> None: """Begins training the model :param num_epochs: Number of training epochs, defaults to 1. :type num_epochs: int, optional """ self._stop_flag = False for self.epoch in tqdm(range(num_epochs), leave=False, desc="[Epoch]"): if self._stop_flag: return self.fit_epoch() self.test()
[docs] def fit_epoch(self) -> None: """Starts one epoch of model training Error values are saved in :class:`utils.progress_board.ProgressBoard`, progress is displayed in console. """ self.model.train() for _ in (pbar := tqdm(range(self.epoch_size), leave=False, desc="[Train]")): if self._stop_flag: return batch = next(self.train_dataloader_iter) train_loss = self.model.training_step(self._prepare_batch(batch)) self.optim.zero_grad() with torch.no_grad(): train_loss.backward() self.optim.step() self._plot(train_loss, split="train") pbar.set_description("[Train] Loss %.4f / Progress " % train_loss) self._train_batch_idx += 1
[docs] def test(self): """Starts one epoch of model testing Error values are saved in :class:`utils.progress_board.ProgressBoard`, progress is displayed in console. """ self.model.eval() for _ in tqdm(range(self.epoch_size), leave=False, desc="[Test]"): if self._stop_flag: return batch = next(self.test_dataloader_iter) with torch.no_grad(): test_loss = self.model.test_step(self._prepare_batch(batch)) self._plot(test_loss, split="test") self._test_batch_idx += 1
[docs] def validation(self): """Starts one epoch of model evaluation Error values are saved in :class:`utils.progress_board.ProgressBoard`, progress is displayed in console.""" self.model.eval() for _ in tqdm(range(self.epoch_size), leave=False, desc="[Val] "): if self._stop_flag: return batch = next(self.val_dataloader_iter) with torch.no_grad(): val_loss = self.model.validation_step(self._prepare_batch(batch)) self._plot(val_loss, split="val") self._val_batch_idx += 1
[docs] def predict(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Returns the network's prediction for a random sample :return: Three tensors: data, predictions and targets. :rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor] """ with torch.no_grad(): tensors, targets = next(self.test_dataloader_iter) if self.gpu: tensors = tensors.to(self.gpu) predictions = self.model.predict(tensors).to(devices.cpu()) if self.gpu: tensors = tensors.to(devices.cpu()) return tensors, predictions, targets