Skip to content

train

train

def train(rank, num_gpus, config_file, checkpoint_file)

Runs training of a transformer model.

Args
  • rank (int): Device id

  • num_gpus (int): Number of devices

  • config_file (str): Path to the config.yaml that stores all necessary parameters.

  • checkpoint_file (str, optional): Path to a model checkpoint to resume training for (e.g. latest_model.pt)

Returns
  • None: The model checkpoints are stored in a folder provided by the config.