def make_masks_for_mt(self, src, tgt, pad_idx=0): """This method generates the masks for training the transformer model. Arguments --------- src : tensor The sequence to the encoder (required). tgt : tensor The sequence to the decoder (required). pad_idx : int The index for <pad> token (default=0). """ src_key_padding_mask = None if self.training: src_key_padding_mask = get_key_padding_mask(src, pad_idx=pad_idx) tgt_key_padding_mask = get_key_padding_mask(tgt, pad_idx=pad_idx) src_mask = None tgt_mask = get_lookahead_mask(tgt) return src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask
def make_masks( self, src, pad_idx=0, look_ahead_mask=True, padding_mask=True ): src_mask = None if look_ahead_mask: src_mask = get_lookahead_mask(src) src_key_padding_mask = None if padding_mask: src_key_padding_mask = get_key_padding_mask(src, pad_idx) return src_mask, src_key_padding_mask
def make_masks(self, src, tgt, wav_len=None, pad_idx=0): """This method generates the masks for training the transformer model. Arguments --------- src : tensor The sequence to the encoder (required). tgt : tensor The sequence to the decoder (required). pad_idx : int The index for <pad> token (default=0). """ src_key_padding_mask = None if wav_len is not None and self.training: abs_len = torch.round(wav_len * src.shape[1]) src_key_padding_mask = (1 - length_to_mask(abs_len)).bool() tgt_key_padding_mask = get_key_padding_mask(tgt, pad_idx=pad_idx) src_mask = None tgt_mask = get_lookahead_mask(tgt) return src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask