def forward(self, padded_input, input_lengths): """ args: padded_input: B x T x D input_lengths: B return: output: B x T x H """ encoder_self_attn_list = [] # Prepare masks non_pad_mask = get_non_pad_mask( padded_input, input_lengths=input_lengths) # B x T x D seq_len = padded_input.size(1) self_attn_mask = get_attn_pad_mask(padded_input, input_lengths, seq_len) # B x T x T encoder_output = self.layer_norm_input(self.input_linear( padded_input)) + self.positional_encoding(padded_input) for layer in self.layers: encoder_output, self_attn = layer(encoder_output, non_pad_mask=non_pad_mask, self_attn_mask=self_attn_mask) encoder_self_attn_list += [self_attn] return encoder_output, encoder_self_attn_list
def forward(self, padded_input, encoder_padded_outputs, encoder_input_lengths): """ args: padded_input: B x T encoder_padded_outputs: B x T x H encoder_input_lengths: B returns: pred: B x T x vocab gold: B x T """ decoder_self_attn_list, decoder_encoder_attn_list = [], [] seq_in_pad, seq_out_pad = self.preprocess(padded_input) # Prepare masks non_pad_mask = get_non_pad_mask(seq_in_pad, pad_idx=constant.EOS_TOKEN) self_attn_mask_subseq = get_subsequent_mask(seq_in_pad) self_attn_mask_keypad = get_attn_key_pad_mask( seq_k=seq_in_pad, seq_q=seq_in_pad, pad_idx=constant.EOS_TOKEN) self_attn_mask = (self_attn_mask_keypad + self_attn_mask_subseq).gt(0) output_length = seq_in_pad.size(1) dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs, encoder_input_lengths, output_length) decoder_output = self.dropout( self.trg_embedding(seq_in_pad) * self.x_logit_scale + self.positional_encoding(seq_in_pad)) for layer in self.layers: decoder_output, decoder_self_attn, decoder_enc_attn = layer( decoder_output, encoder_padded_outputs, non_pad_mask=non_pad_mask, self_attn_mask=self_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask) decoder_self_attn_list += [decoder_self_attn] decoder_encoder_attn_list += [decoder_enc_attn] seq_logit = self.output_linear(decoder_output) pred, gold = seq_logit, seq_out_pad return pred, gold, decoder_self_attn_list, decoder_encoder_attn_list