Skip to content

Predictor

class Predictor

Performs model predictions on a batch of inputs.

__init__

def __init__(model, preprocessor)

Initializes a Predictor object with a trained transformer model a preprocessor.

Args
  • model (Model): Trained transformer model.

  • preprocessor (Preprocessor): Preprocessor corresponding to the model configuration.

__call__

def __call__(words, lang, batch_size)

Predicts phonemes for a list of words.

Args
  • words (list): List of words to predict.

  • lang (str): Language of texts.

  • batch_size (int): Size of batch for model input to speed up inference.

Returns
  • List[Prediction]: A list of result objects containing (word, phonemes, phoneme_tokens, token_probs, confidence)

from_checkpoint

def from_checkpoint(cls, checkpoint_path, device)

Initializes the predictor from a checkpoint (.pt file).

Args
  • checkpoint_path (str): Path to the checkpoint file (.pt).

  • device (str): Device to load the model on ('cpu' or 'cuda'). (Default value = 'cpu').

Returns
  • Predictor: Predictor object.