Trainer
class Trainer
Performs model training.
__init__
def __init__(checkpoint_dir, device, rank, use_ddp, loss_type)
Initializes a Trainer object.
Args
-
checkpoint_dir (Path): Directory to store the model checkpoints.
-
device (torch.device): Device used for training
-
rank (int): Rank of the current device
-
use_ddp (bool): Flag whether DDP is used for training
-
loss_type (str): Type of loss
train
def train(model, checkpoint, store_phoneme_dict_in_model)
Performs training of a transformer model.
Args
-
model (Model): Model to be trained (can be a fresh model or restored from a checkpoint).
-
checkpoint (Dict[str, Any]): Dictionary with entries 'optimizer'
-
store_phoneme_dict_in_model (bool): Whether to store a dictionary of word-phoneme mappings in the model checkpoint so that it can be automatically loaded by a Phonemizer object.
Returns
- None: the checkpoints will be stored in a folder provided when instantiating a Trainer.