def decode(self, tgt, encoder_out): """This method implements a decoding step for the transformer model. Arguments --------- tgt : torch.Tensor The sequence to the decoder. encoder_out : torch.Tensor Hidden output of the encoder. """ tgt_mask = get_lookahead_mask(tgt) tgt = self.custom_tgt_module(tgt) if self.attention_type == "RelPosMHAXL": # we use fixed positional encodings in the decoder tgt = tgt + self.positional_encoding_decoder(tgt) encoder_out = encoder_out + self.positional_encoding_decoder( encoder_out) # pos_embs_target = self.positional_encoding(tgt) pos_embs_encoder = None # self.positional_encoding(src) pos_embs_target = None elif self.positional_encoding_type == "fixed_abs_sine": tgt = tgt + self.positional_encoding(tgt) # add the encodings here pos_embs_target = None pos_embs_encoder = None prediction, self_attns, multihead_attns = self.decoder( tgt, encoder_out, tgt_mask=tgt_mask, pos_embs_tgt=pos_embs_target, pos_embs_src=pos_embs_encoder, ) return prediction, multihead_attns[-1]
def make_masks( self, src, pad_idx=0, look_ahead_mask=True, padding_mask=True ): src_mask = None if look_ahead_mask: src_mask = get_lookahead_mask(src) src_key_padding_mask = None if padding_mask: src_key_padding_mask = get_key_padding_mask(src, pad_idx) return src_mask, src_key_padding_mask
def decode(self, tgt, encoder_out): """This method implements a decoding step for the transformer model. Arguments --------- tgt : tensor The sequence to the decoder (required). encoder_out : tensor Hidden output of the encoder (required). """ tgt_mask = get_lookahead_mask(tgt) tgt = self.custom_tgt_module(tgt) tgt = tgt + self.positional_encoding(tgt) prediction, self_attns, multihead_attns = self.decoder( tgt, encoder_out, tgt_mask=tgt_mask) return prediction, multihead_attns[-1]
def forward(self, x, src_key_padding_mask=None): if self.causal: self.attn_mask = get_lookahead_mask(x) else: self.attn_mask = None if self.custom_emb_module is not None: x = self.custom_emb_module(x) encoder_output, _ = self.encoder( src=x, src_mask=self.attn_mask, src_key_padding_mask=src_key_padding_mask, ) output = self.output_layer(encoder_output) output = self.output_activation(output) return output
def make_masks(self, src, tgt, wav_len=None, pad_idx=0): """This method generates the masks for training the transformer model. Arguments --------- src : tensor The sequence to the encoder (required). tgt : tensor The sequence to the decoder (required). pad_idx : int The index for <pad> token (default=0). """ src_key_padding_mask = None if wav_len is not None and self.training: abs_len = torch.round(wav_len * src.shape[1]) src_key_padding_mask = (1 - length_to_mask(abs_len)).bool() tgt_key_padding_mask = get_key_padding_mask(tgt, pad_idx=pad_idx) src_mask = None tgt_mask = get_lookahead_mask(tgt) return src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask
def make_masks_for_mt(self, src, tgt, pad_idx=0): """This method generates the masks for training the transformer model. Arguments --------- src : tensor The sequence to the encoder (required). tgt : tensor The sequence to the decoder (required). pad_idx : int The index for <pad> token (default=0). """ src_key_padding_mask = None if self.training: src_key_padding_mask = get_key_padding_mask(src, pad_idx=pad_idx) tgt_key_padding_mask = get_key_padding_mask(tgt, pad_idx=pad_idx) src_mask = None tgt_mask = get_lookahead_mask(tgt) return src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask