def forward(self, padded_input, encoder_padded_outputs, encoder_input_lengths): """ args: padded_input: B x T encoder_padded_outputs: B x T x H encoder_input_lengths: B returns: pred: B x T x vocab gold: B x T """ decoder_self_attn_list, decoder_encoder_attn_list = [], [] seq_in_pad, seq_out_pad = self.preprocess(padded_input) # Prepare masks non_pad_mask = get_non_pad_mask(seq_in_pad, pad_idx=constant.EOS_TOKEN) self_attn_mask_subseq = get_subsequent_mask(seq_in_pad) self_attn_mask_keypad = get_attn_key_pad_mask( seq_k=seq_in_pad, seq_q=seq_in_pad, pad_idx=constant.EOS_TOKEN) self_attn_mask = (self_attn_mask_keypad + self_attn_mask_subseq).gt(0) output_length = seq_in_pad.size(1) dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs, encoder_input_lengths, output_length) decoder_output = self.dropout( self.trg_embedding(seq_in_pad) * self.x_logit_scale + self.positional_encoding(seq_in_pad)) for layer in self.layers: decoder_output, decoder_self_attn, decoder_enc_attn = layer( decoder_output, encoder_padded_outputs, non_pad_mask=non_pad_mask, self_attn_mask=self_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask) decoder_self_attn_list += [decoder_self_attn] decoder_encoder_attn_list += [decoder_enc_attn] seq_logit = self.output_linear(decoder_output) pred, gold = seq_logit, seq_out_pad return pred, gold, decoder_self_attn_list, decoder_encoder_attn_list
def beam_search(self, encoder_padded_outputs, beam_width=2, nbest=5, lm_rescoring=False, lm=None, lm_weight=0.1, c_weight=1, prob_weight=1.0): """ Beam search, decode nbest utterances args: encoder_padded_outputs: B x T x H beam_size: int nbest: int output: batch_ids_nbest_hyps: list of nbest in ids (size B) batch_strs_nbest_hyps: list of nbest in strings (size B) """ batch_size = encoder_padded_outputs.size(0) max_len = encoder_padded_outputs.size(1) batch_ids_nbest_hyps = [] batch_strs_nbest_hyps = [] for x in range(batch_size): encoder_output = encoder_padded_outputs[x].unsqueeze( 0) # 1 x T x H # add SOS_TOKEN ys = torch.ones(1, 1).fill_( constant.SOS_TOKEN).type_as(encoder_output).long() hyp = {'score': 0.0, 'yseq': ys} hyps = [hyp] ended_hyps = [] for i in range(300): # for i in range(self.trg_max_length): hyps_best_kept = [] for hyp in hyps: ys = hyp['yseq'] # 1 x i # Prepare masks non_pad_mask = torch.ones_like(ys).float().unsqueeze( -1) # 1xix1 self_attn_mask = get_subsequent_mask(ys) decoder_output = self.dropout( self.trg_embedding(ys) * self.x_logit_scale + self.positional_encoding(ys)) for layer in self.layers: # print(decoder_output.size(), encoder_output.size()) decoder_output, _, _ = layer( decoder_output, encoder_output, non_pad_mask=non_pad_mask, self_attn_mask=self_attn_mask, dec_enc_attn_mask=None) seq_logit = self.output_linear(decoder_output[:, -1]) local_scores = F.log_softmax(seq_logit, dim=1) local_best_scores, local_best_ids = torch.topk( local_scores, beam_width, dim=1) # calculate beam scores for j in range(beam_width): new_hyp = {} new_hyp["score"] = hyp["score"] + local_best_scores[0, j] new_hyp["yseq"] = torch.ones( 1, (1 + ys.size(1))).type_as(encoder_output).long() new_hyp["yseq"][:, :ys.size(1)] = hyp["yseq"].cpu() new_hyp["yseq"][:, ys.size(1)] = int( local_best_ids[0, j]) # adding new word hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x["score"], reverse=True)[:beam_width] hyps = hyps_best_kept # add EOS_TOKEN if i == max_len - 1: for hyp in hyps: hyp["yseq"] = torch.cat([ hyp["yseq"], torch.ones(1, 1).fill_(constant.EOS_TOKEN).type_as( encoder_output).long() ], dim=1) # add hypothesis that have EOS_TOKEN to ended_hyps list unended_hyps = [] for hyp in hyps: if hyp["yseq"][0, -1] == constant.EOS_TOKEN: if lm_rescoring: # seq_str = "".join(self.id2label[char.item()] for char in hyp["yseq"][0]).replace(constant.PAD_CHAR,"").replace(constant.SOS_CHAR,"").replace(constant.EOS_CHAR,"") # seq_str = seq_str.replace(" ", " ") # num_words = len(seq_str.split()) hyp["lm_score"], hyp[ "num_words"], oov_token = calculate_lm_score( hyp["yseq"], lm, self.id2label) num_words = hyp["num_words"] hyp["lm_score"] -= oov_token * 2 hyp["final_score"] = hyp["score"] + lm_weight * hyp[ "lm_score"] + math.sqrt(num_words) * c_weight else: seq_str = "".join( self.id2label[char.item()] for char in hyp["yseq"][0]).replace( constant.PAD_CHAR, "").replace(constant.SOS_CHAR, "").replace( constant.EOS_CHAR, "") seq_str = seq_str.replace(" ", " ") num_words = len(seq_str.split()) hyp["final_score"] = hyp["score"] + math.sqrt( num_words) * c_weight ended_hyps.append(hyp) else: unended_hyps.append(hyp) hyps = unended_hyps if len(hyps) == 0: # decoding process is finished break num_nbest = min(len(ended_hyps), nbest) nbest_hyps = sorted(ended_hyps, key=lambda x: x["final_score"], reverse=True)[:num_nbest] a_nbest_hyps = sorted(ended_hyps, key=lambda x: x["final_score"], reverse=True)[:beam_width] if lm_rescoring: for hyp in a_nbest_hyps: seq_str = "".join(self.id2label[char.item()] for char in hyp["yseq"][0]).replace( constant.PAD_CHAR, "").replace(constant.SOS_CHAR, "").replace( constant.EOS_CHAR, "") seq_str = seq_str.replace(" ", " ") num_words = len(seq_str.split()) # print("{} || final:{} e2e:{} lm:{} num words:{}".format(seq_str, hyp["final_score"], hyp["score"], hyp["lm_score"], hyp["num_words"])) for hyp in nbest_hyps: hyp["yseq"] = hyp["yseq"][0].cpu().numpy().tolist() hyp_strs = self.post_process_hyp(hyp) batch_ids_nbest_hyps.append(hyp["yseq"]) batch_strs_nbest_hyps.append(hyp_strs) # print(hyp["yseq"], hyp_strs) return batch_ids_nbest_hyps, batch_strs_nbest_hyps
def greedy_search(self, encoder_padded_outputs, beam_width=2, lm_rescoring=False, lm=None, lm_weight=0.1, c_weight=1): """ Greedy search, decode 1-best utterance args: encoder_padded_outputs: B x T x H output: batch_ids_nbest_hyps: list of nbest in ids (size B) batch_strs_nbest_hyps: list of nbest in strings (size B) """ max_seq_len = self.trg_max_length ys = torch.ones(encoder_padded_outputs.size(0), 1).fill_(constant.SOS_TOKEN).long() # batch_size x 1 if constant.args.cuda: ys = ys.cuda() decoded_words = [] for t in range(300): # for t in range(max_seq_len): # print(t) # Prepare masks non_pad_mask = torch.ones_like(ys).float().unsqueeze( -1) # batch_size x t x 1 self_attn_mask = get_subsequent_mask(ys) # batch_size x t x t decoder_output = self.dropout( self.trg_embedding(ys) * self.x_logit_scale + self.positional_encoding(ys)) for layer in self.layers: decoder_output, _, _ = layer(decoder_output, encoder_padded_outputs, non_pad_mask=non_pad_mask, self_attn_mask=self_attn_mask, dec_enc_attn_mask=None) prob = self.output_linear( decoder_output) # batch_size x t x label_size # _, next_word = torch.max(prob[:, -1], dim=1) # decoded_words.append([constant.EOS_CHAR if ni.item() == constant.EOS_TOKEN else self.id2label[ni.item()] for ni in next_word.view(-1)]) # next_word = next_word.unsqueeze(-1) # local_best_scores, local_best_ids = torch.topk(local_scores, beam_width, dim=1) if lm_rescoring: local_scores = F.log_softmax(prob, dim=1) local_best_scores, local_best_ids = torch.topk(local_scores, beam_width, dim=1) best_score = -1 best_word = None # calculate beam scores for j in range(beam_width): cur_seq = " ".join(word for word in decoded_words) lm_score, num_words, oov_token = calculate_lm_score( cur_seq, lm, self.id2label) score = local_best_scores[0, j] + lm_score if best_score < score: best_score = score best_word = local_best_ids[0, j] next_word = best_word.unsqueeze(-1) decoded_words.append(self.id2label[int(best_word)]) else: _, next_word = torch.max(prob[:, -1], dim=1) decoded_words.append([ constant.EOS_CHAR if ni.item() == constant.EOS_TOKEN else self.id2label[ni.item()] for ni in next_word.view(-1) ]) next_word = next_word.unsqueeze(-1) if constant.args.cuda: ys = torch.cat([ys, next_word.cuda()], dim=1) ys = ys.cuda() else: ys = torch.cat([ys, next_word], dim=1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == constant.EOS_CHAR: break else: st += e sent.append(st) return sent