def forward(self, encoded_attentioned, target): """ Args: padded_input: N x T x D input_lengths: N Returns: enc_output: N x T x H """ # Prepare masks ys_in = self.preprocess(target) non_pad_mask = (target > 0).unsqueeze(-1) slf_attn_mask_subseq = get_subsequent_mask(ys_in) slf_attn_mask_keypad = get_attn_key_pad_mask( seq_k=ys_in, seq_q=ys_in, pad_idx=0) slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) ys_in_emb = self.dropout(self.tgt_word_emb(ys_in) + self.positional_encoding(ys_in)) dec_output = self.input_affine(torch.cat([encoded_attentioned, ys_in_emb], -1)) for dec_layer in self.layer_stack: dec_output = dec_layer( dec_output, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask) dec_output = torch.cat([encoded_attentioned, dec_output], -1) logits = self.tgt_word_prj(dec_output) return logits
def forward(self, padded_input, encoder_padded_outputs, encoder_input_lengths, return_attns=False): """ Args: padded_input: N x To encoder_padded_outputs: N x Ti x H Returns: """ dec_slf_attn_list, dec_enc_attn_list = [], [] # Get Deocder Input and Output ys_in_pad, ys_out_pad = self.preprocess(padded_input) # Prepare masks non_pad_mask = get_non_pad_mask(ys_in_pad, pad_idx=self.eos_id) slf_attn_mask_subseq = get_subsequent_mask(ys_in_pad) slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=ys_in_pad, seq_q=ys_in_pad, pad_idx=self.eos_id) slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) output_length = ys_in_pad.size(1) dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs, encoder_input_lengths, output_length) # Forward dec_output = self.dropout( self.tgt_word_emb(ys_in_pad) * self.x_logit_scale + self.positional_encoding(ys_in_pad)) for dec_layer in self.layer_stack: dec_output, dec_slf_attn, dec_enc_attn = dec_layer( dec_output, encoder_padded_outputs, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask) if return_attns: dec_slf_attn_list += [dec_slf_attn] dec_enc_attn_list += [dec_enc_attn] # before softmax seq_logit = self.tgt_word_prj(dec_output) # Return pred, gold = seq_logit, ys_out_pad if return_attns: return pred, gold, dec_slf_attn_list, dec_enc_attn_list return pred, gold
def step_forward(self, ys, encoded_attentioned, t): # -- Prepare masks non_pad_mask = torch.ones_like(ys).float().unsqueeze(-1) # 1xix1 slf_attn_mask = get_subsequent_mask(ys) # -- Forward target_emb = self.tgt_word_emb(ys) + self.positional_encoding(ys) dec_output = self.input_affine(torch.cat([encoded_attentioned[:, :t+1, :], target_emb], -1)) for dec_layer in self.layer_stack: dec_output = dec_layer( dec_output, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask) dec_output = torch.cat([encoded_attentioned[:, :t+1, :], dec_output], -1) seq_logit = self.tgt_word_prj(dec_output[:, -1]) local_scores = F.log_softmax(seq_logit, dim=1) return local_scores
def step(self, prefixs, encoded, len_encoded): non_pad_mask = torch.ones_like(prefixs).float().unsqueeze(-1) # Nxix1 slf_attn_mask = get_subsequent_mask(prefixs) output_length = prefixs.size(1) dec_enc_attn_mask = get_attn_pad_mask(len_encoded, output_length) # Forward dec_output = self.tgt_word_emb(prefixs) + self.positional_encoding(prefixs) for dec_layer in self.layer_stack: dec_output = dec_layer( dec_output, encoded, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask) # before softmax logits = self.tgt_word_prj(dec_output[:, -1, :]) scores = F.log_softmax(logits, -1) # [batch*beam, size_output] return scores
def forward(self, targets, encoder_padded_outputs, encoder_input_lengths): """ Args: padded_input: N x To encoder_padded_outputs: N x Ti x H Returns: """ # Get Deocder Input and Output targets_sos, targets_eos = self.preprocess(targets) # Prepare masks non_pad_mask = (targets_sos > 0).unsqueeze(-1) slf_attn_mask_subseq = get_subsequent_mask(targets_sos) slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=targets_sos, seq_q=targets_sos, pad_idx=0) slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) output_length = targets_sos.size(1) dec_enc_attn_mask = get_attn_pad_mask(encoder_input_lengths, output_length) # Forward dec_output = self.dropout(self.tgt_word_emb(targets_sos) + self.positional_encoding(targets_sos)) for dec_layer in self.layer_stack: dec_output = dec_layer( dec_output, encoder_padded_outputs, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask) # before softmax logits = self.tgt_word_prj(dec_output) return logits, targets_eos
def beam_decode(self, encoded, beam=5, nbest=1, maxlen=100): """Beam search, decode one utterence now. Args: encoder_outputs: T x H char_list: list of character args: args.beam Returns: nbest_hyps: """ encoded = encoded # prepare sos ys = torch.ones(1, 1).fill_(self.sos_id).type_as(encoded).long() # yseq: 1xT hyp = {'score': 0.0, 'yseq': ys} hyps = [hyp] ended_hyps = [] for i in range(maxlen): hyps_best_kept = [] for hyp in hyps: ys = hyp['yseq'] # 1 x i # -- Prepare masks non_pad_mask = torch.ones_like(ys).float().unsqueeze(-1) # 1xix1 slf_attn_mask = get_subsequent_mask(ys) # -- Forward dec_output = self.dropout( self.tgt_word_emb(ys) + self.positional_encoding(ys)) for dec_layer in self.layer_stack: dec_output, *_ = dec_layer( dec_output, encoded, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=None) seq_logit = self.tgt_word_prj(dec_output[:, -1]) local_scores = F.log_softmax(seq_logit, dim=1) # topk scores local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + local_best_scores[0, j] new_hyp['yseq'] = torch.ones(1, (1+ys.size(1))).type_as(encoded).long() new_hyp['yseq'][:, :ys.size(1)] = hyp['yseq'] new_hyp['yseq'][:, ys.size(1)] = int(local_best_ids[0, j]) # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # end for hyp in hyps hyps = hyps_best_kept # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: for hyp in hyps: hyp['yseq'] = torch.cat([hyp['yseq'], torch.ones(1, 1).fill_(self.eos_id).type_as(encoded).long()], dim=1) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][0, -1] == self.eos_id: ended_hyps.append(hyp) else: remained_hyps.append(hyp) hyps = remained_hyps if len(hyps) > 0: print('remeined hypothes: ' + str(len(hyps))) else: print('no hypothesis. Finish decoding.') break # end for i in range(maxlen) nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[ :min(len(ended_hyps), nbest)] # compitable with LAS implementation for hyp in nbest_hyps: hyp['yseq'] = hyp['yseq'][0].cpu().numpy().tolist() return nbest_hyps