def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None): """Beam search implementation. Args: h (chainer.Variable): One of the output from the encoder. lpz (chainer.Variable | None): Result of net propagation. recog_args (Namespace): The argument. char_list (List[str]): List of all charactors. rnnlm (Module): RNNLM module. Defined at `espnet.lm.chainer_backend.lm` Returns: List[Dict[str,Any]]: Result of recognition. """ logging.info('input lengths: ' + str(h.shape[0])) # initialization c_list = [None] # list of cell state of each layer z_list = [None] # list of hidden state of each layer for _ in six.moves.range(1, self.dlayers): c_list.append(None) z_list.append(None) a = None self.att.reset() # reset pre-computation of h # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprate sos y = self.xp.full(1, self.sos, 'i') if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.shape[0])) minlen = int(recog_args.minlenratio * h.shape[0]) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = { 'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a, 'rnnlm_prev': None } else: hyp = { 'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a } if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz, 0, self.eos, self.xp) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: ey = self.embed(hyp['yseq'][i]) # utt list (1) x zdim att_c, att_w = self.att([h], hyp['z_prev'][0], hyp['a_prev']) ey = F.hstack((ey, att_c)) # utt(1) x (zdim + hdim) z_list, c_list = self.rnn_forward(ey, z_list, c_list, hyp['z_prev'], hyp['c_prev']) # get nbest local scores and their ids local_att_scores = F.log_softmax(self.output(z_list[-1])).data if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], hyp['yseq'][i]) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_ids = self.xp.argsort( local_scores, axis=1)[0, ::-1][:ctc_beam] ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids, hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids] \ + ctc_weight * (ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids] joint_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][:beam] local_best_scores = local_scores[:, joint_best_ids] local_best_ids = local_best_ids[joint_best_ids] else: local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][:beam] local_best_scores = local_scores[:, local_best_ids] for j in six.moves.range(beam): new_hyp = {} # do not copy {z,c}_list directly new_hyp['z_prev'] = z_list[:] new_hyp['c_prev'] = c_list[:] new_hyp['a_prev'] = att_w new_hyp['score'] = hyp['score'] + local_best_scores[0, j] new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = self.xp.full( 1, local_best_ids[j], 'i') if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[ joint_best_ids[j]] new_hyp['ctc_score_prev'] = ctc_scores[ joint_best_ids[j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypotheses: ' + str(len(hyps))) logging.debug('best hypo: ' + ''.join( [char_list[int(x)] for x in hyps[0]['yseq'][1:]]).replace('<space>', ' ')) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last position in the loop') for hyp in hyps: hyp['yseq'].append(self.xp.full(1, self.eos, 'i')) # add ended hypotheses to a final list, and removed them from current hypotheses # (this will be a problem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remaining hypotheses: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break for hyp in hyps: logging.debug('hypo: ' + ''.join( [char_list[int(x)] for x in hyp['yseq'][1:]]).replace('<space>', ' ')) logging.debug('number of ended hypotheses: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheses if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy because Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps
def recognize_beam_batch( self, h, hlens, lpz, recog_args, char_list, rnnlm=None, normalize_score=True, strm_idx=0, lang_ids=None, ): # to support mutiple encoder asr mode, in single encoder mode, # convert torch.Tensor to List of torch.Tensor if self.num_encs == 1: h = [h] hlens = [hlens] lpz = [lpz] if self.num_encs > 1 and lpz is None: lpz = [lpz] * self.num_encs att_idx = min(strm_idx, len(self.att) - 1) for idx in range(self.num_encs): logging.info( "Number of Encoder:{}; enc{}: input lengths: {}.".format( self.num_encs, idx + 1, h[idx].size(1) ) ) h[idx] = mask_by_length(h[idx], hlens[idx], 0.0) # search params batch = len(hlens[0]) beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = getattr(recog_args, "ctc_weight", 0) # for NMT att_weight = 1.0 - ctc_weight ctc_margin = getattr( recog_args, "ctc_window_margin", 0 ) # use getattr to keep compatibility # weights-ctc, # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss if lpz[0] is not None and self.num_encs > 1: weights_ctc_dec = recog_args.weights_ctc_dec / np.sum( recog_args.weights_ctc_dec ) # normalize logging.info( "ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec]) ) else: weights_ctc_dec = [1.0] n_bb = batch * beam pad_b = to_device(h[0], torch.arange(batch) * beam).view(-1, 1) max_hlen = np.amin([max(hlens[idx]) for idx in range(self.num_encs)]) if recog_args.maxlenratio == 0: maxlen = max_hlen else: maxlen = max(1, int(recog_args.maxlenratio * max_hlen)) minlen = int(recog_args.minlenratio * max_hlen) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialization c_prev = [ to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) ] z_prev = [ to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) ] c_list = [ to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) ] z_list = [ to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) ] vscores = to_device(h[0], torch.zeros(batch, beam)) rnnlm_state = None if self.num_encs == 1: a_prev = [None] att_w_list, ctc_scorer, ctc_state = [None], [None], [None] self.att[att_idx].reset() # reset pre-computation of h else: a_prev = [None] * (self.num_encs + 1) # atts + han att_w_list = [None] * (self.num_encs + 1) # atts + han att_c_list = [None] * (self.num_encs) # atts ctc_scorer, ctc_state = [None] * (self.num_encs), [None] * (self.num_encs) for idx in range(self.num_encs + 1): self.att[idx].reset() # reset pre-computation of h in atts and han if self.replace_sos and recog_args.tgt_lang: logging.info("<sos> index: " + str(char_list.index(recog_args.tgt_lang))) logging.info("<sos> mark: " + recog_args.tgt_lang) yseq = [ [char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb) ] elif lang_ids is not None: # NOTE: used for evaluation during training yseq = [ [lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb) ] else: logging.info("<sos> index: " + str(self.sos)) logging.info("<sos> mark: " + char_list[self.sos]) yseq = [[self.sos] for _ in six.moves.range(n_bb)] accum_odim_ids = [self.sos for _ in six.moves.range(n_bb)] stop_search = [False for _ in six.moves.range(batch)] nbest_hyps = [[] for _ in six.moves.range(batch)] ended_hyps = [[] for _ in range(batch)] exp_hlens = [ hlens[idx].repeat(beam).view(beam, batch).transpose(0, 1).contiguous() for idx in range(self.num_encs) ] exp_hlens = [exp_hlens[idx].view(-1).tolist() for idx in range(self.num_encs)] exp_h = [ h[idx].unsqueeze(1).repeat(1, beam, 1, 1).contiguous() for idx in range(self.num_encs) ] exp_h = [ exp_h[idx].view(n_bb, h[idx].size()[1], h[idx].size()[2]) for idx in range(self.num_encs) ] if lpz[0] is not None: scoring_num = min( int(beam * CTC_SCORING_RATIO) if att_weight > 0.0 and not lpz[0].is_cuda else 0, lpz[0].size(-1), ) ctc_scorer = [ CTCPrefixScoreTH( lpz[idx], hlens[idx], 0, self.eos, margin=ctc_margin, ) for idx in range(self.num_encs) ] for i in six.moves.range(maxlen): logging.debug("position " + str(i)) vy = to_device(h[0], torch.LongTensor(self._get_last_yseq(yseq))) ey = self.dropout_emb(self.embed(vy)) if self.num_encs == 1: att_c, att_w = self.att[att_idx]( exp_h[0], exp_hlens[0], self.dropout_dec[0](z_prev[0]), a_prev[0] ) att_w_list = [att_w] else: for idx in range(self.num_encs): att_c_list[idx], att_w_list[idx] = self.att[idx]( exp_h[idx], exp_hlens[idx], self.dropout_dec[0](z_prev[0]), a_prev[idx], ) exp_h_han = torch.stack(att_c_list, dim=1) att_c, att_w_list[self.num_encs] = self.att[self.num_encs]( exp_h_han, [self.num_encs] * n_bb, self.dropout_dec[0](z_prev[0]), a_prev[self.num_encs], ) ey = torch.cat((ey, att_c), dim=1) # attention decoder z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev) if self.context_residual: logits = self.output( torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) ) else: logits = self.output(self.dropout_dec[-1](z_list[-1])) local_scores = att_weight * F.log_softmax(logits, dim=1) # rnnlm if rnnlm: rnnlm_state, local_lm_scores = rnnlm.buff_predict(rnnlm_state, vy, n_bb) local_scores = local_scores + recog_args.lm_weight * local_lm_scores # ctc if ctc_scorer[0]: local_scores[:, 0] = self.logzero # avoid choosing blank part_ids = ( torch.topk(local_scores, scoring_num, dim=-1)[1] if scoring_num > 0 else None ) for idx in range(self.num_encs): att_w = att_w_list[idx] att_w_ = att_w if isinstance(att_w, torch.Tensor) else att_w[0] local_ctc_scores, ctc_state[idx] = ctc_scorer[idx]( yseq, ctc_state[idx], part_ids, att_w_ ) local_scores = ( local_scores + ctc_weight * weights_ctc_dec[idx] * local_ctc_scores ) local_scores = local_scores.view(batch, beam, self.odim) if i == 0: local_scores[:, 1:, :] = self.logzero # accumulate scores eos_vscores = local_scores[:, :, self.eos] + vscores vscores = vscores.view(batch, beam, 1).repeat(1, 1, self.odim) vscores[:, :, self.eos] = self.logzero vscores = (vscores + local_scores).view(batch, -1) # global pruning accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1) accum_odim_ids = ( torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist() ) accum_padded_beam_ids = ( (accum_best_ids // self.odim + pad_b).view(-1).data.cpu().tolist() ) y_prev = yseq[:][:] yseq = self._index_select_list(yseq, accum_padded_beam_ids) yseq = self._append_ids(yseq, accum_odim_ids) vscores = accum_best_scores vidx = to_device(h[0], torch.LongTensor(accum_padded_beam_ids)) a_prev = [] num_atts = self.num_encs if self.num_encs == 1 else self.num_encs + 1 for idx in range(num_atts): if isinstance(att_w_list[idx], torch.Tensor): _a_prev = torch.index_select( att_w_list[idx].view(n_bb, *att_w_list[idx].shape[1:]), 0, vidx ) elif isinstance(att_w_list[idx], list): # handle the case of multi-head attention _a_prev = [ torch.index_select(att_w_one.view(n_bb, -1), 0, vidx) for att_w_one in att_w_list[idx] ] else: # handle the case of location_recurrent when return is a tuple _a_prev_ = torch.index_select( att_w_list[idx][0].view(n_bb, -1), 0, vidx ) _h_prev_ = torch.index_select( att_w_list[idx][1][0].view(n_bb, -1), 0, vidx ) _c_prev_ = torch.index_select( att_w_list[idx][1][1].view(n_bb, -1), 0, vidx ) _a_prev = (_a_prev_, (_h_prev_, _c_prev_)) a_prev.append(_a_prev) z_prev = [ torch.index_select(z_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers) ] c_prev = [ torch.index_select(c_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers) ] # pick ended hyps if i >= minlen: k = 0 penalty_i = (i + 1) * penalty thr = accum_best_scores[:, -1] for samp_i in six.moves.range(batch): if stop_search[samp_i]: k = k + beam continue for beam_j in six.moves.range(beam): _vscore = None if eos_vscores[samp_i, beam_j] > thr[samp_i]: yk = y_prev[k][:] if len(yk) <= min( hlens[idx][samp_i] for idx in range(self.num_encs) ): _vscore = eos_vscores[samp_i][beam_j] + penalty_i elif i == maxlen - 1: yk = yseq[k][:] _vscore = vscores[samp_i][beam_j] + penalty_i if _vscore: yk.append(self.eos) if rnnlm: _vscore += recog_args.lm_weight * rnnlm.final( rnnlm_state, index=k ) _score = _vscore.data.cpu().numpy() ended_hyps[samp_i].append( {"yseq": yk, "vscore": _vscore, "score": _score} ) k = k + 1 # end detection stop_search = [ stop_search[samp_i] or end_detect(ended_hyps[samp_i], i) for samp_i in six.moves.range(batch) ] stop_search_summary = list(set(stop_search)) if len(stop_search_summary) == 1 and stop_search_summary[0]: break if rnnlm: rnnlm_state = self._index_select_lm_state(rnnlm_state, 0, vidx) if ctc_scorer[0]: for idx in range(self.num_encs): ctc_state[idx] = ctc_scorer[idx].index_select_state( ctc_state[idx], accum_best_ids ) torch.cuda.empty_cache() dummy_hyps = [ {"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])} ] ended_hyps = [ ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps for samp_i in six.moves.range(batch) ] if normalize_score: for samp_i in six.moves.range(batch): for x in ended_hyps[samp_i]: x["score"] /= len(x["yseq"]) nbest_hyps = [ sorted(ended_hyps[samp_i], key=lambda x: x["score"], reverse=True)[ : min(len(ended_hyps[samp_i]), recog_args.nbest) ] for samp_i in six.moves.range(batch) ] return nbest_hyps
def translate(self, x, trans_args, char_list=None): """Translate source text. :param list x: input source text feature (T,) :param Namespace trans_args: argment Namespace contraining options :param list char_list: list of characters :return: N-best decoding results :rtype: list """ self.eval( ) # NOTE: this is important because self.encode() is not used assert isinstance(x, list) # make a utt list (1) to use the same interface for encoder if self.multilingual: x = to_device( self, torch.from_numpy( np.fromiter(map(int, x[0][1:]), dtype=np.int64))) else: x = to_device( self, torch.from_numpy(np.fromiter(map(int, x[0]), dtype=np.int64))) logging.info("input lengths: " + str(x.size(0))) xs_pad = x.unsqueeze(0) tgt_lang = None if trans_args.tgt_lang: tgt_lang = char_list.index(trans_args.tgt_lang) xs_pad, _ = self.target_forcing(xs_pad, tgt_lang=tgt_lang) h, _ = self.encoder(xs_pad, None) logging.info("encoder output lengths: " + str(h.size(1))) # search parms beam = trans_args.beam_size penalty = trans_args.penalty if trans_args.maxlenratio == 0: maxlen = h.size(1) else: # maxlen >= 1 maxlen = max(1, int(trans_args.maxlenratio * h.size(1))) minlen = int(trans_args.minlenratio * h.size(1)) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis hyp = {"score": 0.0, "yseq": [self.sos]} hyps = [hyp] ended_hyps = [] for i in range(maxlen): logging.debug("position " + str(i)) # batchfy ys = h.new_zeros((len(hyps), i + 1), dtype=torch.int64) for j, hyp in enumerate(hyps): ys[j, :] = torch.tensor(hyp["yseq"]) ys_mask = subsequent_mask(i + 1).unsqueeze(0).to(h.device) local_scores = self.decoder.forward_one_step( ys, ys_mask, h.repeat([len(hyps), 1, 1]))[0] hyps_best_kept = [] for j, hyp in enumerate(hyps): local_best_scores, local_best_ids = torch.topk( local_scores[j:j + 1], beam, dim=1) for j in range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + float( local_best_scores[0, j]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][:len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x["score"], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypothes: " + str(len(hyps))) if char_list is not None: logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last position in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and trans_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remeined hypothes: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break if char_list is not None: for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]])) logging.debug("number of ended hypothes: " + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x["score"], reverse=True)[:min(len(ended_hyps), trans_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning("there is no N-best results, perform translation " "again with smaller minlenratio.") # should copy becasuse Namespace will be overwritten globally trans_args = Namespace(**vars(trans_args)) trans_args.minlenratio = max(0.0, trans_args.minlenratio - 0.1) return self.translate(x, trans_args, char_list) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info("normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"]))) return nbest_hyps
def recognize_beam(self, h, lpz, recog_args, char_list=None, rnnlm=None): """E2E beam search. Args: h (ndarray): Encoder ouput features (B, T, D) or (T, D). lpz (ndarray): Log probabilities from CTC. recog_args (Namespace): Argment namespace contraining options. char_list (List[str]): List of characters. rnnlm (chainer.Chain): Language model module defined at `espnet.lm.chainer_backend.lm`. Returns: List: N-best decoding results. """ logging.info("input lengths: " + str(h.shape[1])) # initialization n_len = h.shape[1] xp = self.xp h_mask = xp.ones((1, n_len)) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # prepare sos y = self.sos if recog_args.maxlenratio == 0: maxlen = n_len else: maxlen = max(1, int(recog_args.maxlenratio * n_len)) minlen = int(recog_args.minlenratio * n_len) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis if rnnlm: hyp = {"score": 0.0, "yseq": [y], "rnnlm_prev": None} else: hyp = {"score": 0.0, "yseq": [y]} if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz, 0, self.eos, self.xp) hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() hyp["ctc_score_prev"] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] for i in six.moves.range(maxlen): logging.debug("position " + str(i)) hyps_best_kept = [] for hyp in hyps: ys = F.expand_dims(xp.array(hyp["yseq"]), axis=0).data out = self.decoder(ys, h, h_mask) # get nbest local scores and their ids local_att_scores = F.log_softmax(out[:, -1], axis=-1).data if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp["rnnlm_prev"], hyp["yseq"][i] ) local_scores = ( local_att_scores + recog_args.lm_weight * local_lm_scores ) else: local_scores = local_att_scores if lpz is not None: local_best_ids = xp.argsort(local_scores, axis=1)[0, ::-1][ :ctc_beam ] ctc_scores, ctc_states = ctc_prefix_score( hyp["yseq"], local_best_ids, hyp["ctc_state_prev"] ) local_scores = (1.0 - ctc_weight) * local_att_scores[ :, local_best_ids ] + ctc_weight * (ctc_scores - hyp["ctc_score_prev"]) if rnnlm: local_scores += ( recog_args.lm_weight * local_lm_scores[:, local_best_ids] ) joint_best_ids = xp.argsort(local_scores, axis=1)[0, ::-1][:beam] local_best_scores = local_scores[:, joint_best_ids] local_best_ids = local_best_ids[joint_best_ids] else: local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][ :beam ] local_best_scores = local_scores[:, local_best_ids] for j in six.moves.range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + float(local_best_scores[0, j]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[j]) if rnnlm: new_hyp["rnnlm_prev"] = rnnlm_state if lpz is not None: new_hyp["ctc_state_prev"] = ctc_states[joint_best_ids[j]] new_hyp["ctc_score_prev"] = ctc_scores[joint_best_ids[j]] hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x["score"], reverse=True )[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypothesis: " + str(len(hyps))) if char_list is not None: logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) + " score: " + str(hyps[0]["score"]) ) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last postion in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp["score"] += recog_args.lm_weight * rnnlm.final( hyp["rnnlm_prev"] ) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remained hypothes: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break if char_list is not None: for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) ) logging.debug("number of ended hypothes: " + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x["score"], reverse=True ) # [:min(len(ended_hyps), recog_args.nbest)] logging.debug(nbest_hyps) # check number of hypotheis if len(nbest_hyps) == 0: logging.warn( "there is no N-best results, perform recognition " "again with smaller minlenratio." ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info( "normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) ) # remove sos return nbest_hyps
def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None, strm_idx=0): """beam search implementation :param torch.Tensor h: encoder hidden state (T, eprojs) [in multi-encoder case, list of torch.Tensor, [(T1, eprojs), (T2, eprojs), ...] ] :param torch.Tensor lpz: ctc log softmax output (T, odim) [in multi-encoder case, list of torch.Tensor, [(T1, odim), (T2, odim), ...] ] :param Namespace recog_args: argument Namespace containing options :param char_list: list of character strings :param torch.nn.Module rnnlm: language module :param int strm_idx: stream index for speaker parallel attention in multi-speaker case :return: N-best decoding results :rtype: list of dicts """ # to support mutiple encoder asr mode, in single encoder mode, # convert torch.Tensor to List of torch.Tensor if self.num_encs == 1: h = [h] lpz = [lpz] if self.num_encs > 1 and lpz is None: lpz = [lpz] * self.num_encs for idx in range(self.num_encs): logging.info( "Number of Encoder:{}; enc{}: input lengths: {}.".format( self.num_encs, idx + 1, h[0].size(0) ) ) att_idx = min(strm_idx, len(self.att) - 1) # initialization c_list = [self.zero_state(h[0].unsqueeze(0))] z_list = [self.zero_state(h[0].unsqueeze(0))] for _ in six.moves.range(1, self.dlayers): c_list.append(self.zero_state(h[0].unsqueeze(0))) z_list.append(self.zero_state(h[0].unsqueeze(0))) if self.num_encs == 1: a = None self.att[att_idx].reset() # reset pre-computation of h else: a = [None] * (self.num_encs + 1) # atts + han att_w_list = [None] * (self.num_encs + 1) # atts + han att_c_list = [None] * (self.num_encs) # atts for idx in range(self.num_encs + 1): self.att[idx].reset() # reset pre-computation of h in atts and han # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = getattr(recog_args, "ctc_weight", False) # for NMT if lpz[0] is not None and self.num_encs > 1: # weights-ctc, # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss weights_ctc_dec = recog_args.weights_ctc_dec / np.sum( recog_args.weights_ctc_dec ) # normalize logging.info( "ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec]) ) else: weights_ctc_dec = [1.0] # preprate sos if self.replace_sos and recog_args.tgt_lang: y = char_list.index(recog_args.tgt_lang) else: y = self.sos logging.info("<sos> index: " + str(y)) logging.info("<sos> mark: " + char_list[y]) vy = h[0].new_zeros(1).long() maxlen = np.amin([h[idx].size(0) for idx in range(self.num_encs)]) if recog_args.maxlenratio != 0: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * maxlen)) minlen = int(recog_args.minlenratio * maxlen) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis if rnnlm: hyp = { "score": 0.0, "yseq": [y], "c_prev": c_list, "z_prev": z_list, "a_prev": a, "rnnlm_prev": None, } else: hyp = { "score": 0.0, "yseq": [y], "c_prev": c_list, "z_prev": z_list, "a_prev": a, } if lpz[0] is not None: ctc_prefix_score = [ CTCPrefixScore(lpz[idx].detach().numpy(), 0, self.eos, np) for idx in range(self.num_encs) ] hyp["ctc_state_prev"] = [ ctc_prefix_score[idx].initial_state() for idx in range(self.num_encs) ] hyp["ctc_score_prev"] = [0.0] * self.num_encs if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz[0].shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz[0].shape[-1] hyps = [hyp] ended_hyps = [] for i in six.moves.range(maxlen): logging.debug("position " + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp["yseq"][i] ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim if self.num_encs == 1: att_c, att_w = self.att[att_idx]( h[0].unsqueeze(0), [h[0].size(0)], self.dropout_dec[0](hyp["z_prev"][0]), hyp["a_prev"], ) else: for idx in range(self.num_encs): att_c_list[idx], att_w_list[idx] = self.att[idx]( h[idx].unsqueeze(0), [h[idx].size(0)], self.dropout_dec[0](hyp["z_prev"][0]), hyp["a_prev"][idx], ) h_han = torch.stack(att_c_list, dim=1) att_c, att_w_list[self.num_encs] = self.att[self.num_encs]( h_han, [self.num_encs], self.dropout_dec[0](hyp["z_prev"][0]), hyp["a_prev"][self.num_encs], ) ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim) z_list, c_list = self.rnn_forward( ey, z_list, c_list, hyp["z_prev"], hyp["c_prev"] ) # get nbest local scores and their ids if self.context_residual: logits = self.output( torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) ) else: logits = self.output(self.dropout_dec[-1](z_list[-1])) local_att_scores = F.log_softmax(logits, dim=1) if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy) local_scores = ( local_att_scores + recog_args.lm_weight * local_lm_scores ) else: local_scores = local_att_scores if lpz[0] is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1 ) ctc_scores, ctc_states = ( [None] * self.num_encs, [None] * self.num_encs, ) for idx in range(self.num_encs): ctc_scores[idx], ctc_states[idx] = ctc_prefix_score[idx]( hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"][idx] ) local_scores = (1.0 - ctc_weight) * local_att_scores[ :, local_best_ids[0] ] if self.num_encs == 1: local_scores += ctc_weight * torch.from_numpy( ctc_scores[0] - hyp["ctc_score_prev"][0] ) else: for idx in range(self.num_encs): local_scores += ( ctc_weight * weights_ctc_dec[idx] * torch.from_numpy( ctc_scores[idx] - hyp["ctc_score_prev"][idx] ) ) if rnnlm: local_scores += ( recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]] ) local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1 ) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1 ) for j in six.moves.range(beam): new_hyp = {} # [:] is needed! new_hyp["z_prev"] = z_list[:] new_hyp["c_prev"] = c_list[:] if self.num_encs == 1: new_hyp["a_prev"] = att_w[:] else: new_hyp["a_prev"] = [ att_w_list[idx][:] for idx in range(self.num_encs + 1) ] new_hyp["score"] = hyp["score"] + local_best_scores[0, j] new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) if rnnlm: new_hyp["rnnlm_prev"] = rnnlm_state if lpz[0] is not None: new_hyp["ctc_state_prev"] = [ ctc_states[idx][joint_best_ids[0, j]] for idx in range(self.num_encs) ] new_hyp["ctc_score_prev"] = [ ctc_scores[idx][joint_best_ids[0, j]] for idx in range(self.num_encs) ] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x["score"], reverse=True )[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypotheses: " + str(len(hyps))) logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) ) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last position in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypotheses to a final list, # and removed them from current hypotheses # (this will be a problem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp["score"] += recog_args.lm_weight * rnnlm.final( hyp["rnnlm_prev"] ) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remaining hypotheses: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) ) logging.debug("number of ended hypotheses: " + str(len(ended_hyps))) nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ : min(len(ended_hyps), recog_args.nbest) ] # check number of hypotheses if len(nbest_hyps) == 0: logging.warning( "there is no N-best results, " "perform recognition again with smaller minlenratio." ) # should copy because Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) if self.num_encs == 1: return self.recognize_beam(h[0], lpz[0], recog_args, char_list, rnnlm) else: return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info( "normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) ) # remove sos return nbest_hyps
def recognize_beam_batch(self, h, hlens, lpz, recog_args, char_list, rnnlm=None, normalize_score=True, strm_idx=0, tgt_lang_ids=None): logging.info('input lengths: ' + str(h.size(1))) att_idx = min(strm_idx, len(self.att) - 1) h = mask_by_length(h, hlens, 0.0) # search params batch = len(hlens) beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight att_weight = 1.0 - ctc_weight ctc_margin = recog_args.ctc_window_margin n_bb = batch * beam pad_b = to_device(self, torch.arange(batch) * beam).view(-1, 1) max_hlen = int(max(hlens)) if recog_args.maxlenratio == 0: maxlen = max_hlen else: maxlen = max(1, int(recog_args.maxlenratio * max_hlen)) minlen = int(recog_args.minlenratio * max_hlen) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialization c_prev = [ to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) ] z_prev = [ to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) ] c_list = [ to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) ] z_list = [ to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) ] vscores = to_device(self, torch.zeros(batch, beam)) a_prev = None rnnlm_state = None ctc_scorer = None ctc_state = None self.att[att_idx].reset() # reset pre-computation of h if self.replace_sos and recog_args.tgt_lang: logging.info('<sos> index: ' + str(char_list.index(recog_args.tgt_lang))) logging.info('<sos> mark: ' + recog_args.tgt_lang) yseq = [[char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb)] elif tgt_lang_ids is not None: # NOTE: used for evaluation during training yseq = [[tgt_lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb)] else: logging.info('<sos> index: ' + str(self.sos)) logging.info('<sos> mark: ' + char_list[self.sos]) yseq = [[self.sos] for _ in six.moves.range(n_bb)] accum_odim_ids = [self.sos for _ in six.moves.range(n_bb)] stop_search = [False for _ in six.moves.range(batch)] nbest_hyps = [[] for _ in six.moves.range(batch)] ended_hyps = [[] for _ in range(batch)] exp_hlens = hlens.repeat(beam).view(beam, batch).transpose(0, 1).contiguous() exp_hlens = exp_hlens.view(-1).tolist() exp_h = h.unsqueeze(1).repeat(1, beam, 1, 1).contiguous() exp_h = exp_h.view(n_bb, h.size()[1], h.size()[2]) if lpz is not None: scoring_ratio = CTC_SCORING_RATIO if att_weight > 0.0 and not lpz.is_cuda else 0 ctc_scorer = CTCPrefixScoreTH(lpz, hlens, 0, self.eos, beam, scoring_ratio, margin=ctc_margin) for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) vy = to_device(self, torch.LongTensor(self._get_last_yseq(yseq))) ey = self.dropout_emb(self.embed(vy)) att_c, att_w = self.att[att_idx](exp_h, exp_hlens, self.dropout_dec[0](z_prev[0]), a_prev) ey = torch.cat((ey, att_c), dim=1) # attention decoder z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev) if self.context_residual: logits = self.output( torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)) else: logits = self.output(self.dropout_dec[-1](z_list[-1])) local_scores = att_weight * F.log_softmax(logits, dim=1) # rnnlm if rnnlm: rnnlm_state, local_lm_scores = rnnlm.buff_predict( rnnlm_state, vy, n_bb) local_scores = local_scores + recog_args.lm_weight * local_lm_scores # ctc if ctc_scorer: att_w_ = att_w if isinstance(att_w, torch.Tensor) else att_w[0] ctc_state, local_ctc_scores = ctc_scorer( yseq, ctc_state, local_scores, att_w_) local_scores = local_scores + ctc_weight * local_ctc_scores local_scores = local_scores.view(batch, beam, self.odim) if i == 0: local_scores[:, 1:, :] = self.logzero # accumulate scores eos_vscores = local_scores[:, :, self.eos] + vscores vscores = vscores.view(batch, beam, 1).repeat(1, 1, self.odim) vscores[:, :, self.eos] = self.logzero vscores = (vscores + local_scores).view(batch, -1) # global pruning accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1) accum_odim_ids = torch.fmod( accum_best_ids, self.odim).view(-1).data.cpu().tolist() accum_padded_beam_ids = (torch.div(accum_best_ids, self.odim) + pad_b).view(-1).data.cpu().tolist() y_prev = yseq[:][:] yseq = self._index_select_list(yseq, accum_padded_beam_ids) yseq = self._append_ids(yseq, accum_odim_ids) vscores = accum_best_scores vidx = to_device(self, torch.LongTensor(accum_padded_beam_ids)) if isinstance(att_w, torch.Tensor): a_prev = torch.index_select(att_w.view(n_bb, *att_w.shape[1:]), 0, vidx) elif isinstance(att_w, list): # handle the case of multi-head attention a_prev = [ torch.index_select(att_w_one.view(n_bb, -1), 0, vidx) for att_w_one in att_w ] else: # handle the case of location_recurrent when return is a tuple a_prev_ = torch.index_select(att_w[0].view(n_bb, -1), 0, vidx) h_prev_ = torch.index_select(att_w[1][0].view(n_bb, -1), 0, vidx) c_prev_ = torch.index_select(att_w[1][1].view(n_bb, -1), 0, vidx) a_prev = (a_prev_, (h_prev_, c_prev_)) z_prev = [ torch.index_select(z_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers) ] c_prev = [ torch.index_select(c_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers) ] if rnnlm: rnnlm_state = self._index_select_lm_state(rnnlm_state, 0, vidx) if ctc_scorer: ctc_state = ctc_scorer.index_select_state( ctc_state, accum_best_ids) # pick ended hyps if i > minlen: k = 0 penalty_i = (i + 1) * penalty thr = accum_best_scores[:, -1] for samp_i in six.moves.range(batch): if stop_search[samp_i]: k = k + beam continue for beam_j in six.moves.range(beam): if eos_vscores[samp_i, beam_j] > thr[samp_i]: yk = y_prev[k][:] yk.append(self.eos) if len(yk) < hlens[samp_i]: _vscore = eos_vscores[samp_i][ beam_j] + penalty_i _score = _vscore.data.cpu().numpy() ended_hyps[samp_i].append({ 'yseq': yk, 'vscore': _vscore, 'score': _score }) k = k + 1 # end detection stop_search = [ stop_search[samp_i] or end_detect(ended_hyps[samp_i], i) for samp_i in six.moves.range(batch) ] stop_search_summary = list(set(stop_search)) if len(stop_search_summary) == 1 and stop_search_summary[0]: break torch.cuda.empty_cache() dummy_hyps = [{ 'yseq': [self.sos, self.eos], 'score': np.array([-float('inf')]) }] ended_hyps = [ ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps for samp_i in six.moves.range(batch) ] if normalize_score: for samp_i in six.moves.range(batch): for x in ended_hyps[samp_i]: x['score'] /= len(x['yseq']) nbest_hyps = [ sorted( ended_hyps[samp_i], key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps[samp_i]), recog_args.nbest)] for samp_i in six.moves.range(batch) ] return nbest_hyps
def recognize_beam_batch(self, h, hlens, lpz, recog_args, char_list, rnnlm=None, normalize_score=True, strm_idx=0): logging.info('input lengths: ' + str(h.size(1))) att_idx = min(strm_idx, len(self.att) - 1) h = mask_by_length(h, hlens, 0.0) # search params batch = len(hlens) beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight att_weight = 1.0 - ctc_weight n_bb = batch * beam n_bo = beam * self.odim n_bbo = n_bb * self.odim pad_b = to_device(self, torch.LongTensor([i * beam for i in six.moves.range(batch)]).view(-1, 1)) pad_bo = to_device(self, torch.LongTensor([i * n_bo for i in six.moves.range(batch)]).view(-1, 1)) pad_o = to_device(self, torch.LongTensor([i * self.odim for i in six.moves.range(n_bb)]).view(-1, 1)) max_hlen = int(max(hlens)) if recog_args.maxlenratio == 0: maxlen = max_hlen else: maxlen = max(1, int(recog_args.maxlenratio * max_hlen)) minlen = int(recog_args.minlenratio * max_hlen) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialization c_prev = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)] z_prev = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)] c_list = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)] z_list = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)] vscores = to_device(self, torch.zeros(batch, beam)) a_prev = None rnnlm_prev = None self.att[att_idx].reset() # reset pre-computation of h yseq = [[self.sos] for _ in six.moves.range(n_bb)] accum_odim_ids = [self.sos for _ in six.moves.range(n_bb)] stop_search = [False for _ in six.moves.range(batch)] nbest_hyps = [[] for _ in six.moves.range(batch)] ended_hyps = [[] for _ in range(batch)] exp_hlens = hlens.repeat(beam).view(beam, batch).transpose(0, 1).contiguous() exp_hlens = exp_hlens.view(-1).tolist() exp_h = h.unsqueeze(1).repeat(1, beam, 1, 1).contiguous() exp_h = exp_h.view(n_bb, h.size()[1], h.size()[2]) if lpz is not None: device_id = torch.cuda.device_of(next(self.parameters()).data).idx ctc_prefix_score = CTCPrefixScoreTH(lpz, 0, self.eos, beam, exp_hlens, device_id) ctc_states_prev = ctc_prefix_score.initial_state() ctc_scores_prev = to_device(self, torch.zeros(batch, n_bo)) for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) vy = to_device(self, torch.LongTensor(get_last_yseq(yseq))) ey = self.dropout_emb(self.embed(vy)) att_c, att_w = self.att[att_idx](exp_h, exp_hlens, self.dropout_dec[0](z_prev[0]), a_prev) ey = torch.cat((ey, att_c), dim=1) # attention decoder z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev) local_scores = att_weight * F.log_softmax(self.output(self.dropout_dec[-1](z_list[-1])), dim=1) # rnnlm if rnnlm: rnnlm_state, local_lm_scores = rnnlm.buff_predict(rnnlm_prev, vy, n_bb) local_scores = local_scores + recog_args.lm_weight * local_lm_scores local_scores = local_scores.view(batch, n_bo) # ctc if lpz is not None: ctc_scores, ctc_states = ctc_prefix_score(yseq, ctc_states_prev, accum_odim_ids) ctc_scores = ctc_scores.view(batch, n_bo) local_scores = local_scores + ctc_weight * (ctc_scores - ctc_scores_prev) local_scores = local_scores.view(batch, beam, self.odim) if i == 0: local_scores[:, 1:, :] = self.logzero local_best_scores, local_best_odims = torch.topk(local_scores.view(batch, beam, self.odim), beam, 2) # local pruning (via xp) local_scores = np.full((n_bbo,), self.logzero) _best_odims = local_best_odims.view(n_bb, beam) + pad_o _best_odims = _best_odims.view(-1).cpu().numpy() _best_score = local_best_scores.view(-1).cpu().detach().numpy() local_scores[_best_odims] = _best_score local_scores = to_device(self, torch.from_numpy(local_scores).float()).view(batch, beam, self.odim) # (or indexing) # local_scores = to_cuda(self, torch.full((batch, beam, self.odim), self.logzero)) # _best_odims = local_best_odims # _best_score = local_best_scores # for si in six.moves.range(batch): # for bj in six.moves.range(beam): # for bk in six.moves.range(beam): # local_scores[si, bj, _best_odims[si, bj, bk]] = _best_score[si, bj, bk] eos_vscores = local_scores[:, :, self.eos] + vscores vscores = vscores.view(batch, beam, 1).repeat(1, 1, self.odim) vscores[:, :, self.eos] = self.logzero vscores = (vscores + local_scores).view(batch, n_bo) # global pruning accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1) accum_odim_ids = torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist() accum_padded_odim_ids = (torch.fmod(accum_best_ids, n_bo) + pad_bo).view(-1).data.cpu().tolist() accum_padded_beam_ids = (torch.div(accum_best_ids, self.odim) + pad_b).view(-1).data.cpu().tolist() y_prev = yseq[:][:] yseq = index_select_list(yseq, accum_padded_beam_ids) yseq = append_ids(yseq, accum_odim_ids) vscores = accum_best_scores vidx = to_device(self, torch.LongTensor(accum_padded_beam_ids)) if not isinstance(a_prev, list): a_prev = torch.index_select(att_w.view(n_bb, -1), 0, vidx) else: # adapt for multi-head attention a_prev = [torch.index_select(att_w_one.view(n_bb, -1), 0, vidx) for att_w_one in att_w] z_prev = [torch.index_select(z_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers)] c_prev = [torch.index_select(c_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers)] if rnnlm: rnnlm_prev = index_select_lm_state(rnnlm_state, 0, vidx) if lpz is not None: ctc_vidx = to_device(self, torch.LongTensor(accum_padded_odim_ids)) ctc_scores_prev = torch.index_select(ctc_scores.view(-1), 0, ctc_vidx) ctc_scores_prev = ctc_scores_prev.view(-1, 1).repeat(1, self.odim).view(batch, n_bo) ctc_states = torch.transpose(ctc_states, 1, 3).contiguous() ctc_states = ctc_states.view(n_bbo, 2, -1) ctc_states_prev = torch.index_select(ctc_states, 0, ctc_vidx).view(n_bb, 2, -1) ctc_states_prev = torch.transpose(ctc_states_prev, 1, 2) # pick ended hyps if i > minlen: k = 0 penalty_i = (i + 1) * penalty thr = accum_best_scores[:, -1] for samp_i in six.moves.range(batch): if stop_search[samp_i]: k = k + beam continue for beam_j in six.moves.range(beam): if eos_vscores[samp_i, beam_j] > thr[samp_i]: yk = y_prev[k][:] yk.append(self.eos) if len(yk) < hlens[samp_i]: _vscore = eos_vscores[samp_i][beam_j] + penalty_i if normalize_score: _vscore = _vscore / len(yk) _score = _vscore.data.cpu().numpy() ended_hyps[samp_i].append({'yseq': yk, 'vscore': _vscore, 'score': _score}) k = k + 1 # end detection stop_search = [stop_search[samp_i] or end_detect(ended_hyps[samp_i], i) for samp_i in six.moves.range(batch)] stop_search_summary = list(set(stop_search)) if len(stop_search_summary) == 1 and stop_search_summary[0]: break torch.cuda.empty_cache() dummy_hyps = [{'yseq': [self.sos, self.eos], 'score': np.array([-float('inf')])}] ended_hyps = [ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps for samp_i in six.moves.range(batch)] nbest_hyps = [sorted(ended_hyps[samp_i], key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps[samp_i]), recog_args.nbest)] for samp_i in six.moves.range(batch)] return nbest_hyps
def infer(x, encoder_rt, ctc_softmax_rt, decoder_fos_rt): ctc_weight = 0.5 beam_size = 1 penalty = 0.0 maxlenratio = 0.0 nbest = 1 sos = eos = 7442 # enc_output = self.encode(x).unsqueeze(0) ort_inputs = {"x": x.numpy()} enc_output = encoder_rt.run(None, ort_inputs) enc_output = torch.tensor(enc_output) enc_output = enc_output.squeeze(0) # print(f"enc_output shape: {enc_output.shape}") # lpz = self.ctc.log_softmax(enc_output) # lpz = lpz.squeeze(0) ort_inputs = {"enc_output": enc_output.numpy()} lpz = ctc_softmax_rt.run(None, ort_inputs) lpz = torch.tensor(lpz) lpz = lpz.squeeze(0).squeeze(0) # print(f"lpz shape: {lpz.shape}") h = enc_output.squeeze(0) # print(f"h shape: {h.shape}") # preprare sos y = sos maxlen = h.shape[0] minlen = 0.0 ctc_beam = 1 # initialize hypothesis hyp = {'score': 0.0, 'yseq': [y]} ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 # pre-pruning based on attention scores hyps = [hyp] ended_hyps = [] for i in six.moves.range(maxlen): hyps_best_kept = [] for hyp in hyps: # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp['yseq']).unsqueeze(0) ort_inputs = {"ys": ys.numpy(), "ys_mask": ys_mask.numpy(), "enc_output": enc_output.numpy()} local_att_scores = decoder_fos_rt.run(None, ort_inputs) local_att_scores = torch.tensor(local_att_scores[0]) local_scores = local_att_scores local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) local_best_scores, joint_best_ids = torch.topk(local_scores, beam_size, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] for j in six.moves.range(beam_size): new_hyp = {} new_hyp['score'] = hyp['score'] + float(local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam_size] # sort and get nbest hyps = hyps_best_kept # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: for hyp in hyps: hyp['yseq'].append(eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and maxlenratio == 0.0: break hyps = remained_hyps if len(hyps) > 0: pass else: break nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), nbest)] # return nbest_hyps return torch.tensor(nbest_hyps[0]['yseq'])
def forward(self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0) -> List[Hypothesis]: """Perform beam search. Args: x (torch.Tensor): Encoded speech feature (T, D) maxlenratio (float): Input length ratio to obtain max output length. If maxlenratio=0.0 (default), it uses a end-detect function to automatically find maximum hypothesis lengths If maxlenratio<0.0, its absolute value is interpreted as a constant max output length. minlenratio (float): Input length ratio to obtain min output length. Returns: list[Hypothesis]: N-best decoding results """ # set length bounds if maxlenratio == 0: maxlen = x.shape[0] elif maxlenratio < 0: maxlen = -1 * int(maxlenratio) else: maxlen = max(1, int(maxlenratio * x.size(0))) minlen = int(minlenratio * x.size(0)) logging.info("decoder input length: " + str(x.shape[0])) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # main loop of prefix search running_hyps = self.init_hyp(x) ended_hyps = [] for i in range(maxlen): logging.debug("position " + str(i)) best = self.search(running_hyps, x) # post process of one iteration running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) # end detection if maxlenratio == 0.0 and end_detect( [h.asdict() for h in ended_hyps], i): logging.info(f"end detected at {i}") break if len(running_hyps) == 0: logging.info("no hypothesis. Finish decoding.") break else: logging.debug(f"remained hypotheses: {len(running_hyps)}") nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) # check the number of hypotheses reaching to eos if len(nbest_hyps) == 0: logging.warning("there is no N-best results, perform recognition " "again with smaller minlenratio.") return ([] if minlenratio < 0.1 else self.forward( x, maxlenratio, max(0.0, minlenratio - 0.1))) # report the best result best = nbest_hyps[0] for k, v in best.scores.items(): logging.info( f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}" ) logging.info(f"total log probability: {best.score:.2f}") logging.info( f"normalized log probability: {best.score / len(best.yseq):.2f}") logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}") if self.token_list is not None: logging.info("best hypo: " + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + "\n") return nbest_hyps
def forward(self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0) -> List[Hypothesis]: """Perform beam search. Args: x (torch.Tensor): Encoded speech feature (T, D) maxlenratio (float): Input length ratio to obtain max output length. If maxlenratio=0.0 (default), it uses a end-detect function to automatically find maximum hypothesis lengths minlenratio (float): Input length ratio to obtain min output length. Returns: list[Hypothesis]: N-best decoding results """ self.conservative = True # always true if self.block_size and self.hop_size and self.look_ahead: cur_end_frame = int(self.block_size - self.look_ahead) else: cur_end_frame = x.shape[0] process_idx = 0 if cur_end_frame < x.shape[0]: h = x.narrow(0, 0, cur_end_frame) else: h = x # set length bounds if maxlenratio == 0: maxlen = x.shape[0] else: maxlen = max(1, int(maxlenratio * x.size(0))) minlen = int(minlenratio * x.size(0)) logging.info("decoder input length: " + str(x.shape[0])) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # main loop of prefix search running_hyps = self.init_hyp(h) prev_hyps = [] ended_hyps = [] prev_repeat = False continue_decode = True while continue_decode: move_to_next_block = False if cur_end_frame < x.shape[0]: h = x.narrow(0, 0, cur_end_frame) else: h = x # extend states for ctc self.extend(h, running_hyps) while process_idx < maxlen: logging.debug("position " + str(process_idx)) best = self.search(running_hyps, x) if process_idx == maxlen - 1: # end decoding running_hyps = self.post_process(process_idx, maxlen, maxlenratio, best, ended_hyps) n_batch = best.yseq.shape[0] local_ended_hyps = [] is_local_eos = (best.yseq[torch.arange(n_batch), best.length - 1] == self.eos) for i in range(is_local_eos.shape[0]): if is_local_eos[i]: hyp = self._select(best, i) local_ended_hyps.append(hyp) # NOTE(tsunoo): check repetitions here # This is a implicit implementation of # Eq (11) in https://arxiv.org/abs/2006.14941 # A flag prev_repeat is used instead of using set elif (not prev_repeat and best.yseq[i, -1] in best.yseq[i, :-1] and cur_end_frame < x.shape[0]): move_to_next_block = True prev_repeat = True if maxlenratio == 0.0 and end_detect( [lh.asdict() for lh in local_ended_hyps], process_idx): logging.info(f"end detected at {process_idx}") continue_decode = False break if len(local_ended_hyps) > 0 and cur_end_frame < x.shape[0]: move_to_next_block = True if move_to_next_block: if (self.hop_size and cur_end_frame + int(self.hop_size) + int(self.look_ahead) < x.shape[0]): cur_end_frame += int(self.hop_size) else: cur_end_frame = x.shape[0] logging.debug("Going to next block: %d", cur_end_frame) if process_idx > 1 and len( prev_hyps) > 0 and self.conservative: running_hyps = prev_hyps process_idx -= 1 prev_hyps = [] break prev_repeat = False prev_hyps = running_hyps running_hyps = self.post_process(process_idx, maxlen, maxlenratio, best, ended_hyps) if cur_end_frame >= x.shape[0]: for hyp in local_ended_hyps: ended_hyps.append(hyp) if len(running_hyps) == 0: logging.info("no hypothesis. Finish decoding.") continue_decode = False break else: logging.debug(f"remained hypotheses: {len(running_hyps)}") # increment number process_idx += 1 nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) # check the number of hypotheses reaching to eos if len(nbest_hyps) == 0: logging.warning("there is no N-best results, perform recognition " "again with smaller minlenratio.") return ([] if minlenratio < 0.1 else self.forward( x, maxlenratio, max(0.0, minlenratio - 0.1))) # report the best result best = nbest_hyps[0] for k, v in best.scores.items(): logging.info( f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}" ) logging.info(f"total log probability: {best.score:.2f}") logging.info( f"normalized log probability: {best.score / len(best.yseq):.2f}") logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}") if self.token_list is not None: logging.info("best hypo: " + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + "\n") return nbest_hyps
def translate(self, x, trans_args, char_list=None, rnnlm=None, use_jit=False): """Translate source text. :param list x: input source text feature (T,) :param Namespace trans_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ self.eval( ) # NOTE: this is important because self.encode() is not used assert isinstance(x, list) # make a utt list (1) to use the same interface for encoder if self.multilingual: x = to_device( self, torch.from_numpy( np.fromiter(map(int, x[0][1:]), dtype=np.int64))) else: x = to_device( self, torch.from_numpy(np.fromiter(map(int, x[0]), dtype=np.int64))) xs_pad = x.unsqueeze(0) tgt_lang = None if trans_args.tgt_lang: tgt_lang = char_list.index(trans_args.tgt_lang) xs_pad, _ = self.target_forcing(xs_pad, tgt_lang=tgt_lang) enc_output, _ = self.encoder(xs_pad, None) h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) # search parms beam = trans_args.beam_size penalty = trans_args.penalty # preprare sos y = self.sos vy = h.new_zeros(1).long() if trans_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(trans_args.maxlenratio * h.size(0))) minlen = int(trans_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp['yseq']).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + trans_args.lm_weight * local_lm_scores else: local_scores = local_att_scores local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += trans_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and trans_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), trans_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally trans_args = Namespace(**vars(trans_args)) trans_args.minlenratio = max(0.0, trans_args.minlenratio - 0.1) return self.translate(x, trans_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps
def forward(self, x): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ ctc_weight = 0.5 beam_size = 1 penalty = 0.0 maxlenratio = 0.0 nbest = 1 sos = eos = 7442 import onnxruntime # enc_output = self.encode(x).unsqueeze(0) encoder_rt = onnxruntime.InferenceSession("encoder.onnx") ort_inputs = {"x": x.numpy()} enc_output = encoder_rt.run(None, ort_inputs) enc_output = torch.tensor(enc_output) enc_output = enc_output.squeeze(0) # print(f"enc_output shape: {enc_output.shape}") # lpz = self.ctc.log_softmax(enc_output) # lpz = lpz.squeeze(0) ctc_softmax_rt = onnxruntime.InferenceSession("ctc_softmax.onnx") ort_inputs = {"enc_output": enc_output.numpy()} lpz = ctc_softmax_rt.run(None, ort_inputs) lpz = torch.tensor(lpz) lpz = lpz.squeeze(0).squeeze(0) # print(f"lpz shape: {lpz.shape}") h = enc_output.squeeze(0) # print(f"h shape: {h.shape}") # preprare sos y = sos maxlen = h.shape[0] minlen = 0.0 # initialize hypothesis hyp = {'score': 0.0, 'yseq': [y]} ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 # pre-pruning based on attention scores # ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) ctc_beam = 1 hyps = [hyp] ended_hyps = [] decoder_fos_rt = onnxruntime.InferenceSession("decoder_fos.onnx") for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp['yseq']).unsqueeze(0) ort_inputs = { "ys": ys.numpy(), "ys_mask": ys_mask.numpy(), "enc_output": enc_output.numpy() } local_att_scores = decoder_fos_rt.run(None, ort_inputs) local_att_scores = torch.tensor(local_att_scores[0]) # local_att_scores = self.decoder.forward_one_step(ys, ys_mask, enc_output)[0] local_scores = local_att_scores # print(local_scores.shape) 1, 7443 local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) # print(local_best_scores.shape) 1, 1 ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) local_best_scores, joint_best_ids = torch.topk(local_scores, beam_size, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] for j in six.moves.range(beam_size): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam_size] # sort and get nbest hyps = hyps_best_kept # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: for hyp in hyps: hyp['yseq'].append(eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and maxlenratio == 0.0: break hyps = remained_hyps if len(hyps) > 0: pass else: break nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), nbest)] # return nbest_hyps return torch.tensor(nbest_hyps[0]['yseq'])
def recognize_beam(self, h, lpz, recog_args, char_list=None, rnnlm=None): """beam search implementation :param h: :param lpz: :param recog_args: :param char_list: :param rnnlm: :return: """ logging.info('input lengths: ' + str(h.shape[0])) # initialization xp = self.xp h_mask = xp.ones((1, h.shape[0])) batch = 1 # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # prepare sos y = self.sos if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: maxlen = max(1, int(recog_args.maxlenratio * h.shape[0])) minlen = int(recog_args.minlenratio * h.shape[0]) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz, 0, self.eos, self.xp) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: ys = F.expand_dims(xp.array(hyp['yseq']), axis=0).data yy_mask = self.make_attention_mask(ys, ys) yy_mask *= self.make_history_mask(ys) xy_mask = self.make_attention_mask(ys, h_mask) out = self.decoder(ys, yy_mask, h, xy_mask).reshape(batch, -1, self.odim) # get nbest local scores and their ids local_att_scores = F.log_softmax(out[:, -1], axis=-1).data if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], hyp['yseq'][i]) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_ids = xp.argsort(local_scores, axis=1)[0, ::-1][:ctc_beam] ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids, hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids] \ + ctc_weight * (ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids] joint_best_ids = xp.argsort(local_scores, axis=1)[0, ::-1][:beam] local_best_scores = local_scores[:, joint_best_ids] local_best_ids = local_best_ids[joint_best_ids] else: local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][:beam] local_best_scores = local_scores[:, local_best_ids] for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[ joint_best_ids[j]] new_hyp['ctc_score_prev'] = ctc_scores[ joint_best_ids[j]] hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothesis: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]]) + ' score: ' + str(hyps[0]['score'])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remained hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True) # [:min(len(ended_hyps), recog_args.nbest)] logging.debug(nbest_hyps) # check number of hypotheis if len(nbest_hyps) == 0: logging.warn( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) # remove sos return nbest_hyps
def translate( self, x, trans_args, char_list=None, rnnlm=None, use_jit=False, ): """Translate input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace trans_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ # preprate sos if getattr(trans_args, "tgt_lang", False): if self.replace_sos: y = char_list.index(trans_args.tgt_lang) else: y = self.sos logging.info("<sos> index: " + str(y)) logging.info("<sos> mark: " + char_list[y]) enc_output = self.encode(x).unsqueeze(0) h = enc_output.squeeze(0) logging.info("input lengths: " + str(h.size(0))) # search parms beam = trans_args.beam_size penalty = trans_args.penalty vy = h.new_zeros(1).long() if trans_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(trans_args.maxlenratio * h.size(0))) minlen = int(trans_args.minlenratio * h.size(0)) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis if rnnlm: hyp = {"score": 0.0, "yseq": [y], "rnnlm_prev": None} else: hyp = {"score": 0.0, "yseq": [y]} hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug("position " + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp["yseq"][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp["yseq"]).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp["rnnlm_prev"], vy) local_scores = (local_att_scores + trans_args.lm_weight * local_lm_scores) else: local_scores = local_att_scores local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + float( local_best_scores[0, j]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][:len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) if rnnlm: new_hyp["rnnlm_prev"] = rnnlm_state # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x["score"], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypothes: " + str(len(hyps))) if char_list is not None: logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last postion in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp["score"] += trans_args.lm_weight * rnnlm.final( hyp["rnnlm_prev"]) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and trans_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remeined hypothes: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break if char_list is not None: for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]])) logging.debug("number of ended hypothes: " + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x["score"], reverse=True)[:min(len(ended_hyps), trans_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning("there is no N-best results, perform translation " "again with smaller minlenratio.") # should copy becasuse Namespace will be overwritten globally trans_args = Namespace(**vars(trans_args)) trans_args.minlenratio = max(0.0, trans_args.minlenratio - 0.1) return self.translate(x, trans_args, char_list, rnnlm) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info("normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"]))) return nbest_hyps
def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) if self.mtlalpha == 1.0: recog_args.ctc_weight = 1.0 logging.info("Set to pure CTC decoding mode.") if self.mtlalpha > 0 and recog_args.ctc_weight == 1.0: from itertools import groupby lpz = self.ctc.argmax(enc_output) collapsed_indices = [x[0] for x in groupby(lpz[0])] hyp = [x for x in filter(lambda x: x != self.blank, collapsed_indices)] nbest_hyps = [{"score": 0.0, "yseq": [self.sos] + hyp}] if recog_args.beam_size > 1: raise NotImplementedError("Pure CTC beam search is not implemented.") # TODO(hirofumi0810): Implement beam search return nbest_hyps elif self.mtlalpha > 0 and recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info("input lengths: " + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis if rnnlm: hyp = {"score": 0.0, "yseq": [y], "rnnlm_prev": None} else: hyp = {"score": 0.0, "yseq": [y]} if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() hyp["ctc_score_prev"] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug("position " + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp["yseq"][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp["yseq"]).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output) ) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output )[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy) local_scores = ( local_att_scores + recog_args.lm_weight * local_lm_scores ) else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1 ) ctc_scores, ctc_states = ctc_prefix_score( hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"] ) local_scores = (1.0 - ctc_weight) * local_att_scores[ :, local_best_ids[0] ] + ctc_weight * torch.from_numpy( ctc_scores - hyp["ctc_score_prev"] ) if rnnlm: local_scores += ( recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]] ) local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1 ) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1 ) for j in six.moves.range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + float(local_best_scores[0, j]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) if rnnlm: new_hyp["rnnlm_prev"] = rnnlm_state if lpz is not None: new_hyp["ctc_state_prev"] = ctc_states[joint_best_ids[0, j]] new_hyp["ctc_score_prev"] = ctc_scores[joint_best_ids[0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x["score"], reverse=True )[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypothes: " + str(len(hyps))) if char_list is not None: logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) ) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last postion in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp["score"] += recog_args.lm_weight * rnnlm.final( hyp["rnnlm_prev"] ) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remeined hypothes: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break if char_list is not None: for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) ) logging.debug("number of ended hypothes: " + str(len(ended_hyps))) nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ : min(len(ended_hyps), recog_args.nbest) ] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( "there is no N-best results, perform recognition " "again with smaller minlenratio." ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info( "normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) ) return nbest_hyps
def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None, strm_idx=0): """beam search implementation :param torch.Tensor h: encoder hidden state (T, eprojs) :param torch.Tensor lpz: ctc log softmax output (T, odim) :param Namespace recog_args: argument Namespace containing options :param char_list: list of character strings :param torch.nn.Module rnnlm: language module :param int strm_idx: stream index for speaker parallel attention in multi-speaker case :return: N-best decoding results :rtype: list of dicts """ logging.info('input lengths: ' + str(h.size(0))) att_idx = min(strm_idx, len(self.att) - 1) # initialization c_list = [self.zero_state(h.unsqueeze(0))] z_list = [self.zero_state(h.unsqueeze(0))] for _ in six.moves.range(1, self.dlayers): c_list.append(self.zero_state(h.unsqueeze(0))) z_list.append(self.zero_state(h.unsqueeze(0))) a = None self.att[att_idx].reset() # reset pre-computation of h # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprate sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a, 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a} if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, np) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim ey.unsqueeze(0) att_c, att_w = self.att[att_idx](h.unsqueeze(0), [h.size(0)], self.dropout_dec[0](hyp['z_prev'][0]), hyp['a_prev']) ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim) z_list, c_list = self.rnn_forward(ey, z_list, c_list, hyp['z_prev'], hyp['c_prev']) # get nbest local scores and their ids local_att_scores = F.log_softmax(self.output(self.dropout_dec[-1](z_list[-1])), dim=1) if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict(hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]] local_best_scores, joint_best_ids = torch.topk(local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} # [:] is needed! new_hyp['z_prev'] = z_list[:] new_hyp['c_prev'] = c_list[:] new_hyp['a_prev'] = att_w[:] new_hyp['score'] = hyp['score'] + local_best_scores[0, j] new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypotheses: ' + str(len(hyps))) logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last position in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypotheses to a final list, and removed them from current hypotheses # (this will be a problem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remaining hypotheses: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypotheses: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheses if len(nbest_hyps) == 0: logging.warning('there is no N-best results, perform recognition again with smaller minlenratio.') # should copy because Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) # remove sos return nbest_hyps
def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: import numpy from espnet.nets.ctc_prefix_score import CTCPrefixScore ctc_prefix_score = CTCPrefixScore(lpz.cpu().detach().numpy(), 0, self.eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0).cuda() ys = torch.tensor(hyp['yseq']).unsqueeze(0).cuda() # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0].cpu(), hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]].cpu() \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[ 0]].cpu() local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[ 0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps
def process_one_block(self, h, is_final, maxlen, maxlenratio): """Recognize one block.""" # extend states for ctc self.extend(h, self.running_hyps) while self.process_idx < maxlen: logging.debug("position " + str(self.process_idx)) best = self.search(self.running_hyps, h) if self.process_idx == maxlen - 1: # end decoding self.running_hyps = self.post_process(self.process_idx, maxlen, maxlenratio, best, self.ended_hyps) n_batch = best.yseq.shape[0] local_ended_hyps = [] is_local_eos = best.yseq[torch.arange(n_batch), best.length - 1] == self.eos prev_repeat = False for i in range(is_local_eos.shape[0]): if is_local_eos[i]: hyp = self._select(best, i) local_ended_hyps.append(hyp) # NOTE(tsunoo): check repetitions here # This is a implicit implementation of # Eq (11) in https://arxiv.org/abs/2006.14941 # A flag prev_repeat is used instead of using set # NOTE(fujihara): I made it possible to turned off # the below lines using disable_repetition_detection flag, # because this criteria is too sensitive that the beam # search starts only after the entire inputs are available. # Empirically, this flag didn't affect the performance. elif (not self.disable_repetition_detection and not prev_repeat and best.yseq[i, -1] in best.yseq[i, :-1] and not is_final): prev_repeat = True if prev_repeat: logging.info("Detected repetition.") break if (is_final and maxlenratio == 0.0 and end_detect( [lh.asdict() for lh in self.ended_hyps], self.process_idx)): logging.info(f"end detected at {self.process_idx}") return self.assemble_hyps(self.ended_hyps) if len(local_ended_hyps) > 0 and not is_final: logging.info("Detected hyp(s) reaching EOS in this block.") break self.prev_hyps = self.running_hyps self.running_hyps = self.post_process(self.process_idx, maxlen, maxlenratio, best, self.ended_hyps) if is_final: for hyp in local_ended_hyps: self.ended_hyps.append(hyp) if len(self.running_hyps) == 0: logging.info("no hypothesis. Finish decoding.") return self.assemble_hyps(self.ended_hyps) else: logging.debug(f"remained hypotheses: {len(self.running_hyps)}") # increment number self.process_idx += 1 if is_final: return self.assemble_hyps(self.ended_hyps) else: for hyp in self.ended_hyps: local_ended_hyps.append(hyp) rets = self.assemble_hyps(local_ended_hyps) if self.process_idx > 1 and len(self.prev_hyps) > 0: self.running_hyps = self.prev_hyps self.process_idx -= 1 self.prev_hyps = [] # N-best results return rets