Skip to content

Trainer

class Trainer

Performs model training.

__init__

def __init__(checkpoint_dir, device, rank, use_ddp, loss_type)

Initializes a Trainer object.

Args
  • checkpoint_dir (Path): Directory to store the model checkpoints.

  • device (torch.device): Device used for training

  • rank (int): Rank of the current device

  • use_ddp (bool): Flag whether DDP is used for training

  • loss_type (str): Type of loss

train

def train(model, checkpoint, store_phoneme_dict_in_model)

Performs training of a transformer model.

Args
  • model (Model): Model to be trained (can be a fresh model or restored from a checkpoint).

  • checkpoint (Dict[str, Any]): Dictionary with entries 'optimizer'

  • store_phoneme_dict_in_model (bool): Whether to store a dictionary of word-phoneme mappings in the model checkpoint so that it can be automatically loaded by a Phonemizer object.

Returns
  • None: the checkpoints will be stored in a folder provided when instantiating a Trainer.