Skip to content

Model

create_model

def create_model(model_type, config)

Initializes a model from a config for a given model type.

Args
  • model_type (ModelType): Type of model to be initialized.

  • config (dict): Configuration containing hyperparams.

load_checkpoint

def load_checkpoint(checkpoint_path, device)

Initializes a model from a checkpoint (.pt file).

Args
  • checkpoint_path (str): Path to checkpoint file (.pt).

  • device (str): Device to put the model to ('cpu' or 'cuda').

Returns

class ModelType

is_autoregressive

def is_autoregressive()

Returns: bool: Whether the model is autoregressive.

class Model

__init__

def __init__()

generate

def generate(batch)

Generates phonemes for a text batch

Args
  • batch (Dict[str, torch.Tensor]): Dictionary containing 'text' (tokenized text tensor), 'text_len' (text length tensor), 'start_index' (phoneme start indices for AutoregressiveTransformer)
Returns
  • Tuple[torch.Tensor, torch.Tensor]: The predictions. The first element is a tensor (phoneme tokens)

class ForwardTransformer

__init__

def __init__(encoder_vocab_size, decoder_vocab_size, d_model, d_fft, layers, dropout, heads)

forward

def forward(batch)

Forward pass of the model on a data batch.

Args
  • batch (Dict[str, torch.Tensor]): Input batch entry 'text' (text tensor).
Returns
  • Tensor: Predictions.

generate

def generate(batch)

Inference pass on a batch of tokenized texts.

Args
  • batch (Dict[str, torch.Tensor]): Input batch with entry 'text' (text tensor).
Returns
  • Tuple: The first element is a Tensor (phoneme tokens) and the second element is a tensor (phoneme token probabilities).

from_config

def from_config(cls, config)

class AutoregressiveTransformer

__init__

def __init__(encoder_vocab_size, decoder_vocab_size, end_index, d_model, d_fft, encoder_layers, decoder_layers, dropout, heads)

forward

def forward(batch)

Foward pass of the model on a data batch.

Args
  • batch (Dict[str, torch.Tensor]): Input batch with entries 'text' (text tensor) and 'phonemes' (phoneme tensor for teacher forcing).
Returns
  • Tensor: Predictions.

generate

def generate(batch, max_len)

Inference pass on a batch of tokenized texts.

Args
  • batch (Dict[str, torch.Tensor]): Dictionary containing the input to the model with entries 'text' and 'start_index'

  • max_len (int): Max steps of the autoregressive inference loop.

Returns
  • Tuple: Predictions. The first element is a Tensor of phoneme tokens and the second element is a Tensor of phoneme token probabilities.

from_config

def from_config(cls, config)

Initializes an autoregressive Transformer model from a config. Args: config (dict): Configuration containing the hyperparams.

Returns
  • AutoregressiveTransformer: Model object.