def forward(self, padded_input, input_lengths, return_attns=False): """ Args: padded_input: N x T x D input_lengths: N Returns: enc_output: N x T x H """ enc_slf_attn_list = [] # Prepare masks non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths) length = padded_input.size(1) slf_attn_mask = get_attn_pad_mask(padded_input, input_lengths, length) # Forward enc_output = self.dropout( self.layer_norm_in(self.linear_in(padded_input)) + self.positional_encoding(padded_input)) for enc_layer in self.layer_stack: enc_output, enc_slf_attn = enc_layer(enc_output, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask) if return_attns: enc_slf_attn_list += [enc_slf_attn] if return_attns: return enc_output, enc_slf_attn_list return enc_output,
def recognize_pit(self, padded_input, encoder_padded_outputs, encoder_input_lengths, return_attns=False): """ Args: padded_input: N x To encoder_padded_outputs: N x Ti x H Returns: """ dec_slf_attn_list, dec_enc_attn_list = [], [] # Get Deocder Input and Output ys_in_pad, ys_out_pad = self.preprocess(padded_input) # Prepare masks non_pad_mask = get_non_pad_mask(ys_in_pad, pad_idx=self.eos_id) slf_attn_mask_subseq = get_subsequent_mask(ys_in_pad) slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=ys_in_pad, seq_q=ys_in_pad, pad_idx=self.eos_id) slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) output_length = ys_in_pad.size(1) dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs, encoder_input_lengths, output_length) # Forward dec_output_input = self.dropout( self.tgt_word_emb(ys_in_pad) * self.x_logit_scale + self.positional_encoding(ys_in_pad)) for dec_layer in self.layer_stack: dec_output, dec_slf_attn, dec_enc_attn = dec_layer( dec_output_input, encoder_padded_outputs, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask) if 1 or return_attns: dec_slf_attn_list += [dec_slf_attn] dec_enc_attn_list += [dec_enc_attn] # before softmax # dec_output: bs,dec_len,512 seq_logit = self.tgt_word_prj(dec_output) # Return pred, gold = seq_logit, ys_out_pad if return_attns: return pred, gold, dec_output, dec_output_input, dec_slf_attn_list, dec_enc_attn_list[ -1].view(self.n_head, -1, 3, 751) return pred, gold, dec_output, dec_output_input