Source code for run.train

"""
Script for background network training

Does not create windows and saves training progress to the ``log`` folder.
"""

import engine
import time


[docs] def train_spin( model: engine.Model, trainer: engine.Trainer, params_file: str, load_parameters=True, num_train_rounds=-1, num_round_epochs=60, ): """Script for background network training Does not create windows and saves training progress to the ``log`` folder. :param model: Network model. :type model: engine.Model :param trainer: Training tool. :type trainer: engine.Trainer :param params_file: Parameters file name. See :class:`engine.model.Model.load_params`. :type params_file: str :param load_parameters: If True loads parameters from a file, otherwise initializes the model again. Defaults to True. :type load_parameters: bool, optional :param num_train_rounds: Number of training rounds. If -1 the training will continue until the user stops the process. Defaults to -1. :type num_train_rounds: int, optional :param num_round_epochs: Number of epochs in one round. Defaults to 60. :type num_round_epochs: int, optional """ if load_parameters: model.load_params(params_file) idx = 1 valid = True while valid and (num_train_rounds == -1 or idx <= num_train_rounds): print(f"[INFO]: Starting round {idx} of {num_round_epochs} epoch") try: trainer.fit(num_round_epochs) except KeyboardInterrupt: print("[INFO]: Training was stopped!") valid = False except RuntimeError as exc: print("Error description: ", exc) print("[ERROR]: Training stopped due to error!") valid = False except Exception as exc: print("Error description: ", exc) print("[ERROR]: Training stopped due to unexpected error!") valid = False timestr = time.strftime("%Y%m%d-%H%M%S") model.save_params(params_file + "_" + timestr) print(f"[INFO]: Round {idx} fineshed at " + timestr) idx += 1 trainer.board.save_plot() print("[INFO]: Training complete")