train
train
def train(rank, num_gpus, config_file, checkpoint_file)
Runs training of a transformer model.
Args
-
rank (int): Device id
-
num_gpus (int): Number of devices
-
config_file (str): Path to the config.yaml that stores all necessary parameters.
-
checkpoint_file (str, optional): Path to a model checkpoint to resume training for (e.g. latest_model.pt)
Returns
- None: The model checkpoints are stored in a folder provided by the config.