def batch_init_state(self, x: torch.Tensor): """Get an initial state for decoding. Args: x (torch.Tensor): The encoded feature tensor Returns: initial state """ logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1 xlen = torch.tensor([logp.size(1)]) self.impl = CTCPrefixScoreTH(logp, xlen, 0, self.eos) return None
def init_state(self, x: torch.Tensor): """Get an initial state for decoding. Args: x (torch.Tensor): The encoded feature tensor Returns: initial state """ logp = self.ctc.log_softmax(x.unsqueeze(0)).detach().squeeze(0).cpu().numpy() # TODO(karita): use CTCPrefixScoreTH self.impl = CTCPrefixScore(logp, 0, self.eos, np) return 0, self.impl.initial_state()
def recognize_beam_batch( self, h, hlens, lpz, recog_args, char_list, rnnlm=None, normalize_score=True, strm_idx=0, lang_ids=None, ): # to support mutiple encoder asr mode, in single encoder mode, # convert torch.Tensor to List of torch.Tensor if self.num_encs == 1: h = [h] hlens = [hlens] lpz = [lpz] if self.num_encs > 1 and lpz is None: lpz = [lpz] * self.num_encs att_idx = min(strm_idx, len(self.att) - 1) for idx in range(self.num_encs): logging.info( "Number of Encoder:{}; enc{}: input lengths: {}.".format( self.num_encs, idx + 1, h[idx].size(1))) h[idx] = mask_by_length(h[idx], hlens[idx], 0.0) # search params batch = len(hlens[0]) beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = getattr(recog_args, "ctc_weight", 0) # for NMT att_weight = 1.0 - ctc_weight ctc_margin = getattr(recog_args, "ctc_window_margin", 0) # use getattr to keep compatibility # weights-ctc, # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss if lpz[0] is not None and self.num_encs > 1: weights_ctc_dec = recog_args.weights_ctc_dec / np.sum( recog_args.weights_ctc_dec) # normalize logging.info("ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec])) else: weights_ctc_dec = [1.0] n_bb = batch * beam pad_b = to_device(self, torch.arange(batch) * beam).view(-1, 1) max_hlen = np.amin([max(hlens[idx]) for idx in range(self.num_encs)]) if recog_args.maxlenratio == 0: maxlen = max_hlen else: maxlen = max(1, int(recog_args.maxlenratio * max_hlen)) minlen = int(recog_args.minlenratio * max_hlen) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialization c_prev = [ to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) ] z_prev = [ to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) ] c_list = [ to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) ] z_list = [ to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) ] vscores = to_device(self, torch.zeros(batch, beam)) rnnlm_state = None if self.num_encs == 1: a_prev = [None] att_w_list, ctc_scorer, ctc_state = [None], [None], [None] self.att[att_idx].reset() # reset pre-computation of h else: a_prev = [None] * (self.num_encs + 1) # atts + han att_w_list = [None] * (self.num_encs + 1) # atts + han att_c_list = [None] * (self.num_encs) # atts ctc_scorer, ctc_state = [None] * (self.num_encs), [None] * ( self.num_encs) for idx in range(self.num_encs + 1): self.att[idx].reset( ) # reset pre-computation of h in atts and han if self.replace_sos and recog_args.tgt_lang: logging.info("<sos> index: " + str(char_list.index(recog_args.tgt_lang))) logging.info("<sos> mark: " + recog_args.tgt_lang) yseq = [[char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb)] elif lang_ids is not None: # NOTE: used for evaluation during training yseq = [[lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb)] else: logging.info("<sos> index: " + str(self.sos)) logging.info("<sos> mark: " + char_list[self.sos]) yseq = [[self.sos] for _ in six.moves.range(n_bb)] accum_odim_ids = [self.sos for _ in six.moves.range(n_bb)] stop_search = [False for _ in six.moves.range(batch)] nbest_hyps = [[] for _ in six.moves.range(batch)] ended_hyps = [[] for _ in range(batch)] exp_hlens = [ hlens[idx].repeat(beam).view(beam, batch).transpose(0, 1).contiguous() for idx in range(self.num_encs) ] exp_hlens = [ exp_hlens[idx].view(-1).tolist() for idx in range(self.num_encs) ] exp_h = [ h[idx].unsqueeze(1).repeat(1, beam, 1, 1).contiguous() for idx in range(self.num_encs) ] exp_h = [ exp_h[idx].view(n_bb, h[idx].size()[1], h[idx].size()[2]) for idx in range(self.num_encs) ] if lpz[0] is not None: scoring_ratio = (CTC_SCORING_RATIO if att_weight > 0.0 and not lpz[0].is_cuda else 0) ctc_scorer = [ CTCPrefixScoreTH( lpz[idx], hlens[idx], 0, self.eos, beam, scoring_ratio, margin=ctc_margin, ) for idx in range(self.num_encs) ] for i in six.moves.range(maxlen): logging.debug("position " + str(i)) vy = to_device(self, torch.LongTensor(self._get_last_yseq(yseq))) ey = self.dropout_emb(self.embed(vy)) if self.num_encs == 1: att_c, att_w = self.att[att_idx](exp_h[0], exp_hlens[0], self.dropout_dec[0]( z_prev[0]), a_prev[0]) att_w_list = [att_w] else: for idx in range(self.num_encs): att_c_list[idx], att_w_list[idx] = self.att[idx]( exp_h[idx], exp_hlens[idx], self.dropout_dec[0](z_prev[0]), a_prev[idx], ) exp_h_han = torch.stack(att_c_list, dim=1) att_c, att_w_list[self.num_encs] = self.att[self.num_encs]( exp_h_han, [self.num_encs] * n_bb, self.dropout_dec[0](z_prev[0]), a_prev[self.num_encs], ) ey = torch.cat((ey, att_c), dim=1) # attention decoder z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev) if self.context_residual: logits = self.output( torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)) else: logits = self.output(self.dropout_dec[-1](z_list[-1])) local_scores = att_weight * F.log_softmax(logits, dim=1) # rnnlm if rnnlm: rnnlm_state, local_lm_scores = rnnlm.buff_predict( rnnlm_state, vy, n_bb) local_scores = local_scores + recog_args.lm_weight * local_lm_scores # ctc if ctc_scorer[0]: for idx in range(self.num_encs): att_w = att_w_list[idx] att_w_ = att_w if isinstance(att_w, torch.Tensor) else att_w[0] ctc_state[idx], local_ctc_scores = ctc_scorer[idx]( yseq, ctc_state[idx], local_scores, att_w_) local_scores = ( local_scores + ctc_weight * weights_ctc_dec[idx] * local_ctc_scores) local_scores = local_scores.view(batch, beam, self.odim) if i == 0: local_scores[:, 1:, :] = self.logzero # accumulate scores eos_vscores = local_scores[:, :, self.eos] + vscores vscores = vscores.view(batch, beam, 1).repeat(1, 1, self.odim) vscores[:, :, self.eos] = self.logzero vscores = (vscores + local_scores).view(batch, -1) # global pruning accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1) accum_odim_ids = (torch.fmod( accum_best_ids, self.odim).view(-1).data.cpu().tolist()) accum_padded_beam_ids = ((torch.div(accum_best_ids, self.odim) + pad_b).view(-1).data.cpu().tolist()) y_prev = yseq[:][:] yseq = self._index_select_list(yseq, accum_padded_beam_ids) yseq = self._append_ids(yseq, accum_odim_ids) vscores = accum_best_scores vidx = to_device(self, torch.LongTensor(accum_padded_beam_ids)) a_prev = [] num_atts = self.num_encs if self.num_encs == 1 else self.num_encs + 1 for idx in range(num_atts): if isinstance(att_w_list[idx], torch.Tensor): _a_prev = torch.index_select( att_w_list[idx].view(n_bb, *att_w_list[idx].shape[1:]), 0, vidx) elif isinstance(att_w_list[idx], list): # handle the case of multi-head attention _a_prev = [ torch.index_select(att_w_one.view(n_bb, -1), 0, vidx) for att_w_one in att_w_list[idx] ] else: # handle the case of location_recurrent when return is a tuple _a_prev_ = torch.index_select( att_w_list[idx][0].view(n_bb, -1), 0, vidx) _h_prev_ = torch.index_select( att_w_list[idx][1][0].view(n_bb, -1), 0, vidx) _c_prev_ = torch.index_select( att_w_list[idx][1][1].view(n_bb, -1), 0, vidx) _a_prev = (_a_prev_, (_h_prev_, _c_prev_)) a_prev.append(_a_prev) z_prev = [ torch.index_select(z_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers) ] c_prev = [ torch.index_select(c_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers) ] # pick ended hyps if i >= minlen: k = 0 penalty_i = (i + 1) * penalty thr = accum_best_scores[:, -1] for samp_i in six.moves.range(batch): if stop_search[samp_i]: k = k + beam continue for beam_j in six.moves.range(beam): _vscore = None if eos_vscores[samp_i, beam_j] > thr[samp_i]: yk = y_prev[k][:] if len(yk) <= min(hlens[idx][samp_i] for idx in range(self.num_encs)): _vscore = eos_vscores[samp_i][ beam_j] + penalty_i elif i == maxlen - 1: yk = yseq[k][:] _vscore = vscores[samp_i][beam_j] + penalty_i if _vscore: yk.append(self.eos) if rnnlm: _vscore += recog_args.lm_weight * rnnlm.final( rnnlm_state, index=k) _score = _vscore.data.cpu().numpy() ended_hyps[samp_i].append({ "yseq": yk, "vscore": _vscore, "score": _score }) k = k + 1 # end detection stop_search = [ stop_search[samp_i] or end_detect(ended_hyps[samp_i], i) for samp_i in six.moves.range(batch) ] stop_search_summary = list(set(stop_search)) if len(stop_search_summary) == 1 and stop_search_summary[0]: break if rnnlm: rnnlm_state = self._index_select_lm_state(rnnlm_state, 0, vidx) if ctc_scorer[0]: for idx in range(self.num_encs): ctc_state[idx] = ctc_scorer[idx].index_select_state( ctc_state[idx], accum_best_ids) torch.cuda.empty_cache() dummy_hyps = [{ "yseq": [self.sos, self.eos], "score": np.array([-float("inf")]) }] ended_hyps = [ ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps for samp_i in six.moves.range(batch) ] if normalize_score: for samp_i in six.moves.range(batch): for x in ended_hyps[samp_i]: x["score"] /= len(x["yseq"]) nbest_hyps = [ sorted( ended_hyps[samp_i], key=lambda x: x["score"], reverse=True)[:min(len(ended_hyps[samp_i]), recog_args.nbest)] for samp_i in six.moves.range(batch) ] return nbest_hyps
def recognize_beam_batch(self, h, hlens, lpz, recog_args, char_list, rnnlm=None, normalize_score=True, strm_idx=0, tgt_lang_ids=None): logging.info('input lengths: ' + str(h.size(1))) att_idx = min(strm_idx, len(self.att) - 1) h = mask_by_length(h, hlens, 0.0) # search params batch = len(hlens) beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight att_weight = 1.0 - ctc_weight n_bb = batch * beam n_bo = beam * self.odim n_bbo = n_bb * self.odim pad_b = to_device( self, torch.LongTensor([i * beam for i in six.moves.range(batch)]).view(-1, 1)) pad_bo = to_device( self, torch.LongTensor([i * n_bo for i in six.moves.range(batch)]).view(-1, 1)) pad_o = to_device( self, torch.LongTensor([i * self.odim for i in six.moves.range(n_bb)]).view(-1, 1)) max_hlen = int(max(hlens)) if recog_args.maxlenratio == 0: maxlen = max_hlen else: maxlen = max(1, int(recog_args.maxlenratio * max_hlen)) minlen = int(recog_args.minlenratio * max_hlen) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialization c_prev = [ to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) ] z_prev = [ to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) ] c_list = [ to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) ] z_list = [ to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) ] vscores = to_device(self, torch.zeros(batch, beam)) a_prev = None rnnlm_prev = None self.att[att_idx].reset() # reset pre-computation of h if self.replace_sos and recog_args.tgt_lang: logging.info('<sos> index: ' + str(char_list.index(recog_args.tgt_lang))) logging.info('<sos> mark: ' + recog_args.tgt_lang) yseq = [[char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb)] elif tgt_lang_ids is not None: # NOTE: used for evaluation during training yseq = [[tgt_lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb)] else: logging.info('<sos> index: ' + str(self.sos)) logging.info('<sos> mark: ' + char_list[self.sos]) yseq = [[self.sos] for _ in six.moves.range(n_bb)] accum_odim_ids = [self.sos for _ in six.moves.range(n_bb)] stop_search = [False for _ in six.moves.range(batch)] nbest_hyps = [[] for _ in six.moves.range(batch)] ended_hyps = [[] for _ in range(batch)] exp_hlens = hlens.repeat(beam).view(beam, batch).transpose(0, 1).contiguous() exp_hlens = exp_hlens.view(-1).tolist() exp_h = h.unsqueeze(1).repeat(1, beam, 1, 1).contiguous() exp_h = exp_h.view(n_bb, h.size()[1], h.size()[2]) if lpz is not None: device_id = torch.cuda.device_of(next(self.parameters()).data).idx ctc_prefix_score = CTCPrefixScoreTH(lpz, 0, self.eos, beam, exp_hlens, device_id) ctc_states_prev = ctc_prefix_score.initial_state() ctc_scores_prev = to_device(self, torch.zeros(batch, n_bo)) for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) vy = to_device(self, torch.LongTensor(self._get_last_yseq(yseq))) ey = self.dropout_emb(self.embed(vy)) att_c, att_w = self.att[att_idx](exp_h, exp_hlens, self.dropout_dec[0](z_prev[0]), a_prev) ey = torch.cat((ey, att_c), dim=1) # attention decoder z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev) if self.context_residual: logits = self.output( torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)) else: logits = self.output(self.dropout_dec[-1](z_list[-1])) local_scores = att_weight * F.log_softmax(logits, dim=1) # rnnlm if rnnlm: rnnlm_state, local_lm_scores = rnnlm.buff_predict( rnnlm_prev, vy, n_bb) local_scores = local_scores + recog_args.lm_weight * local_lm_scores local_scores = local_scores.view(batch, n_bo) # ctc if lpz is not None: ctc_scores, ctc_states = ctc_prefix_score( yseq, ctc_states_prev, accum_odim_ids) ctc_scores = ctc_scores.view(batch, n_bo) local_scores = local_scores + ctc_weight * (ctc_scores - ctc_scores_prev) local_scores = local_scores.view(batch, beam, self.odim) if i == 0: local_scores[:, 1:, :] = self.logzero local_best_scores, local_best_odims = torch.topk( local_scores.view(batch, beam, self.odim), beam, 2) # local pruning (via xp) local_scores = np.full((n_bbo, ), self.logzero) _best_odims = local_best_odims.view(n_bb, beam) + pad_o _best_odims = _best_odims.view(-1).cpu().numpy() _best_score = local_best_scores.view(-1).cpu().detach().numpy() local_scores[_best_odims] = _best_score local_scores = to_device( self, torch.from_numpy(local_scores).float()).view( batch, beam, self.odim) # (or indexing) # local_scores = to_device(self, torch.full((batch, beam, self.odim), self.logzero)) # _best_odims = local_best_odims # _best_score = local_best_scores # for si in six.moves.range(batch): # for bj in six.moves.range(beam): # for bk in six.moves.range(beam): # local_scores[si, bj, _best_odims[si, bj, bk]] = _best_score[si, bj, bk] eos_vscores = local_scores[:, :, self.eos] + vscores vscores = vscores.view(batch, beam, 1).repeat(1, 1, self.odim) vscores[:, :, self.eos] = self.logzero vscores = (vscores + local_scores).view(batch, n_bo) # global pruning accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1) accum_odim_ids = torch.fmod( accum_best_ids, self.odim).view(-1).data.cpu().tolist() accum_padded_odim_ids = (torch.fmod(accum_best_ids, n_bo) + pad_bo).view(-1).data.cpu().tolist() accum_padded_beam_ids = (torch.div(accum_best_ids, self.odim) + pad_b).view(-1).data.cpu().tolist() y_prev = yseq[:][:] yseq = self._index_select_list(yseq, accum_padded_beam_ids) yseq = self._append_ids(yseq, accum_odim_ids) vscores = accum_best_scores vidx = to_device(self, torch.LongTensor(accum_padded_beam_ids)) if isinstance(att_w, torch.Tensor): a_prev = torch.index_select(att_w.view(n_bb, *att_w.shape[1:]), 0, vidx) elif isinstance(att_w, list): # handle the case of multi-head attention a_prev = [ torch.index_select(att_w_one.view(n_bb, -1), 0, vidx) for att_w_one in att_w ] else: # handle the case of location_recurrent when return is a tuple a_prev_ = torch.index_select(att_w[0].view(n_bb, -1), 0, vidx) h_prev_ = torch.index_select(att_w[1][0].view(n_bb, -1), 0, vidx) c_prev_ = torch.index_select(att_w[1][1].view(n_bb, -1), 0, vidx) a_prev = (a_prev_, (h_prev_, c_prev_)) z_prev = [ torch.index_select(z_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers) ] c_prev = [ torch.index_select(c_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers) ] if rnnlm: rnnlm_prev = self._index_select_lm_state(rnnlm_state, 0, vidx) if lpz is not None: ctc_vidx = to_device(self, torch.LongTensor(accum_padded_odim_ids)) ctc_scores_prev = torch.index_select(ctc_scores.view(-1), 0, ctc_vidx) ctc_scores_prev = ctc_scores_prev.view(-1, 1).repeat( 1, self.odim).view(batch, n_bo) ctc_states = torch.transpose(ctc_states, 1, 3).contiguous() ctc_states = ctc_states.view(n_bbo, 2, -1) ctc_states_prev = torch.index_select(ctc_states, 0, ctc_vidx).view( n_bb, 2, -1) ctc_states_prev = torch.transpose(ctc_states_prev, 1, 2) # pick ended hyps if i > minlen: k = 0 penalty_i = (i + 1) * penalty thr = accum_best_scores[:, -1] for samp_i in six.moves.range(batch): if stop_search[samp_i]: k = k + beam continue for beam_j in six.moves.range(beam): if eos_vscores[samp_i, beam_j] > thr[samp_i]: yk = y_prev[k][:] yk.append(self.eos) if len(yk) < hlens[samp_i]: _vscore = eos_vscores[samp_i][ beam_j] + penalty_i if normalize_score: _vscore = _vscore / len(yk) _score = _vscore.data.cpu().numpy() ended_hyps[samp_i].append({ 'yseq': yk, 'vscore': _vscore, 'score': _score }) k = k + 1 # end detection stop_search = [ stop_search[samp_i] or end_detect(ended_hyps[samp_i], i) for samp_i in six.moves.range(batch) ] stop_search_summary = list(set(stop_search)) if len(stop_search_summary) == 1 and stop_search_summary[0]: break torch.cuda.empty_cache() dummy_hyps = [{ 'yseq': [self.sos, self.eos], 'score': np.array([-float('inf')]) }] ended_hyps = [ ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps for samp_i in six.moves.range(batch) ] nbest_hyps = [ sorted( ended_hyps[samp_i], key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps[samp_i]), recog_args.nbest)] for samp_i in six.moves.range(batch) ] return nbest_hyps
class CTCPrefixScorer(BatchPartialScorerInterface): """Decoder interface wrapper for CTCPrefixScore.""" def __init__(self, ctc: torch.nn.Module, eos: int): """Initialize class. Args: ctc (torch.nn.Module): The CTC implementaiton. For example, :class:`espnet.nets.pytorch_backend.ctc.CTC` eos (int): The end-of-sequence id. """ self.ctc = ctc self.eos = eos self.impl = None def init_state(self, x: torch.Tensor): """Get an initial state for decoding. Args: x (torch.Tensor): The encoded feature tensor Returns: initial state """ logp = self.ctc.log_softmax( x.unsqueeze(0)).detach().squeeze(0).cpu().numpy() # TODO(karita): use CTCPrefixScoreTH self.impl = CTCPrefixScore(logp, 0, self.eos, np) return 0, self.impl.initial_state() def select_state(self, state, i, new_id=None): """Select state with relative ids in the main beam search. Args: state: Decoder state for prefix tokens i (int): Index to select a state in the main beam search new_id (int): New label id to select a state if necessary Returns: state: pruned state """ if type(state) == tuple: if len(state) == 2: # for CTCPrefixScore sc, st = state return sc[i], st[i] else: # for CTCPrefixScoreTH (need new_id > 0) r, log_psi, f_min, f_max, scoring_idmap = state s = log_psi[i, new_id].expand(log_psi.size(1)) if scoring_idmap is not None: return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max else: return r[:, :, i, new_id], s, f_min, f_max return None if state is None else state[i] def score_partial(self, y, ids, state, x): """Score new token. Args: y (torch.Tensor): 1D prefix token next_tokens (torch.Tensor): torch.int64 next token to score state: decoder state for prefix tokens x (torch.Tensor): 2D encoder feature that generates ys Returns: tuple[torch.Tensor, Any]: Tuple of a score tensor for y that has a shape `(len(next_tokens),)` and next state for ys """ prev_score, state = state presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state) tscore = torch.as_tensor(presub_score - prev_score, device=x.device, dtype=x.dtype) return tscore, (presub_score, new_st) def batch_init_state(self, x: torch.Tensor): """Get an initial state for decoding. Args: x (torch.Tensor): The encoded feature tensor Returns: initial state """ logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1 xlen = torch.tensor([logp.size(1)]) self.impl = CTCPrefixScoreTH(logp, xlen, 0, self.eos) return None def batch_score_partial(self, y, ids, state, x): """Score new token. Args: y (torch.Tensor): 1D prefix token ids (torch.Tensor): torch.int64 next token to score state: decoder state for prefix tokens x (torch.Tensor): 2D encoder feature that generates ys Returns: tuple[torch.Tensor, Any]: Tuple of a score tensor for y that has a shape `(len(next_tokens),)` and next state for ys """ batch_state = (( torch.stack([s[0] for s in state], dim=2), torch.stack([s[1] for s in state]), state[0][2], state[0][3], ) if state[0] is not None else None) return self.impl(y, batch_state, ids) def extend_prob(self, x: torch.Tensor): """Extend probs for decoding. This extention is for streaming decoding as in Eq (14) in https://arxiv.org/abs/2006.14941 Args: x (torch.Tensor): The encoded feature tensor """ logp = self.ctc.log_softmax(x.unsqueeze(0)) self.impl.extend_prob(logp) def extend_state(self, state): """Extend state for decoding. This extention is for streaming decoding as in Eq (14) in https://arxiv.org/abs/2006.14941 Args: state: The states of hyps Returns: exteded state """ new_state = [] for s in state: new_state.append(self.impl.extend_state(s)) return new_state