engine.trainer.Trainer#

class Trainer(board: ProgressBoard, gpu_index: int = 0, epoch_size: int = 60)[source]#

Bases: object

Class for training a model on a selected dataset

Parameters:
  • board (ProgressBoard) – The board that plots data points in animation.

  • gpu_index (int, optional) – CUDA index for the GPU selected for training. See CUDA semantics. Defaults to 0.

  • epoch_size (int, optional) – Size of one epoch, defaults to 60

Methods

fit

Begins training the model

fit_epoch

Starts one epoch of model training

predict

Returns the network's prediction for a random sample

prepare

Prepares the model and data module

stop

Interrupts training

test

Starts one epoch of model testing

validation

Starts one epoch of model evaluation

fit(num_epochs: int = 1) None[source]#

Begins training the model

Parameters:

num_epochs (int, optional) – Number of training epochs, defaults to 1.

fit_epoch() None[source]#

Starts one epoch of model training

Error values are saved in utils.progress_board.ProgressBoard, progress is displayed in console.

predict() tuple[Tensor, Tensor, Tensor][source]#

Returns the network’s prediction for a random sample

Returns:

Three tensors: data, predictions and targets.

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]

prepare(model: Model, data: DataModule) None[source]#

Prepares the model and data module

Must be called before training begins.

Parameters:
  • model (Model) – Model for training

  • data (DataModule) – Data used for training

stop() None[source]#

Interrupts training

The state is saved and training can be continued.

test()[source]#

Starts one epoch of model testing

Error values are saved in utils.progress_board.ProgressBoard, progress is displayed in console.

validation()[source]#

Starts one epoch of model evaluation

Error values are saved in utils.progress_board.ProgressBoard, progress is displayed in console.