Trainer

class Trainer

__init__

def __init__(max_input_len, max_output_len, batch_size, max_vocab_size_encoder, max_vocab_size_decoder, embedding_path_encoder, embedding_path_decoder, steps_per_epoch, tensorboard_dir, model_save_path, shuffle_buffer_size, use_bucketing, bucketing_buffer_size_batches, bucketing_batches_to_bucket, logging_level, num_print_predictions, steps_to_log, preprocessor)

Initializes the trainer.

Args
  • max_input_len (output): Maximum length of input sequences, longer sequences will be truncated.

  • max_output_len (output): Maximum length of output sequences, longer sequences will be truncated.

  • batch_size: Size of mini-batches for stochastic gradient descent.

  • max_vocab_size_encoder: Maximum number of unique tokens to consider for encoder embeddings.

  • max_vocab_size_decoder: Maximum number of unique tokens to consider for decoder embeddings.

  • embedding_path_encoder: Path to embedding file for the encoder.

  • embedding_path_decoder: Path to embedding file for the decoder.

  • steps_per_epoch: Number of steps to train until callbacks are invoked.

  • tensorboard_dir: Directory for saving tensorboard logs.

  • model_save_path: Directory for saving the best model.

  • shuffle_buffer_size: Size of the buffer for shuffling the files before batching.

  • use_bucketing: Whether to bucket the sequences by length to reduce the amount of padding.

  • bucketing_buffer_size_batches: Number of batches to buffer when bucketing sequences.

  • bucketing_batches_to_bucket: Number of buffered batches from which sequences are collected for bucketing.

  • logging_level: Level of logging to use, e.g. logging.INFO or logging.DEBUG.

  • num_print_predictions: Number of sample predictions to print in each evaluation.

  • steps_to_log: Number of steps to wait for logging output.

  • preprocessor (optional): custom preprocessor, if None a standard preprocessor will be created.

from_config

def from_config(cls, file_path, **kwargs)

train

def train(summarizer, train_data, val_data, num_epochs, scorers, callbacks)

Trains a summarizer or resumes training of a previously initialized summarizer.

Args
  • summarizer: Model to train, can be either a freshly created model or a loaded model.

  • train_data: Data to train the model on.

  • val_data (optional): Validation data.

  • num_epochs: Number of epochs to train.

  • scorers (optional): Dictionary with {score_name, scorer} to add validation scores to the logs.

  • callbacks (optional): Additional custom callbacks.