class Trainer
def __init__(max_input_len, max_output_len, batch_size, max_vocab_size_encoder, max_vocab_size_decoder, embedding_path_encoder, embedding_path_decoder, steps_per_epoch, tensorboard_dir, model_save_path, shuffle_buffer_size, use_bucketing, bucketing_buffer_size_batches, bucketing_batches_to_bucket, logging_level, num_print_predictions, steps_to_log, preprocessor)
Initializes the trainer.
max_input_len (output): Maximum length of input sequences, longer sequences will be truncated.
max_output_len (output): Maximum length of output sequences, longer sequences will be truncated.
batch_size: Size of mini-batches for stochastic gradient descent.
max_vocab_size_encoder: Maximum number of unique tokens to consider for encoder embeddings.
max_vocab_size_decoder: Maximum number of unique tokens to consider for decoder embeddings.
embedding_path_encoder: Path to embedding file for the encoder.
embedding_path_decoder: Path to embedding file for the decoder.
steps_per_epoch: Number of steps to train until callbacks are invoked.
tensorboard_dir: Directory for saving tensorboard logs.
model_save_path: Directory for saving the best model.
shuffle_buffer_size: Size of the buffer for shuffling the files before batching.
use_bucketing: Whether to bucket the sequences by length to reduce the amount of padding.
bucketing_buffer_size_batches: Number of batches to buffer when bucketing sequences.
bucketing_batches_to_bucket: Number of buffered batches from which sequences are collected for bucketing.
logging_level: Level of logging to use, e.g. logging.INFO or logging.DEBUG.
num_print_predictions: Number of sample predictions to print in each evaluation.
steps_to_log: Number of steps to wait for logging output.
preprocessor (optional): custom preprocessor, if None a standard preprocessor will be created.
def from_config(cls, file_path, **kwargs)
def train(summarizer, train_data, val_data, num_epochs, scorers, callbacks)
Trains a summarizer or resumes training of a previously initialized summarizer.
summarizer: Model to train, can be either a freshly created model or a loaded model.
train_data: Data to train the model on.
val_data (optional): Validation data.
num_epochs: Number of epochs to train.
scorers (optional): Dictionary with {score_name, scorer} to add validation scores to the logs.
callbacks (optional): Additional custom callbacks.