Trainer
class Trainer
__init__
def __init__(max_input_len, max_output_len, batch_size, max_vocab_size_encoder, max_vocab_size_decoder, embedding_path_encoder, embedding_path_decoder, steps_per_epoch, tensorboard_dir, model_save_path, shuffle_buffer_size, use_bucketing, bucketing_buffer_size_batches, bucketing_batches_to_bucket, logging_level, num_print_predictions, steps_to_log, preprocessor)
Initializes the trainer.
Args
-
max_input_len (output): Maximum length of input sequences, longer sequences will be truncated.
-
max_output_len (output): Maximum length of output sequences, longer sequences will be truncated.
-
batch_size: Size of mini-batches for stochastic gradient descent.
-
max_vocab_size_encoder: Maximum number of unique tokens to consider for encoder embeddings.
-
max_vocab_size_decoder: Maximum number of unique tokens to consider for decoder embeddings.
-
embedding_path_encoder: Path to embedding file for the encoder.
-
embedding_path_decoder: Path to embedding file for the decoder.
-
steps_per_epoch: Number of steps to train until callbacks are invoked.
-
tensorboard_dir: Directory for saving tensorboard logs.
-
model_save_path: Directory for saving the best model.
-
shuffle_buffer_size: Size of the buffer for shuffling the files before batching.
-
use_bucketing: Whether to bucket the sequences by length to reduce the amount of padding.
-
bucketing_buffer_size_batches: Number of batches to buffer when bucketing sequences.
-
bucketing_batches_to_bucket: Number of buffered batches from which sequences are collected for bucketing.
-
logging_level: Level of logging to use, e.g. logging.INFO or logging.DEBUG.
-
num_print_predictions: Number of sample predictions to print in each evaluation.
-
steps_to_log: Number of steps to wait for logging output.
-
preprocessor (optional): custom preprocessor, if None a standard preprocessor will be created.
from_config
def from_config(cls, file_path, **kwargs)
train
def train(summarizer, train_data, val_data, num_epochs, scorers, callbacks)
Trains a summarizer or resumes training of a previously initialized summarizer.
Args
-
summarizer: Model to train, can be either a freshly created model or a loaded model.
-
train_data: Data to train the model on.
-
val_data (optional): Validation data.
-
num_epochs: Number of epochs to train.
-
scorers (optional): Dictionary with {score_name, scorer} to add validation scores to the logs.
-
callbacks (optional): Additional custom callbacks.