def init_state(self, x: torch.Tensor): """Get an initial state for decoding. Args: x (torch.Tensor): The encoded feature tensor Returns: initial state """ logp = self.ctc.log_softmax(x.unsqueeze(0)).detach().squeeze(0).cpu().numpy() # TODO(karita): use CTCPrefixScoreTH self.impl = CTCPrefixScore(logp, 0, self.eos, np) return 0, self.impl.initial_state()
def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) if self.mtlalpha == 1.0: recog_args.ctc_weight = 1.0 logging.info("Set to pure CTC decoding mode.") if self.mtlalpha > 0 and recog_args.ctc_weight == 1.0: from itertools import groupby lpz = self.ctc.argmax(enc_output) collapsed_indices = [x[0] for x in groupby(lpz[0])] hyp = [x for x in filter(lambda x: x != self.blank, collapsed_indices)] nbest_hyps = [{"score": 0.0, "yseq": hyp}] if recog_args.beam_size > 1: raise NotImplementedError("Pure CTC beam search is not implemented.") # TODO(hirofumi0810): Implement beam search return nbest_hyps elif self.mtlalpha > 0 and recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info("input lengths: " + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis if rnnlm: hyp = {"score": 0.0, "yseq": [y], "rnnlm_prev": None} else: hyp = {"score": 0.0, "yseq": [y]} if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() hyp["ctc_score_prev"] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug("position " + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp["yseq"][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp["yseq"]).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output) ) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output )[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy) local_scores = ( local_att_scores + recog_args.lm_weight * local_lm_scores ) else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1 ) ctc_scores, ctc_states = ctc_prefix_score( hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"] ) local_scores = (1.0 - ctc_weight) * local_att_scores[ :, local_best_ids[0] ] + ctc_weight * torch.from_numpy( ctc_scores - hyp["ctc_score_prev"] ) if rnnlm: local_scores += ( recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]] ) local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1 ) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1 ) for j in six.moves.range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + float(local_best_scores[0, j]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) if rnnlm: new_hyp["rnnlm_prev"] = rnnlm_state if lpz is not None: new_hyp["ctc_state_prev"] = ctc_states[joint_best_ids[0, j]] new_hyp["ctc_score_prev"] = ctc_scores[joint_best_ids[0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x["score"], reverse=True )[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypothes: " + str(len(hyps))) if char_list is not None: logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) ) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last postion in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp["score"] += recog_args.lm_weight * rnnlm.final( hyp["rnnlm_prev"] ) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remeined hypothes: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break if char_list is not None: for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) ) logging.debug("number of ended hypothes: " + str(len(ended_hyps))) nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ : min(len(ended_hyps), recog_args.nbest) ] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( "there is no N-best results, perform recognition " "again with smaller minlenratio." ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info( "normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) ) return nbest_hyps
def recognize_jca(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) # (1, T, D) if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) # shape of (T, D) else: lpz = None h = enc_output.squeeze(0) # (B, T, D), #B=1 logging.info('input lengths: ' + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: import numpy from espnet.nets.ctc_prefix_score import CTCPrefixScore ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: if self.remove_blank_in_ctc_mode: ctc_beam = lpz.shape[-1] - 1 # except blank else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze( 0) # mask scores of future state ys = torch.tensor(hyp['yseq']).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: if self.remove_blank_in_ctc_mode: # here we need to filter out <blank> in local_best_ids # it happens in pure ctc-mode, when ctc_beam equals to #vocab local_best_scores, local_best_ids = torch.topk( local_att_scores[:, 1:], ctc_beam, dim=1) local_best_ids += 1 # hack else: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[ 0]] local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[ 0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection # from espnet.nets.e2e_asr_common import end_detect # if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: from espnet.nets.e2e_asr_common import end_detect_yzl23 if end_detect_yzl23(ended_hyps, remained_hyps, penalty) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps
def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None, strm_idx=0): """beam search implementation :param torch.Tensor h: encoder hidden state (T, eprojs) [in multi-encoder case, list of torch.Tensor, [(T1, eprojs), (T2, eprojs), ...] ] :param torch.Tensor lpz: ctc log softmax output (T, odim) [in multi-encoder case, list of torch.Tensor, [(T1, odim), (T2, odim), ...] ] :param Namespace recog_args: argument Namespace containing options :param char_list: list of character strings :param torch.nn.Module rnnlm: language module :param int strm_idx: stream index for speaker parallel attention in multi-speaker case :return: N-best decoding results :rtype: list of dicts """ # to support mutiple encoder asr mode, in single encoder mode, # convert torch.Tensor to List of torch.Tensor if self.num_encs == 1: h = [h] lpz = [lpz] if self.num_encs > 1 and lpz is None: lpz = [lpz] * self.num_encs for idx in range(self.num_encs): logging.info( "Number of Encoder:{}; enc{}: input lengths: {}.".format( self.num_encs, idx + 1, h[0].size(0))) att_idx = min(strm_idx, len(self.att) - 1) # initialization c_list = [self.zero_state(h[0].unsqueeze(0))] z_list = [self.zero_state(h[0].unsqueeze(0))] for _ in six.moves.range(1, self.dlayers): c_list.append(self.zero_state(h[0].unsqueeze(0))) z_list.append(self.zero_state(h[0].unsqueeze(0))) if self.num_encs == 1: a = None self.att[att_idx].reset() # reset pre-computation of h else: a = [None] * (self.num_encs + 1) # atts + han att_w_list = [None] * (self.num_encs + 1) # atts + han att_c_list = [None] * (self.num_encs) # atts for idx in range(self.num_encs + 1): self.att[idx].reset( ) # reset pre-computation of h in atts and han # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = getattr(recog_args, "ctc_weight", False) # for NMT if lpz[0] is not None and self.num_encs > 1: # weights-ctc, # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss weights_ctc_dec = recog_args.weights_ctc_dec / np.sum( recog_args.weights_ctc_dec) # normalize logging.info("ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec])) else: weights_ctc_dec = [1.0] # preprate sos if self.replace_sos and recog_args.tgt_lang: y = char_list.index(recog_args.tgt_lang) else: y = self.sos logging.info("<sos> index: " + str(y)) logging.info("<sos> mark: " + char_list[y]) vy = h[0].new_zeros(1).long() maxlen = np.amin([h[idx].size(0) for idx in range(self.num_encs)]) if recog_args.maxlenratio != 0: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * maxlen)) minlen = int(recog_args.minlenratio * maxlen) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis if rnnlm: hyp = { "score": 0.0, "yseq": [y], "c_prev": c_list, "z_prev": z_list, "a_prev": a, "rnnlm_prev": None, } else: hyp = { "score": 0.0, "yseq": [y], "c_prev": c_list, "z_prev": z_list, "a_prev": a, } if lpz[0] is not None: ctc_prefix_score = [ CTCPrefixScore(lpz[idx].detach().numpy(), 0, self.eos, np) for idx in range(self.num_encs) ] hyp["ctc_state_prev"] = [ ctc_prefix_score[idx].initial_state() for idx in range(self.num_encs) ] hyp["ctc_score_prev"] = [0.0] * self.num_encs if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz[0].shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz[0].shape[-1] hyps = [hyp] ended_hyps = [] for i in six.moves.range(maxlen): logging.debug("position " + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp["yseq"][i] ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim if self.num_encs == 1: att_c, att_w = self.att[att_idx]( h[0].unsqueeze(0), [h[0].size(0)], self.dropout_dec[0](hyp["z_prev"][0]), hyp["a_prev"], ) else: for idx in range(self.num_encs): att_c_list[idx], att_w_list[idx] = self.att[idx]( h[idx].unsqueeze(0), [h[idx].size(0)], self.dropout_dec[0](hyp["z_prev"][0]), hyp["a_prev"][idx], ) h_han = torch.stack(att_c_list, dim=1) att_c, att_w_list[self.num_encs] = self.att[self.num_encs]( h_han, [self.num_encs], self.dropout_dec[0](hyp["z_prev"][0]), hyp["a_prev"][self.num_encs], ) ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim) z_list, c_list = self.rnn_forward(ey, z_list, c_list, hyp["z_prev"], hyp["c_prev"]) # get nbest local scores and their ids if self.context_residual: logits = self.output( torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)) else: logits = self.output(self.dropout_dec[-1](z_list[-1])) local_att_scores = F.log_softmax(logits, dim=1) if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp["rnnlm_prev"], vy) local_scores = (local_att_scores + recog_args.lm_weight * local_lm_scores) else: local_scores = local_att_scores if lpz[0] is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ( [None] * self.num_encs, [None] * self.num_encs, ) for idx in range(self.num_encs): ctc_scores[idx], ctc_states[idx] = ctc_prefix_score[ idx](hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"][idx]) local_scores = (1.0 - ctc_weight ) * local_att_scores[:, local_best_ids[0]] if self.num_encs == 1: local_scores += ctc_weight * torch.from_numpy( ctc_scores[0] - hyp["ctc_score_prev"][0]) else: for idx in range(self.num_encs): local_scores += ( ctc_weight * weights_ctc_dec[idx] * torch.from_numpy(ctc_scores[idx] - hyp["ctc_score_prev"][idx])) if rnnlm: local_scores += (recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]) local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} # [:] is needed! new_hyp["z_prev"] = z_list[:] new_hyp["c_prev"] = c_list[:] if self.num_encs == 1: new_hyp["a_prev"] = att_w[:] else: new_hyp["a_prev"] = [ att_w_list[idx][:] for idx in range(self.num_encs + 1) ] new_hyp["score"] = hyp["score"] + local_best_scores[0, j] new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][:len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) if rnnlm: new_hyp["rnnlm_prev"] = rnnlm_state if lpz[0] is not None: new_hyp["ctc_state_prev"] = [ ctc_states[idx][joint_best_ids[0, j]] for idx in range(self.num_encs) ] new_hyp["ctc_score_prev"] = [ ctc_scores[idx][joint_best_ids[0, j]] for idx in range(self.num_encs) ] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x["score"], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypotheses: " + str(len(hyps))) logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last position in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypotheses to a final list, # and removed them from current hypotheses # (this will be a problem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp["score"] += recog_args.lm_weight * rnnlm.final( hyp["rnnlm_prev"]) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remaining hypotheses: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]])) logging.debug("number of ended hypotheses: " + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x["score"], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheses if len(nbest_hyps) == 0: logging.warning( "there is no N-best results, " "perform recognition again with smaller minlenratio.") # should copy because Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) if self.num_encs == 1: return self.recognize_beam(h[0], lpz[0], recog_args, char_list, rnnlm) else: return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info("normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"]))) # remove sos return nbest_hyps
def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None, strm_idx=0): """beam search implementation :param torch.Tensor h: encoder hidden state (T, eprojs) :param torch.Tensor lpz: ctc log softmax output (T, odim) :param Namespace recog_args: argument Namespace containing options :param char_list: list of character strings :param torch.nn.Module rnnlm: language module :param int strm_idx: stream index for speaker parallel attention in multi-speaker case :return: N-best decoding results :rtype: list of dicts """ logging.info('input lengths: ' + str(h.size(0))) att_idx = min(strm_idx, len(self.att) - 1) # initialization c_list = [self.zero_state(h.unsqueeze(0))] z_list = [self.zero_state(h.unsqueeze(0))] for _ in six.moves.range(1, self.dlayers): c_list.append(self.zero_state(h.unsqueeze(0))) z_list.append(self.zero_state(h.unsqueeze(0))) a = None self.att[att_idx].reset() # reset pre-computation of h # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprate sos if self.replace_sos and recog_args.tgt_lang: y = char_list.index(recog_args.tgt_lang) else: y = self.sos logging.info('<sos> index: ' + str(y)) logging.info('<sos> mark: ' + char_list[y]) vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = { 'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a, 'rnnlm_prev': None } else: hyp = { 'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a } if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, np) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim ey.unsqueeze(0) att_c, att_w = self.att[att_idx]( h.unsqueeze(0), [h.size(0)], self.dropout_dec[0](hyp['z_prev'][0]), hyp['a_prev']) ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim) z_list, c_list = self.rnn_forward(ey, z_list, c_list, hyp['z_prev'], hyp['c_prev']) # get nbest local scores and their ids if self.context_residual: logits = self.output( torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)) else: logits = self.output(self.dropout_dec[-1](z_list[-1])) local_att_scores = F.log_softmax(logits, dim=1) if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[ 0]] local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} # [:] is needed! new_hyp['z_prev'] = z_list[:] new_hyp['c_prev'] = c_list[:] new_hyp['a_prev'] = att_w[:] new_hyp['score'] = hyp['score'] + local_best_scores[0, j] new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[ 0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypotheses: ' + str(len(hyps))) logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last position in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypotheses to a final list, and removed them from current hypotheses # (this will be a problem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remaining hypotheses: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypotheses: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheses if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy because Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) # remove sos return nbest_hyps
def recognize_beam(self, h, lpz, recog_args, char_list=None, rnnlm=None): """beam search implementation :param h: :param lpz: :param recog_args: :param char_list: :param rnnlm: :return: """ logging.info('input lengths: ' + str(h.shape[0])) # initialization xp = self.xp h_mask = xp.ones((1, h.shape[0])) batch = 1 # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # prepare sos y = self.sos if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: maxlen = max(1, int(recog_args.maxlenratio * h.shape[0])) minlen = int(recog_args.minlenratio * h.shape[0]) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz, 0, self.eos, self.xp) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: ys = F.expand_dims(xp.array(hyp['yseq']), axis=0).data yy_mask = self.make_attention_mask(ys, ys) yy_mask *= self.make_history_mask(ys) xy_mask = self.make_attention_mask(ys, h_mask) out = self.decoder(ys, yy_mask, h, xy_mask).reshape(batch, -1, self.odim) # get nbest local scores and their ids local_att_scores = F.log_softmax(out[:, -1], axis=-1).data if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], hyp['yseq'][i]) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_ids = xp.argsort(local_scores, axis=1)[0, ::-1][:ctc_beam] ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids, hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids] \ + ctc_weight * (ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids] joint_best_ids = xp.argsort(local_scores, axis=1)[0, ::-1][:beam] local_best_scores = local_scores[:, joint_best_ids] local_best_ids = local_best_ids[joint_best_ids] else: local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][:beam] local_best_scores = local_scores[:, local_best_ids] for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[ joint_best_ids[j]] new_hyp['ctc_score_prev'] = ctc_scores[ joint_best_ids[j]] hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothesis: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]]) + ' score: ' + str(hyps[0]['score'])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remained hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True) # [:min(len(ended_hyps), recog_args.nbest)] logging.debug(nbest_hyps) # check number of hypotheis if len(nbest_hyps) == 0: logging.warn( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) # remove sos return nbest_hyps
def recognize_beam(self, h, lpz, recog_args, char_list=None, rnnlm=None): """E2E beam search. Args: h (ndarray): Encoder ouput features (B, T, D) or (T, D). lpz (ndarray): Log probabilities from CTC. recog_args (Namespace): Argment namespace contraining options. char_list (List[str]): List of characters. rnnlm (chainer.Chain): Language model module defined at `espnet.lm.chainer_backend.lm`. Returns: List: N-best decoding results. """ logging.info("input lengths: " + str(h.shape[1])) # initialization n_len = h.shape[1] xp = self.xp h_mask = xp.ones((1, n_len)) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # prepare sos y = self.sos if recog_args.maxlenratio == 0: maxlen = n_len else: maxlen = max(1, int(recog_args.maxlenratio * n_len)) minlen = int(recog_args.minlenratio * n_len) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis if rnnlm: hyp = {"score": 0.0, "yseq": [y], "rnnlm_prev": None} else: hyp = {"score": 0.0, "yseq": [y]} if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz, 0, self.eos, self.xp) hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() hyp["ctc_score_prev"] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] for i in six.moves.range(maxlen): logging.debug("position " + str(i)) hyps_best_kept = [] for hyp in hyps: ys = F.expand_dims(xp.array(hyp["yseq"]), axis=0).data out = self.decoder(ys, h, h_mask) # get nbest local scores and their ids local_att_scores = F.log_softmax(out[:, -1], axis=-1).data if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp["rnnlm_prev"], hyp["yseq"][i]) local_scores = (local_att_scores + recog_args.lm_weight * local_lm_scores) else: local_scores = local_att_scores if lpz is not None: local_best_ids = xp.argsort(local_scores, axis=1)[0, ::-1][:ctc_beam] ctc_scores, ctc_states = ctc_prefix_score( hyp["yseq"], local_best_ids, hyp["ctc_state_prev"]) local_scores = ( 1.0 - ctc_weight ) * local_att_scores[:, local_best_ids] + ctc_weight * ( ctc_scores - hyp["ctc_score_prev"]) if rnnlm: local_scores += (recog_args.lm_weight * local_lm_scores[:, local_best_ids]) joint_best_ids = xp.argsort(local_scores, axis=1)[0, ::-1][:beam] local_best_scores = local_scores[:, joint_best_ids] local_best_ids = local_best_ids[joint_best_ids] else: local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][:beam] local_best_scores = local_scores[:, local_best_ids] for j in six.moves.range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + float( local_best_scores[0, j]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][:len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[j]) if rnnlm: new_hyp["rnnlm_prev"] = rnnlm_state if lpz is not None: new_hyp["ctc_state_prev"] = ctc_states[ joint_best_ids[j]] new_hyp["ctc_score_prev"] = ctc_scores[ joint_best_ids[j]] hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x["score"], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypothesis: " + str(len(hyps))) if char_list is not None: logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) + " score: " + str(hyps[0]["score"])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last postion in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp["score"] += recog_args.lm_weight * rnnlm.final( hyp["rnnlm_prev"]) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remained hypothes: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break if char_list is not None: for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]])) logging.debug("number of ended hypothes: " + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x["score"], reverse=True) # [:min(len(ended_hyps), recog_args.nbest)] logging.debug(nbest_hyps) # check number of hypotheis if len(nbest_hyps) == 0: logging.warn("there is no N-best results, perform recognition " "again with smaller minlenratio.") # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info("normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"]))) # remove sos return nbest_hyps
def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None): """Beam search implementation. Args: h (chainer.Variable): One of the output from the encoder. lpz (chainer.Variable | None): Result of net propagation. recog_args (Namespace): The argument. char_list (List[str]): List of all charactors. rnnlm (Module): RNNLM module. Defined at `espnet.lm.chainer_backend.lm` Returns: List[Dict[str,Any]]: Result of recognition. """ logging.info('input lengths: ' + str(h.shape[0])) # initialization c_list = [None] # list of cell state of each layer z_list = [None] # list of hidden state of each layer for _ in six.moves.range(1, self.dlayers): c_list.append(None) z_list.append(None) a = None self.att.reset() # reset pre-computation of h # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprate sos y = self.xp.full(1, self.sos, 'i') if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.shape[0])) minlen = int(recog_args.minlenratio * h.shape[0]) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = { 'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a, 'rnnlm_prev': None } else: hyp = { 'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a } if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz, 0, self.eos, self.xp) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: ey = self.embed(hyp['yseq'][i]) # utt list (1) x zdim att_c, att_w = self.att([h], hyp['z_prev'][0], hyp['a_prev']) ey = F.hstack((ey, att_c)) # utt(1) x (zdim + hdim) z_list, c_list = self.rnn_forward(ey, z_list, c_list, hyp['z_prev'], hyp['c_prev']) # get nbest local scores and their ids local_att_scores = F.log_softmax(self.output(z_list[-1])).data if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], hyp['yseq'][i]) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_ids = self.xp.argsort( local_scores, axis=1)[0, ::-1][:ctc_beam] ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids, hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids] \ + ctc_weight * (ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids] joint_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][:beam] local_best_scores = local_scores[:, joint_best_ids] local_best_ids = local_best_ids[joint_best_ids] else: local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][:beam] local_best_scores = local_scores[:, local_best_ids] for j in six.moves.range(beam): new_hyp = {} # do not copy {z,c}_list directly new_hyp['z_prev'] = z_list[:] new_hyp['c_prev'] = c_list[:] new_hyp['a_prev'] = att_w new_hyp['score'] = hyp['score'] + local_best_scores[0, j] new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = self.xp.full( 1, local_best_ids[j], 'i') if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[ joint_best_ids[j]] new_hyp['ctc_score_prev'] = ctc_scores[ joint_best_ids[j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypotheses: ' + str(len(hyps))) logging.debug('best hypo: ' + ''.join( [char_list[int(x)] for x in hyps[0]['yseq'][1:]]).replace('<space>', ' ')) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last position in the loop') for hyp in hyps: hyp['yseq'].append(self.xp.full(1, self.eos, 'i')) # add ended hypotheses to a final list, and removed them from current hypotheses # (this will be a problem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remaining hypotheses: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break for hyp in hyps: logging.debug('hypo: ' + ''.join( [char_list[int(x)] for x in hyp['yseq'][1:]]).replace('<space>', ' ')) logging.debug('number of ended hypotheses: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheses if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy because Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps
class CTCPrefixScorer(PartialScorerInterface): """Decoder interface wrapper for CTCPrefixScore.""" def __init__(self, ctc: torch.nn.Module, eos: int): """Initialize class. Args: ctc (torch.nn.Module): The CTC implementaiton. For example, :class:`espnet.nets.pytorch_backend.ctc.CTC` eos (int): The end-of-sequence id. """ self.ctc = ctc self.eos = eos self.impl = None def init_state(self, x: torch.Tensor): """Get an initial state for decoding. Args: x (torch.Tensor): The encoded feature tensor Returns: initial state """ logp = self.ctc.log_softmax(x.unsqueeze(0)).detach().squeeze(0).cpu().numpy() # TODO(karita): use CTCPrefixScoreTH self.impl = CTCPrefixScore(logp, 0, self.eos, np) return 0, self.impl.initial_state() def select_state(self, state, i): """Select state with relative ids in the main beam search. Args: state: Decoder state for prefix tokens i (int): Index to select a state in the main beam search Returns: state: pruned state """ sc, st = state return sc[i], st[i] def score_partial(self, y, ids, state, x): """Score new token. Args: y (torch.Tensor): 1D prefix token next_tokens (torch.Tensor): torch.int64 next token to score state: decoder state for prefix tokens x (torch.Tensor): 2D encoder feature that generates ys Returns: tuple[torch.Tensor, Any]: Tuple of a score tensor for y that has a shape `(len(next_tokens),)` and next state for ys """ prev_score, state = state presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state) tscore = torch.as_tensor(presub_score - prev_score, device=x.device, dtype=x.dtype) return tscore, (presub_score, new_st)
def infer(x, encoder_rt, ctc_softmax_rt, decoder_fos_rt): ctc_weight = 0.5 beam_size = 1 penalty = 0.0 maxlenratio = 0.0 nbest = 1 sos = eos = 7442 # enc_output = self.encode(x).unsqueeze(0) ort_inputs = {"x": x.numpy()} enc_output = encoder_rt.run(None, ort_inputs) enc_output = torch.tensor(enc_output) enc_output = enc_output.squeeze(0) # print(f"enc_output shape: {enc_output.shape}") # lpz = self.ctc.log_softmax(enc_output) # lpz = lpz.squeeze(0) ort_inputs = {"enc_output": enc_output.numpy()} lpz = ctc_softmax_rt.run(None, ort_inputs) lpz = torch.tensor(lpz) lpz = lpz.squeeze(0).squeeze(0) # print(f"lpz shape: {lpz.shape}") h = enc_output.squeeze(0) # print(f"h shape: {h.shape}") # preprare sos y = sos maxlen = h.shape[0] minlen = 0.0 ctc_beam = 1 # initialize hypothesis hyp = {'score': 0.0, 'yseq': [y]} ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 # pre-pruning based on attention scores hyps = [hyp] ended_hyps = [] for i in six.moves.range(maxlen): hyps_best_kept = [] for hyp in hyps: # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp['yseq']).unsqueeze(0) ort_inputs = {"ys": ys.numpy(), "ys_mask": ys_mask.numpy(), "enc_output": enc_output.numpy()} local_att_scores = decoder_fos_rt.run(None, ort_inputs) local_att_scores = torch.tensor(local_att_scores[0]) local_scores = local_att_scores local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) local_best_scores, joint_best_ids = torch.topk(local_scores, beam_size, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] for j in six.moves.range(beam_size): new_hyp = {} new_hyp['score'] = hyp['score'] + float(local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam_size] # sort and get nbest hyps = hyps_best_kept # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: for hyp in hyps: hyp['yseq'].append(eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and maxlenratio == 0.0: break hyps = remained_hyps if len(hyps) > 0: pass else: break nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), nbest)] # return nbest_hyps return torch.tensor(nbest_hyps[0]['yseq'])
def forward(self, x): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ ctc_weight = 0.5 beam_size = 1 penalty = 0.0 maxlenratio = 0.0 nbest = 1 sos = eos = 7442 import onnxruntime # enc_output = self.encode(x).unsqueeze(0) encoder_rt = onnxruntime.InferenceSession("encoder.onnx") ort_inputs = {"x": x.numpy()} enc_output = encoder_rt.run(None, ort_inputs) enc_output = torch.tensor(enc_output) enc_output = enc_output.squeeze(0) # print(f"enc_output shape: {enc_output.shape}") # lpz = self.ctc.log_softmax(enc_output) # lpz = lpz.squeeze(0) ctc_softmax_rt = onnxruntime.InferenceSession("ctc_softmax.onnx") ort_inputs = {"enc_output": enc_output.numpy()} lpz = ctc_softmax_rt.run(None, ort_inputs) lpz = torch.tensor(lpz) lpz = lpz.squeeze(0).squeeze(0) # print(f"lpz shape: {lpz.shape}") h = enc_output.squeeze(0) # print(f"h shape: {h.shape}") # preprare sos y = sos maxlen = h.shape[0] minlen = 0.0 # initialize hypothesis hyp = {'score': 0.0, 'yseq': [y]} ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 # pre-pruning based on attention scores # ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) ctc_beam = 1 hyps = [hyp] ended_hyps = [] decoder_fos_rt = onnxruntime.InferenceSession("decoder_fos.onnx") for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp['yseq']).unsqueeze(0) ort_inputs = { "ys": ys.numpy(), "ys_mask": ys_mask.numpy(), "enc_output": enc_output.numpy() } local_att_scores = decoder_fos_rt.run(None, ort_inputs) local_att_scores = torch.tensor(local_att_scores[0]) # local_att_scores = self.decoder.forward_one_step(ys, ys_mask, enc_output)[0] local_scores = local_att_scores # print(local_scores.shape) 1, 7443 local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) # print(local_best_scores.shape) 1, 1 ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) local_best_scores, joint_best_ids = torch.topk(local_scores, beam_size, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] for j in six.moves.range(beam_size): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam_size] # sort and get nbest hyps = hyps_best_kept # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: for hyp in hyps: hyp['yseq'].append(eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and maxlenratio == 0.0: break hyps = remained_hyps if len(hyps) > 0: pass else: break nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), nbest)] # return nbest_hyps return torch.tensor(nbest_hyps[0]['yseq'])