Skip to content

Utils

get_dedup_tokens

def get_dedup_tokens(logits_batch)

Converts a batch of logits into the batch most probable tokens and their probabilities.

Args
  • logits_batch (Tensor): Batch of logits (N x T x V).
Returns
  • Tuple: Deduplicated tokens. The first element is a tensor (token indices) and the second element

_generate_square_subsequent_mask

def _generate_square_subsequent_mask(sz)

_make_len_mask

def _make_len_mask(inp)

_get_len_util_stop

def _get_len_util_stop(sequence, end_index)

_trim_util_stop

def _trim_util_stop(sequence, end_index)

class PositionalEncoding

__init__

def __init__(d_model, dropout, max_len)

Initializes positional encoding.

Args
  • d_model (int): Dimension of model.

  • dropout (float): Dropout after positional encoding.

  • max_len: Max length of precalculated position sequence.

forward

def forward(x)