def recognize_pit(self, padded_input, encoder_padded_outputs, encoder_input_lengths, return_attns=False): """ Args: padded_input: N x To encoder_padded_outputs: N x Ti x H Returns: """ dec_slf_attn_list, dec_enc_attn_list = [], [] # Get Deocder Input and Output ys_in_pad, ys_out_pad = self.preprocess(padded_input) # Prepare masks non_pad_mask = get_non_pad_mask(ys_in_pad, pad_idx=self.eos_id) slf_attn_mask_subseq = get_subsequent_mask(ys_in_pad) slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=ys_in_pad, seq_q=ys_in_pad, pad_idx=self.eos_id) slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) output_length = ys_in_pad.size(1) dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs, encoder_input_lengths, output_length) # Forward dec_output_input = self.dropout( self.tgt_word_emb(ys_in_pad) * self.x_logit_scale + self.positional_encoding(ys_in_pad)) for dec_layer in self.layer_stack: dec_output, dec_slf_attn, dec_enc_attn = dec_layer( dec_output_input, encoder_padded_outputs, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask) if 1 or return_attns: dec_slf_attn_list += [dec_slf_attn] dec_enc_attn_list += [dec_enc_attn] # before softmax # dec_output: bs,dec_len,512 seq_logit = self.tgt_word_prj(dec_output) # Return pred, gold = seq_logit, ys_out_pad if return_attns: return pred, gold, dec_output, dec_output_input, dec_slf_attn_list, dec_enc_attn_list[ -1].view(self.n_head, -1, 3, 751) return pred, gold, dec_output, dec_output_input
def sample_one(self, input, soft_score, tmp_hiddens, contexts, mask): input = input.unsqueeze(0) non_pad_mask = torch.ones_like(input).float().unsqueeze(-1) # 1xix1 slf_attn_mask = get_subsequent_mask(input) dec_output = self.dropout( self.tgt_word_emb(input) * self.x_logit_scale + self.positional_encoding(input)) for dec_layer in self.layer_stack: dec_output, _, _ = dec_layer(dec_output, contexts, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=None) output = self.compute_score(dec_output, targets=None) # print(output) if 1: if mask is not None: output = output.scatter_(1, mask, -9999999999) # return output, state, attn_weigths, hidden, emb return output, dec_output, dec_output, dec_output, dec_output
def recognize_beam(self, encoder_outputs, char_list, args): """Beam search, decode one utterence now. Args: encoder_outputs: T x H char_list: list of character args: args.beam Returns: nbest_hyps: """ # search params beam = 5 nbest = 1 # if args.decode_max_len == 0: # maxlen = encoder_outputs.size(0) # else: # maxlen = args.decode_max_len maxlen = 5 # encoder_outputs = encoder_outputs.unsqueeze(0) # prepare sos ys = torch.ones(1, 1).fill_(self.sos_id).type_as(encoder_outputs).long() # yseq: 1xT hyp = {'score': 0.0, 'yseq': ys} hyps = [hyp] ended_hyps = [] for i in range(maxlen): hyps_best_kept = [] for hyp in hyps: ys = hyp['yseq'] # 1 x i # -- Prepare masks non_pad_mask = torch.ones_like(ys).float().unsqueeze( -1) # 1xix1 slf_attn_mask = get_subsequent_mask(ys) # -- Forward dec_output_input = self.dropout( self.tgt_word_emb(ys) * self.x_logit_scale + self.positional_encoding(ys)) for dec_layer in self.layer_stack: dec_output, _, _ = dec_layer(dec_output_input, encoder_outputs, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=None) seq_logit = self.tgt_word_prj(dec_output[:, -1]) local_scores = F.log_softmax(seq_logit, dim=1) # topk scores local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1) for j in range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + local_best_scores[0, j] new_hyp['yseq'] = torch.ones( 1, (1 + ys.size(1))).type_as(encoder_outputs).long() new_hyp['yseq'][:, :ys.size(1)] = hyp['yseq'] new_hyp['yseq'][:, ys.size(1)] = int(local_best_ids[0, j]) new_hyp['dec_hiddens'] = dec_output new_hyp['dec_embs_input'] = dec_output_input # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # end for hyp in hyps hyps = hyps_best_kept # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: for hyp in hyps: hyp['yseq'] = torch.cat([ hyp['yseq'], torch.ones(1, 1).fill_( self.eos_id).type_as(encoder_outputs).long() ], dim=1) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][0, -1] == self.eos_id: ended_hyps.append(hyp) else: remained_hyps.append(hyp) hyps = remained_hyps if len(hyps) > 0: print('remeined hypothes: ' + str(len(hyps))) else: print('no hypothesis. Finish decoding.') break for hyp in hyps: print( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][0, 1:]]), hyp['score']) # end for i in range(maxlen) tmp_list = [] for hyp in ended_hyps: if len(set(list(hyp['yseq'][0].data.cpu().numpy()))) == len( hyp['yseq'][0]): tmp_list.append(hyp) else: print('Repeated Predtion:', hyp) ended_hyps = tmp_list nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), nbest)] # compitable with LAS implementation for hyp in nbest_hyps: hyp['yseq'] = hyp['yseq'][0].cpu().numpy().tolist() return nbest_hyps
def recognize_beam_greddy(self, encoder_outputs, char_list, tgt): """Beam search, decode one utterence now. Args: encoder_outputs: T x H char_list: list of character args: args.beam tgt: bs,dec_len(4, BOS,SPK1,SPK2,EOS) Returns: nbest_hyps: """ bs = encoder_outputs.shape[0] # search params beam = 1 nbest = 1 # if args.decode_max_len == 0: # maxlen = encoder_outputs.size(0) # else: # maxlen = args.decode_max_len maxlen = 5 # encoder_outputs = encoder_outputs.unsqueeze(0) # prepare sos ys = torch.ones(bs, 1).fill_(self.sos_id).type_as(encoder_outputs).long() # yseq: 1xT hyp = {'score': 0.0, 'yseq': ys} hyps = [hyp] ended_hyps = [] # remained_tgts=tgt[:,1:-1] for i in range(maxlen): hyps_best_kept = [] for hyp in hyps: ys = hyp['yseq'] # 1 x i # -- Prepare masks non_pad_mask = torch.ones_like(ys).float().unsqueeze( -1) # 1xix1 slf_attn_mask = get_subsequent_mask(ys) # -- Forward ys_seq = torch.from_numpy(np.array([list(range(i + 1)) ])).cuda() #按照seq排序的那个 dec_output_input = self.dropout( self.tgt_word_emb(ys_seq) * self.x_logit_scale + self.positional_encoding(ys_seq)) for dec_layer in self.layer_stack: dec_output, _, _ = dec_layer(dec_output_input, encoder_outputs, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=None) seq_logit = self.tgt_word_prj(dec_output[:, -1]) if 1: # try to stop repeated predition for yy, logit in zip(ys, seq_logit): logit[yy] = -999999 local_scores = F.log_softmax(seq_logit, dim=1) # topk scores local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1) ''' best_tgt_idx_list = [] best_tgt_score_list = [] for tgt_sample,score_sample in zip(remained_tgts,local_scores): best_tgt_score,best_tgt_idx=-10000,0 for tgt_sample_idx in tgt_sample: if score_sample[tgt_sample_idx]>best_tgt_score: best_tgt_score=score_sample[tgt_sample_idx] best_tgt_idx=tgt_sample_idx best_tgt_idx_list.append(best_tgt_idx) best_tgt_score_list.append(best_tgt_score) ''' for j in range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + local_best_scores[0, j] new_hyp['yseq'] = torch.ones( bs, (1 + ys.size(1))).type_as(encoder_outputs).long() new_hyp['yseq'][:, :ys.size(1)] = hyp['yseq'] new_hyp['yseq'][:, ys.size(1)] = int(local_best_ids[0, j]) new_hyp['dec_hiddens'] = dec_output new_hyp['dec_embs_input'] = dec_output_input # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # end for hyp in hyps hyps = hyps_best_kept # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: for hyp in hyps: hyp['yseq'] = torch.cat([ hyp['yseq'], torch.ones(bs, 1).fill_( self.eos_id).type_as(encoder_outputs).long() ], dim=1) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][0, -1] == self.eos_id: ended_hyps.append(hyp) else: remained_hyps.append(hyp) hyps = remained_hyps if len(hyps) > 0: print('remeined hypothes: ' + str(len(hyps))) else: print('no hypothesis. Finish decoding.') break for hyp in hyps: print( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][0, 1:]]), hyp['score']) # end for i in range(maxlen) # This part to remove the repreated hyps. tmp_list = [] for hyp in ended_hyps: if 1 or len(set(list(hyp['yseq'][0].data.cpu().numpy()))) == len( hyp['yseq'][0]): tmp_list.append(hyp) else: print('Repeated Predtion:', hyp['yseq']) ended_hyps = tmp_list nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), nbest)] # compitable with LAS implementation for hyp in nbest_hyps: hyp['yseq'] = hyp['yseq'][0].cpu().numpy().tolist() # hyp['yseq'] = hyp['yseq'].cpu().numpy().tolist() return nbest_hyps