Model
create_model
def create_model(model_type, config)
Initializes a model from a config for a given model type.
Args
-
model_type (ModelType): Type of model to be initialized.
-
config (dict): Configuration containing hyperparams.
load_checkpoint
def load_checkpoint(checkpoint_path, device)
Initializes a model from a checkpoint (.pt file).
Args
-
checkpoint_path (str): Path to checkpoint file (.pt).
-
device (str): Device to put the model to ('cpu' or 'cuda').
Returns
class ModelType
is_autoregressive
def is_autoregressive()
Returns: bool: Whether the model is autoregressive.
class Model
__init__
def __init__()
generate
def generate(batch)
Generates phonemes for a text batch
Args
- batch (Dict[str, torch.Tensor]): Dictionary containing 'text' (tokenized text tensor), 'text_len' (text length tensor), 'start_index' (phoneme start indices for AutoregressiveTransformer)
Returns
- Tuple[torch.Tensor, torch.Tensor]: The predictions. The first element is a tensor (phoneme tokens)
class ForwardTransformer
__init__
def __init__(encoder_vocab_size, decoder_vocab_size, d_model, d_fft, layers, dropout, heads)
forward
def forward(batch)
Forward pass of the model on a data batch.
Args
- batch (Dict[str, torch.Tensor]): Input batch entry 'text' (text tensor).
Returns
- Tensor: Predictions.
generate
def generate(batch)
Inference pass on a batch of tokenized texts.
Args
- batch (Dict[str, torch.Tensor]): Input batch with entry 'text' (text tensor).
Returns
- Tuple: The first element is a Tensor (phoneme tokens) and the second element is a tensor (phoneme token probabilities).
from_config
def from_config(cls, config)
class AutoregressiveTransformer
__init__
def __init__(encoder_vocab_size, decoder_vocab_size, end_index, d_model, d_fft, encoder_layers, decoder_layers, dropout, heads)
forward
def forward(batch)
Foward pass of the model on a data batch.
Args
- batch (Dict[str, torch.Tensor]): Input batch with entries 'text' (text tensor) and 'phonemes' (phoneme tensor for teacher forcing).
Returns
- Tensor: Predictions.
generate
def generate(batch, max_len)
Inference pass on a batch of tokenized texts.
Args
-
batch (Dict[str, torch.Tensor]): Dictionary containing the input to the model with entries 'text' and 'start_index'
-
max_len (int): Max steps of the autoregressive inference loop.
Returns
- Tuple: Predictions. The first element is a Tensor of phoneme tokens and the second element is a Tensor of phoneme token probabilities.
from_config
def from_config(cls, config)
Initializes an autoregressive Transformer model from a config. Args: config (dict): Configuration containing the hyperparams.
Returns
- AutoregressiveTransformer: Model object.