def beam_search(self, eouts, elens, params, idx2token=None,
                    lm=None, lm_second=None, lm_second_bwd=None, ctc_log_probs=None,
                    nbest=1, exclude_eos=False,
                    refs_id=None, utt_ids=None, speakers=None,
                    ensmbl_eouts=[], ensmbl_elens=[], ensmbl_decs=[], cache_states=True):
        """Beam search decoding.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            params (dict): decoding hyperparameters
            idx2token (): converter from index to token
            lm (torch.nn.module): firsh-pass LM
            lm_second (torch.nn.module): second-pass LM
            lm_second_bwd (torch.nn.module): secoding-pass backward LM
            ctc_log_probs (FloatTensor):
            nbest (int): number of N-best list
            exclude_eos (bool): exclude <eos> from hypothesis
            refs_id (List): reference list
            utt_ids (List): utterance id list
            speakers (List): speaker list
            ensmbl_eouts (List[FloatTensor]): encoder outputs for ensemble models
            ensmbl_elens (List[IntTensor]) encoder outputs for ensemble models
            ensmbl_decs (List[torch.nn.Module): decoders for ensemble models
            cache_states (bool): cache decoder states for fast decoding
        Returns:
            nbest_hyps_idx (List): length `[B]`, each of which contains list of N hypotheses
            aws (List): length `[B]`, each of which contains arrays of size `[H, L, T]`
            scores (List):

        """
        bs, xmax, _ = eouts.size()
        n_models = len(ensmbl_decs) + 1

        beam_width = params.get('recog_beam_width')
        assert 1 <= nbest <= beam_width
        ctc_weight = params.get('recog_ctc_weight')
        max_len_ratio = params.get('recog_max_len_ratio')
        min_len_ratio = params.get('recog_min_len_ratio')
        lp_weight = params.get('recog_length_penalty')
        length_norm = params.get('recog_length_norm')
        cache_emb = params.get('recog_cache_embedding')
        lm_weight = params.get('recog_lm_weight')
        lm_weight_second = params.get('recog_lm_second_weight')
        lm_weight_second_bwd = params.get('recog_lm_bwd_weight')
        eos_threshold = params.get('recog_eos_threshold')
        lm_state_carry_over = params.get('recog_lm_state_carry_over')
        softmax_smoothing = params.get('recog_softmax_smoothing')
        eps_wait = params.get('recog_mma_delay_threshold')

        helper = BeamSearch(beam_width, self.eos, ctc_weight, lm_weight, self.device)
        lm = helper.verify_lm_eval_mode(lm, lm_weight, cache_emb)
        lm_second = helper.verify_lm_eval_mode(lm_second, lm_weight_second, cache_emb)
        lm_second_bwd = helper.verify_lm_eval_mode(lm_second_bwd, lm_weight_second_bwd, cache_emb)

        # cache token embeddings
        if cache_emb:
            self.cache_embedding(eouts.device)

        if ctc_log_probs is not None:
            assert ctc_weight > 0
            ctc_log_probs = tensor2np(ctc_log_probs)

        nbest_hyps_idx, aws, scores = [], [], []
        eos_flags = []
        for b in range(bs):
            # Initialization per utterance
            lmstate = None
            ys = eouts.new_zeros((1, 1), dtype=torch.int64).fill_(self.eos)
            # print(ys.shape)
            for layer in self.layers:
                layer.reset()

            # For joint CTC-Attention decoding
            ctc_prefix_scorer = None
            if ctc_log_probs is not None:
                if self.bwd:
                    ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b][::-1], self.blank, self.eos)
                else:
                    ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b], self.blank, self.eos)

            if speakers is not None:
                if speakers[b] == self.prev_spk:
                    if lm_state_carry_over and isinstance(lm, RNNLM):
                        lmstate = self.lmstate_final
                self.prev_spk = speakers[b]

            end_hyps = []
            hyps = [{'hyp': [self.eos],
                     'ys': ys,
                     'cache': None,
                     'score': 0.,
                     'score_att': 0.,
                     'score_ctc': 0.,
                     'score_lm': 0.,
                     'aws': [None],
                     'lmstate': lmstate,
                     'ensmbl_cache': [[None] * dec.n_layers for dec in ensmbl_decs] if n_models > 1 else None,
                     'ctc_state': ctc_prefix_scorer.initial_state() if ctc_prefix_scorer is not None else None,
                     'quantity_rate': 1.,
                     'streamable': True,
                     'streaming_failed_point': 1000}]
            streamable_global = True
            ymax = math.ceil(elens[b] * max_len_ratio)
            for i in range(ymax):
                # batchfy all hypotheses for batch decoding
                cache = [None] * self.n_layers
                if cache_states and i > 0:
                    for lth in range(self.n_layers): # 
                        cache[lth] = torch.cat([beam['cache'][lth] for beam in hyps], dim=0)
                ys = eouts.new_zeros((len(hyps), i + 1), dtype=torch.int64)
                for j, beam in enumerate(hyps):
                    ys[j, :] = beam['ys']
                if i > 0:
                    xy_aws_prev = torch.cat([beam['aws'][-1] for beam in hyps], dim=0)  # `[B, n_layers, H_ma, 1, klen]`
                else:
                    xy_aws_prev = None

                # Update LM states for shallow fusion
                y_lm = ys[:, -1:].clone()  # NOTE: this is important
                _, lmstate, scores_lm = helper.update_rnnlm_state_batch(lm, hyps, y_lm)

                # for the main model
                # print(i)
                causal_mask = eouts.new_ones(i + 1, i + 1, dtype=torch.uint8)
                causal_mask = torch.tril(causal_mask).unsqueeze(0).repeat([ys.size(0), 1, 1])
                # print(causal_mask.shape)
                out = self.pos_enc(self.embed_token_id(ys), scale=True)  # scaled + dropout
                # print(out.shape)
                # assert False, 'vv'
                n_heads_total = 0
                eouts_b = eouts[b:b + 1, :elens[b]].repeat([ys.size(0), 1, 1]) # [Beam, T, dim]
                new_cache = [None] * self.n_layers
                xy_aws_layers = []
                xy_aws = None
                lth_s = self.mma_first_layer - 1
                # 自回归解码
                for lth, layer in enumerate(self.layers):
                    out = layer(
                        out, causal_mask, eouts_b, None,
                        cache=cache[lth],
                        xy_aws_prev=xy_aws_prev[:, lth - lth_s] if lth >= lth_s and i > 0 else None,
                        eps_wait=eps_wait)
                    xy_aws = layer.xy_aws

                    new_cache[lth] = out
                    if xy_aws is not None:
                        xy_aws_layers.append(xy_aws)
                logits = self.output(self.norm_out(out[:, -1])) # 取当前时刻概率输出
                probs = torch.softmax(logits * softmax_smoothing, dim=1)
                xy_aws_layers = torch.stack(xy_aws_layers, dim=1)  # `[B, H, n_layers, L, T]`

                # Ensemble initialization
                ensmbl_cache = [[None] * dec.n_layers for dec in ensmbl_decs]
                if n_models > 1 and cache_states and i > 0:
                    for i_e, dec in enumerate(ensmbl_decs):
                        for lth in range(dec.n_layers):
                            ensmbl_cache[i_e][lth] = torch.cat([beam['ensmbl_cache'][i_e][lth]
                                                                for beam in hyps], dim=0)

                # for the ensemble
                ensmbl_new_cache = [[None] * dec.n_layers for dec in ensmbl_decs]
                for i_e, dec in enumerate(ensmbl_decs):
                    out_e = dec.pos_enc(dec.embed(ys))  # scaled + dropout
                    eouts_e = ensmbl_eouts[i_e][b:b + 1, :elens[b]].repeat([ys.size(0), 1, 1])
                    for lth in range(dec.n_layers):
                        out_e = dec.layers[lth](out_e, causal_mask, eouts_e, None,
                                                cache=ensmbl_cache[i_e][lth])
                        ensmbl_new_cache[i_e][lth] = out_e
                    logits_e = dec.output(dec.norm_out(out_e[:, -1]))
                    probs += torch.softmax(logits_e * softmax_smoothing, dim=1)
                    # NOTE: sum in the probability scale (not log-scale)

                # Ensemble 多个模型融合
                scores_att = torch.log(probs / n_models) # [1, vocab]
                # print(scores_att.shape)
                # assert False, 'vv'
                new_hyps = []
                for j, beam in enumerate(hyps): # hyps [,] # 每个beam生成beam
                    # Attention scores
                    total_scores_att = beam['score_att'] + scores_att[j:j + 1] # current time T # [[vocab]]
                    total_scores = total_scores_att * (1 - ctc_weight)

                    # Add LM score <before> top-K selection
                    if lm is not None:
                        total_scores_lm = beam['score_lm'] + scores_lm[j:j + 1, -1]
                        total_scores += total_scores_lm * lm_weight
                    else:
                        total_scores_lm = eouts.new_zeros(1, self.vocab)

                    # topk_ids 
                    total_scores_topk, topk_ids = torch.topk(
                        total_scores, k=beam_width, dim=1, largest=True, sorted=True)

                    # Add length penalty
                    if lp_weight > 0:
                        total_scores_topk += (len(beam['hyp'][1:]) + 1) * lp_weight

                    # Add CTC score
                    new_ctc_states, total_scores_ctc, total_scores_topk = helper.add_ctc_score(
                        beam['hyp'], topk_ids, beam['ctc_state'],
                        total_scores_topk, ctc_prefix_scorer)

                    new_aws = beam['aws'] + [xy_aws_layers[j:j + 1, :, :, -1:]]
                    aws_j = torch.cat(new_aws[1:], dim=3)  # `[1, H, n_layers, L, T]`

                    # forward direction
                    for k in range(beam_width):
                        idx = topk_ids[0, k].item() # k-beam 的索引
                        length_norm_factor = len(beam['hyp'][1:]) + 1 if length_norm else 1
                        total_score = total_scores_topk[0, k].item() / length_norm_factor # 当前长度

                        if idx == self.eos:
                            # Exclude short hypotheses
                            # remove 短句 中间的静默信号
                            if len(beam['hyp'][1:]) < elens[b] * min_len_ratio:
                                continue
                            # EOS threshold
                            # 找到不是EOS的最大得分idx
                            max_score_no_eos = scores_att[j, :idx].max(0)[0].item()
                            max_score_no_eos = max(max_score_no_eos, scores_att[j, idx + 1:].max(0)[0].item())
                            if scores_att[j, idx].item() <= eos_threshold * max_score_no_eos:
                                # 继续识别 跳过当前帧
                                continue

                        streaming_failed_point = beam['streaming_failed_point']
                        quantity_rate = 1.
                        # 流式相关的
                        if self.attn_type == 'mocha':
                            n_tokens_hyp_k = i + 1
                            n_quantity_k = aws_j[:, :, :, :n_tokens_hyp_k].int().sum().item()
                            quantity_diff = n_tokens_hyp_k * n_heads_total - n_quantity_k

                            if quantity_diff != 0:
                                if idx == self.eos:
                                    n_tokens_hyp_k -= 1  # NOTE: do not count <eos> for streamability
                                    n_quantity_k = aws_j[:, :, :, :n_tokens_hyp_k].int().sum().item()
                                else:
                                    streamable_global = False
                                if n_tokens_hyp_k * n_heads_total == 0:
                                    quantity_rate = 0
                                else:
                                    quantity_rate = n_quantity_k / (n_tokens_hyp_k * n_heads_total)

                            if beam['streamable'] and not streamable_global:
                                streaming_failed_point = i

                        new_hyps.append(
                            {'hyp': beam['hyp'] + [idx],
                             'ys': torch.cat([beam['ys'], eouts.new_zeros((1, 1), dtype=torch.int64).fill_(idx)], dim=-1),
                             'cache': [new_cache_l[j:j + 1] for new_cache_l in new_cache] if cache_states else cache,
                             'score': total_score,
                             'score_att': total_scores_att[0, idx].item(),
                             'score_ctc': total_scores_ctc[k].item(),
                             'score_lm': total_scores_lm[0, idx].item(),
                             'aws': new_aws,
                             'lmstate': {'hxs': lmstate['hxs'][:, j:j + 1],
                                         'cxs': lmstate['cxs'][:, j:j + 1]} if lmstate is not None else None,
                             'ctc_state': new_ctc_states[k] if ctc_prefix_scorer is not None else None,
                             'ensmbl_cache': [[new_cache_e_l[j:j + 1] for new_cache_e_l in new_cache_e]
                                              for new_cache_e in ensmbl_new_cache] if cache_states else None,
                             'streamable': streamable_global,
                             'streaming_failed_point': streaming_failed_point,
                             'quantity_rate': quantity_rate})

                # Local pruning 
                # new_hyps[beamsize,hyps]
                new_hyps_sorted = sorted(new_hyps, key=lambda x: x['score'], reverse=True)[:beam_width]

                # Remove complete hypotheses
                # 剪枝 结果beamwidth大小的列表
                new_hyps, end_hyps, is_finish = helper.remove_complete_hyp(
                    new_hyps_sorted, end_hyps, prune=True)
                hyps = new_hyps[:]
                if is_finish:
                    break

            # Global pruning # 一句识别结束
            if len(end_hyps) == 0:
                end_hyps = hyps[:]
            elif len(end_hyps) < nbest and nbest > 1:
                end_hyps.extend(hyps[:nbest - len(end_hyps)])

            # forward/backward second-pass LM rescoring
            end_hyps = helper.lm_rescoring(end_hyps, lm_second, lm_weight_second,
                                           length_norm=length_norm, tag='second')
            end_hyps = helper.lm_rescoring(end_hyps, lm_second_bwd, lm_weight_second_bwd,
                                           length_norm=length_norm, tag='second_bwd')

            # Sort by score
            end_hyps = sorted(end_hyps, key=lambda x: x['score'], reverse=True)

            # TODO: 
            for j in range(len(end_hyps[0]['aws'][1:])):
                tmp = end_hyps[0]['aws'][j + 1]
                end_hyps[0]['aws'][j + 1] = tmp.view(1, -1, tmp.size(-2), tmp.size(-1))

            # metrics for streaming infernece
            self.streamable = end_hyps[0]['streamable']
            self.quantity_rate = end_hyps[0]['quantity_rate']
            self.last_success_frame_ratio = None

            if idx2token is not None:
                if utt_ids is not None:
                    logger.info('Utt-id: %s' % utt_ids[b])
                assert self.vocab == idx2token.vocab
                logger.info('=' * 200)
                for k in range(len(end_hyps)):
                    if refs_id is not None:
                        logger.info('Ref: %s' % idx2token(refs_id[b]))
                    logger.info('Hyp: %s' % idx2token(
                        end_hyps[k]['hyp'][1:][::-1] if self.bwd else end_hyps[k]['hyp'][1:]))
                    logger.info('num tokens (hyp): %d' % len(end_hyps[k]['hyp'][1:]))
                    logger.info('log prob (hyp): %.7f' % end_hyps[k]['score'])
                    logger.info('log prob (hyp, att): %.7f' %
                                (end_hyps[k]['score_att'] * (1 - ctc_weight)))
                    if ctc_prefix_scorer is not None:
                        logger.info('log prob (hyp, ctc): %.7f' %
                                    (end_hyps[k]['score_ctc'] * ctc_weight))
                    if lm is not None:
                        logger.info('log prob (hyp, first-pass lm): %.7f' %
                                    (end_hyps[k]['score_lm'] * lm_weight))
                    if lm_second is not None:
                        logger.info('log prob (hyp, second-pass lm): %.7f' %
                                    (end_hyps[k]['score_lm_second'] * lm_weight_second))
                    if lm_second_bwd is not None:
                        logger.info('log prob (hyp, second-pass lm, reverse): %.7f' %
                                    (end_hyps[k]['score_lm_second_bwd'] * lm_weight_second_bwd))
                    if self.attn_type == 'mocha':
                        logger.info('streamable: %s' % end_hyps[k]['streamable'])
                        logger.info('streaming failed point: %d' %
                                    (end_hyps[k]['streaming_failed_point'] + 1))
                        logger.info('quantity rate [%%]: %.2f' %
                                    (end_hyps[k]['quantity_rate'] * 100))
                    logger.info('-' * 50)

                if self.attn_type == 'mocha' and end_hyps[0]['streaming_failed_point'] < 1000:
                    assert not self.streamable
                    aws_last_success = end_hyps[0]['aws'][1:][end_hyps[0]['streaming_failed_point'] - 1]
                    rightmost_frame = max(0, aws_last_success[0, :, 0].nonzero()[:, -1].max().item()) + 1
                    frame_ratio = rightmost_frame * 100 / xmax
                    self.last_success_frame_ratio = frame_ratio
                    logger.info('streaming last success frame ratio: %.2f' % frame_ratio)

            # N-best list
            if self.bwd:
                # Reverse the order
                nbest_hyps_idx += [[np.array(end_hyps[n]['hyp'][1:][::-1]) for n in range(nbest)]]
                aws += [[tensor2np(torch.cat(end_hyps[n]['aws'][1:][::-1], dim=2).squeeze(0)) for n in range(nbest)]]
            else:
                nbest_hyps_idx += [[np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest)]]
                aws += [[tensor2np(torch.cat(end_hyps[n]['aws'][1:], dim=2).squeeze(0)) for n in range(nbest)]]
            scores += [[end_hyps[n]['score_att'] for n in range(nbest)]]

            # Check <eos>
            eos_flags.append([(end_hyps[n]['hyp'][-1] == self.eos) for n in range(nbest)])

        # Exclude <eos> (<sos> in case of the backward decoder)
        if exclude_eos:
            if self.bwd:
                nbest_hyps_idx = [[nbest_hyps_idx[b][n][1:] if eos_flags[b][n]
                                   else nbest_hyps_idx[b][n] for n in range(nbest)] for b in range(bs)]
                aws = [[aws[b][n][:, 1:] if eos_flags[b][n] else aws[b][n] for n in range(nbest)] for b in range(bs)]
            else:
                nbest_hyps_idx = [[nbest_hyps_idx[b][n][:-1] if eos_flags[b][n]
                                   else nbest_hyps_idx[b][n] for n in range(nbest)] for b in range(bs)]
                aws = [[aws[b][n][:, :-1] if eos_flags[b][n] else aws[b][n] for n in range(nbest)] for b in range(bs)]

        # Store ASR/LM state
        if bs == 1:
            self.lmstate_final = end_hyps[0]['lmstate']

        return nbest_hyps_idx, aws, scores
Beispiel #2
0
    def beam_search(self,
                    eouts,
                    elens,
                    params,
                    idx2token=None,
                    lm=None,
                    lm_second=None,
                    lm_second_bwd=None,
                    ctc_log_probs=None,
                    nbest=1,
                    exclude_eos=False,
                    refs_id=None,
                    utt_ids=None,
                    speakers=None,
                    ensmbl_eouts=None,
                    ensmbl_elens=None,
                    ensmbl_decs=[]):
        """Beam search decoding.

        Args:
            eouts (FloatTensor): `[B, T, enc_n_units]`
            elens (IntTensor): `[B]`
            params (dict): hyperparameters for decoding
            idx2token (): converter from index to token
            lm: firsh path LM
            lm_second: second path LM
            lm_second_bwd: secoding path backward LM
            ctc_log_probs (FloatTensor): `[B, T, vocab]`
            nbest (int): number of N-best list
            exclude_eos (bool): exclude <eos> from hypothesis
            refs_id (list): reference list
            utt_ids (list): utterance id list
            speakers (list): speaker list
            ensmbl_eouts (list): list of FloatTensor
            ensmbl_elens (list) list of list
            ensmbl_decs (list): list of torch.nn.Module
        Returns:
            nbest_hyps_idx (list): length `B`, each of which contains list of N hypotheses
            aws: dummy
            scores: dummy

        """
        bs = eouts.size(0)

        beam_width = params['recog_beam_width']
        assert 1 <= nbest <= beam_width
        ctc_weight = params['recog_ctc_weight']
        assert ctc_weight == 0
        assert ctc_log_probs is None
        # length_norm = params['recog_length_norm']
        lm_weight = params['recog_lm_weight']
        lm_weight_second = params['recog_lm_second_weight']
        lm_weight_second_bwd = params['recog_lm_bwd_weight']
        # asr_state_carry_over = params['recog_asr_state_carry_over']
        lm_state_carry_over = params['recog_lm_state_carry_over']

        if lm is not None:
            assert lm_weight > 0
            lm.eval()
        if lm_second is not None:
            assert lm_weight_second > 0
            lm_second.eval()
        if lm_second_bwd is not None:
            assert lm_weight_second_bwd > 0
            lm_second_bwd.eval()

        nbest_hyps_idx = []
        eos_flags = []
        for b in range(bs):
            # Initialization per utterance
            y = eouts.new_zeros((1, 1), dtype=torch.int64).fill_(self.eos)
            y_emb = self.dropout_emb(self.embed(y))
            dout, dstate = self.recurrency(y_emb, None)
            lmstate = None

            if speakers is not None:
                if speakers[b] == self.prev_spk:
                    if lm_state_carry_over and isinstance(lm, RNNLM):
                        lmstate = self.lmstate_final
                self.prev_spk = speakers[b]

            helper = BeamSearch(beam_width, self.eos, ctc_weight, eouts.device)

            end_hyps = []
            hyps = [{
                'hyp': [self.eos],
                'hyp_str': '',
                'ys': [self.eos],
                'score': 0.,
                'score_rnnt': 0.,
                'score_lm': 0.,
                'dout': dout,
                'dstate': dstate,
                'lmstate': lmstate
            }]
            for t in range(elens[b]):
                # batchfy all hypotheses for batch decoding
                douts = torch.cat([beam['dout'] for beam in hyps], dim=0)
                logits = self.joint(
                    eouts[b:b + 1, t:t + 1].repeat([douts.size(0), 1, 1]),
                    douts)
                scores_rnnt = torch.log_softmax(logits.squeeze(2).squeeze(1),
                                                dim=-1)  # `[B, vocab]`

                new_hyps = []
                for j, beam in enumerate(hyps):
                    # Transducer scores
                    total_scores_rnnt = beam['score_rnnt'] + scores_rnnt[j:j +
                                                                         1]
                    total_scores_topk, topk_ids = torch.topk(total_scores_rnnt,
                                                             k=beam_width,
                                                             dim=-1,
                                                             largest=True,
                                                             sorted=True)

                    for k in range(beam_width):
                        idx = topk_ids[0, k].item()
                        # length_norm_factor = len(beam['hyp'][1:]) + 1 if length_norm else 1
                        # total_score = total_scores_topk[0, k].item() / length_norm_factor
                        total_score = total_scores_topk[0, k].item()
                        total_score_lm = beam['score_lm']

                        if idx == self.blank:
                            new_hyps.append(beam.copy())
                            new_hyps[-1]['score'] += scores_rnnt[
                                j, self.blank].item()
                            new_hyps[-1]['score_rnnt'] += scores_rnnt[
                                j, self.blank].item()
                            continue

                        # Update prediction network only when predicting non-blank labels
                        hyp_ids = beam['hyp'] + [idx]
                        hyp_str = ' '.join(list(map(str, hyp_ids)))
                        if hyp_str in self.state_cache.keys():
                            # from cache
                            dout = self.state_cache[hyp_str]['dout']
                            dstate = self.state_cache[hyp_str]['dstate']
                            lmstate = self.state_cache[hyp_str]['lmstate']
                            total_score_lm = self.state_cache[hyp_str][
                                'total_score_lm']
                        else:
                            y = eouts.new_zeros((1, 1),
                                                dtype=torch.int64).fill_(idx)
                            y_emb = self.dropout_emb(self.embed(y))
                            dout, dstate = self.recurrency(
                                y_emb, beam['dstate'])

                            # Update LM states for shallow fusion
                            y_prev = eouts.new_zeros(
                                (1, 1),
                                dtype=torch.int64).fill_(beam['hyp'][-1])
                            _, lmstate, scores_lm = helper.update_rnnlm_state(
                                lm, beam, y_prev)
                            if lm is not None:
                                total_score_lm += scores_lm[0, -1, idx].item()
                                # total_score_lm /= length_norm_factor

                            self.state_cache[hyp_str] = {
                                'dout': dout,
                                'dstate': dstate,
                                'lmstate': {
                                    'hxs': lmstate['hxs'],
                                    'cxs': lmstate['cxs']
                                } if lmstate is not None else None,
                                'total_score_lm': total_score_lm,
                            }

                        if lm is not None:
                            total_score += total_score_lm * lm_weight

                        new_hyps.append({
                            'hyp':
                            hyp_ids,
                            'hyp_str':
                            hyp_str,
                            'score':
                            total_score,
                            'score_rnnt':
                            total_scores_rnnt[0, idx].item(),
                            'score_lm':
                            total_score_lm,
                            'dout':
                            dout,
                            'dstate':
                            dstate,
                            'lmstate': {
                                'hxs': lmstate['hxs'],
                                'cxs': lmstate['cxs']
                            } if lmstate is not None else None
                        })

                # Merge hypotheses having the same token sequences
                new_hyps_merged = {}
                for beam in new_hyps:
                    hyp_str = ' '.join(list(map(str, beam['hyp'])))
                    if hyp_str not in new_hyps_merged.keys():
                        new_hyps_merged[hyp_str] = beam
                    elif hyp_str in new_hyps_merged.keys():
                        if beam['score'] > new_hyps_merged[hyp_str]['score']:
                            new_hyps_merged[hyp_str] = beam
                new_hyps = [v for v in new_hyps_merged.values()]

                # Local pruning
                new_hyps_sorted = sorted(new_hyps,
                                         key=lambda x: x['score'],
                                         reverse=True)[:beam_width]

                # Remove complete hypotheses
                new_hyps, end_hyps, is_finish = helper.remove_complete_hyp(
                    new_hyps_sorted, end_hyps)
                hyps = new_hyps[:]
                if is_finish:
                    break

            # Global pruning
            if len(end_hyps) == 0:
                end_hyps = hyps[:]
            elif len(end_hyps) < nbest and nbest > 1:
                end_hyps.extend(hyps[:nbest - len(end_hyps)])

            # forward second path LM rescoring
            if lm_second is not None:
                self.lm_rescoring(end_hyps,
                                  lm_second,
                                  lm_weight_second,
                                  tag='second')

            # backward second path LM rescoring
            if lm_second_bwd is not None:
                self.lm_rescoring(end_hyps,
                                  lm_second_bwd,
                                  lm_weight_second_bwd,
                                  tag='second_bwd')

            # Sort by score
            end_hyps = sorted(
                end_hyps,
                key=lambda x: x['score'] / max(len(x['hyp'][1:]), 1),
                reverse=True)

            # Reset state cache
            self.state_cache = OrderedDict()

            if idx2token is not None:
                if utt_ids is not None:
                    logger.info('Utt-id: %s' % utt_ids[b])
                assert self.vocab == idx2token.vocab
                logger.info('=' * 200)
                for k in range(len(end_hyps)):
                    if refs_id is not None:
                        logger.info('Ref: %s' % idx2token(refs_id[b]))
                    logger.info('Hyp: %s' % idx2token(end_hyps[k]['hyp'][1:]))
                    logger.info('log prob (hyp): %.7f' % end_hyps[k]['score'])
                    logger.info('log prob (hyp, rnnt): %.7f' %
                                end_hyps[k]['score_rnnt'])
                    if lm is not None:
                        logger.info('log prob (hyp, first-path lm): %.7f' %
                                    (end_hyps[k]['score_lm'] * lm_weight))
                    if lm_second is not None:
                        logger.info('log prob (hyp, second-path lm): %.7f' %
                                    (end_hyps[k]['score_lm_second'] *
                                     lm_weight_second))
                    if lm_second_bwd is not None:
                        logger.info(
                            'log prob (hyp, second-path lm, reverse): %.7f' %
                            (end_hyps[k]['score_lm_second_rev'] *
                             lm_weight_second_bwd))
                    logger.info('-' * 50)

            # N-best list
            nbest_hyps_idx += [[
                np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest)
            ]]

            # Check <eos>
            eos_flags.append([(end_hyps[n]['hyp'][-1] == self.eos)
                              for n in range(nbest)])

        return nbest_hyps_idx, None, None
Beispiel #3
0
    def beam_search(self, eouts, elens, params, idx2token=None,
                    lm=None, lm_second=None, lm_bwd=None, ctc_log_probs=None,
                    nbest=1, exclude_eos=False,
                    refs_id=None, utt_ids=None, speakers=None,
                    ensmbl_eouts=None, ensmbl_elens=None, ensmbl_decs=[], cache_states=True):
        """Beam search decoding.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            params (dict): hyperparameters for decoding
            idx2token (): converter from index to token
            lm: firsh path LM
            lm_second: second path LM
            lm_bwd: first/secoding path backward LM
            ctc_log_probs (FloatTensor):
            nbest (int):
            exclude_eos (bool): exclude <eos> from hypothesis
            refs_id (list): reference list
            utt_ids (list): utterance id list
            speakers (list): speaker list
            ensmbl_eouts (list): list of FloatTensor
            ensmbl_elens (list) list of list
            ensmbl_decs (list): list of torch.nn.Module
            cache_states (bool): cache decoder states for fast decoding
        Returns:
            nbest_hyps_idx (list): length `B`, each of which contains list of N hypotheses
            aws (list): length `B`, each of which contains arrays of size `[H, L, T]`
            scores (list):

        """
        bs, xmax, _ = eouts.size()
        n_models = len(ensmbl_decs) + 1

        beam_width = params['recog_beam_width']
        assert 1 <= nbest <= beam_width
        ctc_weight = params['recog_ctc_weight']
        max_len_ratio = params['recog_max_len_ratio']
        min_len_ratio = params['recog_min_len_ratio']
        lp_weight = params['recog_length_penalty']
        length_norm = params['recog_length_norm']
        lm_weight = params['recog_lm_weight']
        lm_weight_second = params['recog_lm_second_weight']
        lm_weight_bwd = params['recog_lm_bwd_weight']
        eos_threshold = params['recog_eos_threshold']
        lm_state_carry_over = params['recog_lm_state_carry_over']
        softmax_smoothing = params['recog_softmax_smoothing']
        eps_wait = params['recog_mma_delay_threshold']

        if lm is not None:
            assert lm_weight > 0
            lm.eval()
        if lm_second is not None:
            assert lm_weight_second > 0
            lm_second.eval()
        if lm_bwd is not None:
            assert lm_weight_bwd > 0
            lm_bwd.eval()

        if ctc_log_probs is not None:
            assert ctc_weight > 0
            ctc_log_probs = tensor2np(ctc_log_probs)

        nbest_hyps_idx, aws, scores = [], [], []
        eos_flags = []
        for b in range(bs):
            # Initialization per utterance
            lmstate = None
            ys = eouts.new_zeros(1, 1).fill_(self.eos).long()

            # For joint CTC-Attention decoding
            ctc_prefix_scorer = None
            if ctc_log_probs is not None:
                if self.bwd:
                    ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b][::-1], self.blank, self.eos)
                else:
                    ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b], self.blank, self.eos)

            if speakers is not None:
                if speakers[b] == self.prev_spk:
                    if lm_state_carry_over and isinstance(lm, RNNLM):
                        lmstate = self.lmstate_final
                self.prev_spk = speakers[b]

            helper = BeamSearch(beam_width, self.eos, ctc_weight, self.device_id)

            end_hyps = []
            ymax = int(math.floor(elens[b] * max_len_ratio)) + 1
            hyps = [{'hyp': [self.eos],
                     'ys': ys,
                     'cache': None,
                     'score': 0.,
                     'score_attn': 0.,
                     'score_ctc': 0.,
                     'score_lm': 0.,
                     'aws': [None],
                     'lmstate': lmstate,
                     'ensmbl_aws':[[None]] * (n_models - 1),
                     'ctc_state': ctc_prefix_scorer.initial_state() if ctc_prefix_scorer is not None else None,
                     'streamable': True,
                     'streaming_failed_point': 1000}]
            streamable_global = True
            for t in range(ymax):
                # batchfy all hypotheses for batch decoding
                cache = [None] * self.n_layers
                if cache_states and t > 0:
                    for lth in range(self.n_layers):
                        cache[lth] = torch.cat([beam['cache'][lth] for beam in hyps], dim=0)
                ys = eouts.new_zeros(len(hyps), t + 1).long()
                for j, beam in enumerate(hyps):
                    ys[j, :] = beam['ys']
                if t > 0:
                    xy_aws_prev = torch.cat([beam['aws'][-1] for beam in hyps], dim=0)  # `[B, n_layers, H_ma, 1, klen]`
                else:
                    xy_aws_prev = None

                # Update LM states for shallow fusion
                lmstate, scores_lm = None, None
                if lm is not None:
                    if hyps[0]['lmstate'] is not None:
                        lm_hxs = torch.cat([beam['lmstate']['hxs'] for beam in hyps], dim=1)
                        lm_cxs = torch.cat([beam['lmstate']['cxs'] for beam in hyps], dim=1)
                        lmstate = {'hxs': lm_hxs, 'cxs': lm_cxs}
                    y = ys[:, -1:].clone()  # NOTE: this is important
                    _, lmstate, scores_lm = lm.predict(y, lmstate)

                # for the main model
                causal_mask = eouts.new_ones(t + 1, t + 1).byte()
                causal_mask = torch.tril(causal_mask, out=causal_mask).unsqueeze(0).repeat([ys.size(0), 1, 1])

                out = self.pos_enc(self.embed(ys))  # scaled

                mlen = 0  # TODO: fix later
                if self.memory_transformer:
                    # NOTE: TransformerXL does not use positional encoding in the token embedding
                    mems = self.init_memory()
                    # adopt zero-centered offset
                    pos_idxs = torch.arange(mlen - 1, -(t + 1) - 1, -1.0, dtype=torch.float)
                    pos_embs = self.pos_emb(pos_idxs, self.device_id)
                    out = self.dropout_emb(out)
                    hidden_states = [out]

                n_heads_total = 0
                eouts_b = eouts[b:b + 1, :elens[b]].repeat([ys.size(0), 1, 1])
                new_cache = [None] * self.n_layers
                xy_aws_all_layers = []
                lth_s = self.mocha_first_layer - 1
                for lth, layer in enumerate(self.layers):
                    if self.memory_transformer:
                        out = layer(
                            out, causal_mask, eouts_b, None,
                            cache=cache[lth],
                            pos_embs=pos_embs, memory=mems[lth], u=self.u, v=self.v)
                        hidden_states.append(out)
                    else:
                        out = layer(
                            out, causal_mask, eouts_b, None,
                            cache=cache[lth],
                            xy_aws_prev=xy_aws_prev[:, lth - lth_s] if lth >= lth_s and t > 0 else None,
                            eps_wait=eps_wait)

                    new_cache[lth] = out
                    if layer.xy_aws is not None:
                        xy_aws_all_layers.append(layer.xy_aws)
                logits = self.output(self.norm_out(out))
                probs = torch.softmax(logits[:, -1] * softmax_smoothing, dim=1)
                xy_aws_all_layers = torch.stack(xy_aws_all_layers, dim=1)  # `[B, H, n_layers, L, T]`

                # for the ensemble
                ensmbl_new_cache = []
                if n_models > 1:
                    # Ensemble initialization
                    # ensmbl_cache = []
                    # cache_e = [None] * self.n_layers
                    # if cache_states and t > 0:
                    #     for lth in range(self.n_layers):
                    #         cache_e[lth] = torch.cat([beam['ensmbl_cache'][lth] for beam in hyps], dim=0)
                    for i_e, dec in enumerate(ensmbl_decs):
                        out_e = dec.pos_enc(dec.embed(ys))  # scaled
                        eouts_e = ensmbl_eouts[i_e][b:b + 1, :elens[b]].repeat([ys.size(0), 1, 1])
                        new_cache_e = [None] * dec.n_layers
                        for lth in range(dec.n_layers):
                            out_e, _, xy_aws_e, _, _ = dec.layers[lth](out_e, causal_mask, eouts_e, None,
                                                                       cache=cache[lth])
                            new_cache_e[lth] = out_e
                        ensmbl_new_cache.append(new_cache_e)
                        logits_e = dec.output(dec.norm_out(out_e))
                        probs += torch.softmax(logits_e[:, -1] * softmax_smoothing, dim=1)
                        # NOTE: sum in the probability scale (not log-scale)

                # Ensemble in log-scale
                scores_attn = torch.log(probs) / n_models

                new_hyps = []
                for j, beam in enumerate(hyps):
                    # Attention scores
                    total_scores_attn = beam['score_attn'] + scores_attn[j:j + 1]
                    total_scores = total_scores_attn * (1 - ctc_weight)

                    # Add LM score <before> top-K selection
                    if lm is not None:
                        total_scores_lm = beam['score_lm'] + scores_lm[j:j + 1, -1]
                        total_scores += total_scores_lm * lm_weight
                    else:
                        total_scores_lm = eouts.new_zeros(1, self.vocab)

                    total_scores_topk, topk_ids = torch.topk(
                        total_scores, k=beam_width, dim=1, largest=True, sorted=True)

                    # Add length penalty
                    if lp_weight > 0:
                        total_scores_topk += (len(beam['hyp'][1:]) + 1) * lp_weight

                    # Add CTC score
                    new_ctc_states, total_scores_ctc, total_scores_topk = helper.add_ctc_score(
                        beam['hyp'], topk_ids, beam['ctc_state'],
                        total_scores_topk, ctc_prefix_scorer)

                    new_aws = beam['aws'] + [xy_aws_all_layers[j:j + 1, :, :, -1:]]
                    aws_j = torch.cat(new_aws[1:], dim=3)  # `[1, H, n_layers, L, T]`
                    streaming_failed_point = beam['streaming_failed_point']

                    # forward direction
                    for k in range(beam_width):
                        idx = topk_ids[0, k].item()
                        length_norm_factor = len(beam['hyp'][1:]) + 1 if length_norm else 1
                        total_scores_topk /= length_norm_factor

                        if idx == self.eos:
                            # Exclude short hypotheses
                            if len(beam['hyp']) - 1 < elens[b] * min_len_ratio:
                                continue
                            # EOS threshold
                            max_score_no_eos = scores_attn[j, :idx].max(0)[0].item()
                            max_score_no_eos = max(max_score_no_eos, scores_attn[j, idx + 1:].max(0)[0].item())
                            if scores_attn[j, idx].item() <= eos_threshold * max_score_no_eos:
                                continue

                        quantity_rate = 1.
                        if 'mocha' in self.attn_type:
                            n_tokens_hyp_k = t + 1
                            n_quantity_k = aws_j[:, :, :, :n_tokens_hyp_k].int().sum().item()
                            quantity_diff = n_tokens_hyp_k * n_heads_total - n_quantity_k

                            if quantity_diff != 0:
                                if idx == self.eos:
                                    n_tokens_hyp_k -= 1  # NOTE: do not count <eos> for streamability
                                    n_quantity_k = aws_j[:, :, :, :n_tokens_hyp_k].int().sum().item()
                                else:
                                    streamable_global = False
                                if n_tokens_hyp_k * n_heads_total == 0:
                                    quantity_rate = 0
                                else:
                                    quantity_rate = n_quantity_k / (n_tokens_hyp_k * n_heads_total)

                            if beam['streamable'] and not streamable_global:
                                streaming_failed_point = t

                        new_hyps.append(
                            {'hyp': beam['hyp'] + [idx],
                             'ys': torch.cat([beam['ys'], eouts.new_zeros(1, 1).fill_(idx).long()], dim=-1),
                             'cache': [new_cache_l[j:j + 1] for new_cache_l in new_cache] if cache_states else cache,
                             'score': total_scores_topk[0, k].item(),
                             'score_attn': total_scores_attn[0, idx].item(),
                             'score_ctc': total_scores_ctc[k].item(),
                             'score_lm': total_scores_lm[0, idx].item(),
                             'aws': new_aws,
                             'lmstate': {'hxs': lmstate['hxs'][:, j:j + 1],
                                         'cxs': lmstate['cxs'][:, j:j + 1]} if lmstate is not None else None,
                             'ctc_state': new_ctc_states[k] if ctc_prefix_scorer is not None else None,
                             'ensmbl_cache': ensmbl_new_cache,
                             'streamable': streamable_global,
                             'streaming_failed_point': streaming_failed_point,
                             'quantity_rate': quantity_rate})

                # Local pruning
                new_hyps_sorted = sorted(new_hyps, key=lambda x: x['score'], reverse=True)[:beam_width]

                # Remove complete hypotheses
                new_hyps, end_hyps, is_finish = helper.remove_complete_hyp(
                    new_hyps_sorted, end_hyps, prune=True)
                hyps = new_hyps[:]
                if is_finish:
                    break

            # Global pruning
            if len(end_hyps) == 0:
                end_hyps = hyps[:]
            elif len(end_hyps) < nbest and nbest > 1:
                end_hyps.extend(hyps[:nbest - len(end_hyps)])

            # forward second path LM rescoring
            if lm_second is not None:
                self.lm_rescoring(end_hyps, lm_second, lm_weight_second, tag='second')

            # backward secodn path LM rescoring
            if lm_bwd is not None and lm_weight_bwd > 0:
                self.lm_rescoring(end_hyps, lm_bwd, lm_weight_bwd, tag='second_bwd')

            # Sort by score
            end_hyps = sorted(end_hyps, key=lambda x: x['score'], reverse=True)

            for j in range(len(end_hyps[0]['aws'][1:])):
                tmp = end_hyps[0]['aws'][j + 1]
                end_hyps[0]['aws'][j + 1] = tmp.view(1, -1, tmp.size(-2), tmp.size(-1))

            # metrics for streaming infernece
            self.streamable = end_hyps[0]['streamable']
            self.quantity_rate = end_hyps[0]['quantity_rate']
            self.last_success_frame_ratio = None

            if idx2token is not None:
                if utt_ids is not None:
                    logger.info('Utt-id: %s' % utt_ids[b])
                assert self.vocab == idx2token.vocab
                logger.info('=' * 200)
                for k in range(len(end_hyps)):
                    if refs_id is not None:
                        logger.info('Ref: %s' % idx2token(refs_id[b]))
                    logger.info('Hyp: %s' % idx2token(
                        end_hyps[k]['hyp'][1:][::-1] if self.bwd else end_hyps[k]['hyp'][1:]))
                    logger.info('num tokens (hyp): %d' % len(end_hyps[k]['hyp'][1:]))
                    logger.info('log prob (hyp): %.7f' % end_hyps[k]['score'])
                    logger.info('log prob (hyp, att): %.7f' % (end_hyps[k]['score_attn'] * (1 - ctc_weight)))
                    if ctc_prefix_scorer is not None:
                        logger.info('log prob (hyp, ctc): %.7f' % (end_hyps[k]['score_ctc'] * ctc_weight))
                    if lm is not None:
                        logger.info('log prob (hyp, first-path lm): %.7f' % (end_hyps[k]['score_lm'] * lm_weight))
                    if lm_second is not None:
                        logger.info('log prob (hyp, second-path lm): %.7f' %
                                    (end_hyps[k]['score_lm_second'] * lm_weight_second))
                    if lm_bwd is not None:
                        logger.info('log prob (hyp, second-path lm-bwd): %.7f' %
                                    (end_hyps[k]['score_lm_second_bwd'] * lm_weight_bwd))
                    if 'mocha' in self.attn_type:
                        logger.info('streamable: %s' % end_hyps[k]['streamable'])
                        logger.info('streaming failed point: %d' % (end_hyps[k]['streaming_failed_point'] + 1))
                        logger.info('quantity rate [%%]: %.2f' % (end_hyps[k]['quantity_rate'] * 100))
                    logger.info('-' * 50)

                if 'mocha' in self.attn_type and end_hyps[0]['streaming_failed_point'] < 1000:
                    assert not self.streamable
                    aws_last_success = end_hyps[0]['aws'][1:][end_hyps[0]['streaming_failed_point'] - 1]
                    rightmost_frame = max(0, aws_last_success[0, :, 0].nonzero()[:, -1].max().item()) + 1
                    frame_ratio = rightmost_frame * 100 / xmax
                    self.last_success_frame_ratio = frame_ratio
                    logger.info('streaming last success frame ratio: %.2f' % frame_ratio)

            # N-best list
            if self.bwd:
                # Reverse the order
                nbest_hyps_idx += [[np.array(end_hyps[n]['hyp'][1:][::-1]) for n in range(nbest)]]
                aws += [tensor2np(torch.cat(end_hyps[0]['aws'][1:][::-1], dim=2).squeeze(0))]
            else:
                nbest_hyps_idx += [[np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest)]]
                aws += [tensor2np(torch.cat(end_hyps[0]['aws'][1:], dim=2).squeeze(0))]
            scores += [[end_hyps[n]['score_attn'] for n in range(nbest)]]

            # Check <eos>
            eos_flags.append([(end_hyps[n]['hyp'][-1] == self.eos) for n in range(nbest)])

        # Exclude <eos> (<sos> in case of the backward decoder)
        if exclude_eos:
            if self.bwd:
                nbest_hyps_idx = [[nbest_hyps_idx[b][n][1:] if eos_flags[b][n]
                                   else nbest_hyps_idx[b][n] for n in range(nbest)] for b in range(bs)]
            else:
                nbest_hyps_idx = [[nbest_hyps_idx[b][n][:-1] if eos_flags[b][n]
                                   else nbest_hyps_idx[b][n] for n in range(nbest)] for b in range(bs)]

        # Store ASR/LM state
        if len(end_hyps) > 0:
            self.lmstate_final = end_hyps[0]['lmstate']

        return nbest_hyps_idx, aws, scores
Beispiel #4
0
    def beam_search(self,
                    eouts,
                    elens,
                    params,
                    idx2token=None,
                    lm=None,
                    lm_second=None,
                    lm_second_bwd=None,
                    ctc_log_probs=None,
                    nbest=1,
                    exclude_eos=False,
                    refs_id=None,
                    utt_ids=None,
                    speakers=None,
                    ensmbl_eouts=None,
                    ensmbl_elens=None,
                    ensmbl_decs=[]):
        """Beam search decoding.

        Args:
            eouts (FloatTensor): `[B, T, enc_n_units]`
            elens (IntTensor): `[B]`
            params (dict): hyperparameters for decoding
            idx2token (): converter from index to token
            lm: firsh path LM
            lm_second: second path LM
            lm_second_bwd: secoding path backward LM
            ctc_log_probs (FloatTensor): `[B, T, vocab]`
            nbest (int): number of N-best list
            exclude_eos (bool): exclude <eos> from hypothesis
            refs_id (list): reference list
            utt_ids (list): utterance id list
            speakers (list): speaker list
            ensmbl_eouts (list): list of FloatTensor
            ensmbl_elens (list) list of list
            ensmbl_decs (list): list of torch.nn.Module
        Returns:
            nbest_hyps_idx (list): length `B`, each of which contains list of N hypotheses
            aws: dummy
            scores: dummy

        """
        bs = eouts.size(0)

        beam_width = params['recog_beam_width']
        assert 1 <= nbest <= beam_width
        ctc_weight = params['recog_ctc_weight']
        lm_weight = params['recog_lm_weight']
        lm_weight_second = params['recog_lm_second_weight']
        lm_weight_second_bwd = params['recog_lm_bwd_weight']
        # asr_state_carry_over = params['recog_asr_state_carry_over']
        lm_state_carry_over = params['recog_lm_state_carry_over']

        if lm is not None:
            assert lm_weight > 0
            lm.eval()
        if lm_second is not None:
            assert lm_weight_second > 0
            lm_second.eval()
        if lm_second_bwd is not None:
            assert lm_weight_second_bwd > 0
            lm_second_bwd.eval()

        if ctc_log_probs is not None:
            assert ctc_weight > 0
            ctc_log_probs = tensor2np(ctc_log_probs)

        nbest_hyps_idx = []
        eos_flags = []
        for b in range(bs):
            # Initialization per utterance
            y = eouts.new_zeros(bs, 1).fill_(self.eos).long()
            y_emb = self.dropout_emb(self.embed(y))
            dout, dstate = self.recurrency(y_emb, None)
            lmstate = None

            # For joint CTC-Attention decoding
            ctc_prefix_scorer = None
            if ctc_log_probs is not None:
                ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b],
                                                   self.blank, self.eos)

            if speakers is not None:
                if speakers[b] == self.prev_spk:
                    if lm_state_carry_over and isinstance(lm, RNNLM):
                        lmstate = self.lmstate_final
                self.prev_spk = speakers[b]

            helper = BeamSearch(beam_width, self.eos, ctc_weight,
                                self.device_id)

            end_hyps = []
            hyps = [{
                'hyp': [self.eos],
                'ys': [self.eos],
                'score':
                0.,
                'score_rnnt':
                0.,
                'score_lm':
                0.,
                'score_ctc':
                0.,
                'dout':
                dout,
                'dstate':
                dstate,
                'lmstate':
                lmstate,
                'ctc_state':
                ctc_prefix_scorer.initial_state()
                if ctc_prefix_scorer is not None else None
            }]
            for t in range(elens[b]):
                # preprocess for batch decoding
                douts = torch.cat([beam['dout'] for beam in hyps], dim=0)
                outs = self.joint(
                    eouts[b:b + 1, t:t + 1].repeat([douts.size(0), 1, 1]),
                    douts)
                scores_rnnt = torch.log_softmax(outs.squeeze(2).squeeze(1),
                                                dim=-1)

                # Update LM states for shallow fusion
                y = eouts.new_zeros(len(hyps), 1).long()
                for j, beam in enumerate(hyps):
                    y[j, 0] = beam['hyp'][-1]
                lmstate, scores_lm = None, None
                if lm is not None:
                    if hyps[0]['lmstate'] is not None:
                        lm_hxs = torch.cat(
                            [beam['lmstate']['hxs'] for beam in hyps], dim=1)
                        lm_cxs = torch.cat(
                            [beam['lmstate']['cxs'] for beam in hyps], dim=1)
                        lmstate = {'hxs': lm_hxs, 'cxs': lm_cxs}
                    lmout, lmstate, scores_lm = lm.predict(y, lmstate)

                new_hyps = []
                for j, beam in enumerate(hyps):
                    dout = douts[j:j + 1]
                    dstate = beam['dstate']
                    lmstate = beam['lmstate']

                    # Attention scores
                    total_scores_rnnt = beam['score_rnnt'] + scores_rnnt[j:j +
                                                                         1]
                    total_scores = total_scores_rnnt * (1 - ctc_weight)

                    # Add LM score <after> top-K selection
                    total_scores_topk, topk_ids = torch.topk(total_scores,
                                                             k=beam_width,
                                                             dim=-1,
                                                             largest=True,
                                                             sorted=True)
                    if lm is not None:
                        total_scores_lm = beam['score_lm'] + scores_lm[
                            j, -1, topk_ids[0]]
                        total_scores_topk += total_scores_lm * lm_weight
                    else:
                        total_scores_lm = eouts.new_zeros(beam_width)

                    # Add CTC score
                    new_ctc_states, total_scores_ctc, total_scores_topk = helper.add_ctc_score(
                        beam['hyp'], topk_ids, beam['ctc_state'],
                        total_scores_topk, ctc_prefix_scorer)

                    for k in range(beam_width):
                        idx = topk_ids[0, k].item()

                        if idx == self.blank:
                            beam['score'] = total_scores_topk[0, k].item()
                            beam['score_rnnt'] = total_scores_topk[0, k].item()
                            new_hyps.append(beam.copy())
                            continue

                        # skip blank-dominant frames
                        # if total_scores_topk[0, self.blank].item() > 0.7:
                        #     continue

                        # Update prediction network only when predicting non-blank labels
                        hyp_id = beam['hyp'] + [idx]
                        hyp_str = ' '.join(list(map(str, hyp_id)))
                        # if hyp_str in self.state_cache.keys():
                        #     # from cache
                        #     dout = self.state_cache[hyp_str]['dout']
                        #     new_dstate = self.state_cache[hyp_str]['dstate']
                        #     lmstate = self.state_cache[hyp_str]['lmstate']
                        # else:
                        y = eouts.new_zeros(1, 1).fill_(idx).long()
                        y_emb = self.dropout_emb(self.embed(y))
                        dout, new_dstate = self.recurrency(y_emb, dstate)

                        # store in cache
                        self.state_cache[hyp_str] = {
                            'dout': dout,
                            'dstate': new_dstate,
                            'lmstate': {
                                'hxs': lmstate['hxs'][:, j:j + 1],
                                'cxs': lmstate['cxs'][:, j:j + 1]
                            } if lmstate is not None else None,
                        }

                        new_hyps.append({
                            'hyp':
                            hyp_id,
                            'score':
                            total_scores_topk[0, k].item(),
                            'score_rnnt':
                            total_scores_rnnt[0, idx].item(),
                            'score_ctc':
                            total_scores_ctc[k].item(),
                            'score_lm':
                            total_scores_lm[k].item(),
                            'dout':
                            dout,
                            'dstate':
                            new_dstate,
                            'lmstate': {
                                'hxs': lmstate['hxs'][:, j:j + 1],
                                'cxs': lmstate['cxs'][:, j:j + 1]
                            } if lmstate is not None else None,
                            'ctc_state':
                            new_ctc_states[k]
                            if ctc_prefix_scorer is not None else None
                        })

                # Merge hypotheses having the same token sequences
                new_hyps_merged = {}
                for beam in new_hyps:
                    hyp_str = ' '.join(list(map(str, beam['hyp'])))
                    if hyp_str not in new_hyps_merged.keys():
                        new_hyps_merged[hyp_str] = beam
                    elif hyp_str in new_hyps_merged.keys():
                        if beam['score'] > new_hyps_merged[hyp_str]['score']:
                            new_hyps_merged[hyp_str] = beam
                new_hyps = [v for v in new_hyps_merged.values()]

                # Local pruning
                new_hyps_sorted = sorted(new_hyps,
                                         key=lambda x: x['score'],
                                         reverse=True)[:beam_width]

                # Remove complete hypotheses
                new_hyps, end_hyps, is_finish = helper.remove_complete_hyp(
                    new_hyps_sorted, end_hyps)
                hyps = new_hyps[:]
                if is_finish:
                    break

            # Global pruning
            if len(end_hyps) == 0:
                end_hyps = hyps[:]
            elif len(end_hyps) < nbest and nbest > 1:
                end_hyps.extend(hyps[:nbest - len(end_hyps)])

            # forward second path LM rescoring
            if lm_second is not None:
                self.lm_rescoring(end_hyps,
                                  lm_second,
                                  lm_weight_second,
                                  tag='second')

            # backward secodn path LM rescoring
            if lm_second_bwd is not None:
                self.lm_rescoring(end_hyps,
                                  lm_second_bwd,
                                  lm_weight_second_bwd,
                                  tag='second_bwd')

            # Sort by score
            end_hyps = sorted(end_hyps, key=lambda x: x['score'], reverse=True)

            # Reset state cache
            self.state_cache = OrderedDict()

            if idx2token is not None:
                if utt_ids is not None:
                    logger.info('Utt-id: %s' % utt_ids[b])
                assert self.vocab == idx2token.vocab
                logger.info('=' * 200)
                for k in range(len(end_hyps)):
                    if refs_id is not None:
                        logger.info('Ref: %s' % idx2token(refs_id[b]))
                    logger.info('Hyp: %s' % idx2token(end_hyps[k]['hyp'][1:]))
                    logger.info('log prob (hyp): %.7f' % end_hyps[k]['score'])
                    if ctc_prefix_scorer is not None:
                        logger.info('log prob (hyp, ctc): %.7f' %
                                    (end_hyps[k]['score_ctc'] * ctc_weight))
                    if lm is not None:
                        logger.info('log prob (hyp, first-path lm): %.7f' %
                                    (end_hyps[k]['score_lm'] * lm_weight))
                    if lm_second is not None:
                        logger.info('log prob (hyp, second-path lm): %.7f' %
                                    (end_hyps[k]['score_lm_second'] *
                                     lm_weight_second))
                    if lm_second_bwd is not None:
                        logger.info(
                            'log prob (hyp, second-path lm, reverse): %.7f' %
                            (end_hyps[k]['score_lm_second_rev'] *
                             lm_weight_second_bwd))
                    logger.info('-' * 50)

            # N-best list
            nbest_hyps_idx += [[
                np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest)
            ]]

            # Check <eos>
            eos_flags.append([(end_hyps[n]['hyp'][-1] == self.eos)
                              for n in range(nbest)])

        return nbest_hyps_idx, None, None