Utils
get_dedup_tokens
def get_dedup_tokens(logits_batch)
Converts a batch of logits into the batch most probable tokens and their probabilities.
Args
- logits_batch (Tensor): Batch of logits (N x T x V).
Returns
- Tuple: Deduplicated tokens. The first element is a tensor (token indices) and the second element
_generate_square_subsequent_mask
def _generate_square_subsequent_mask(sz)
_make_len_mask
def _make_len_mask(inp)
_get_len_util_stop
def _get_len_util_stop(sequence, end_index)
_trim_util_stop
def _trim_util_stop(sequence, end_index)
class PositionalEncoding
__init__
def __init__(d_model, dropout, max_len)
Initializes positional encoding.
Args
-
d_model (int): Dimension of model.
-
dropout (float): Dropout after positional encoding.
-
max_len: Max length of precalculated position sequence.
forward
def forward(x)