def beam_search(self, eouts, elens, params, idx2token=None,
                    lm=None, lm_second=None, lm_second_bwd=None, ctc_log_probs=None,
                    nbest=1, exclude_eos=False,
                    refs_id=None, utt_ids=None, speakers=None,
                    ensmbl_eouts=[], ensmbl_elens=[], ensmbl_decs=[], cache_states=True):
        """Beam search decoding.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            params (dict): decoding hyperparameters
            idx2token (): converter from index to token
            lm (torch.nn.module): firsh-pass LM
            lm_second (torch.nn.module): second-pass LM
            lm_second_bwd (torch.nn.module): secoding-pass backward LM
            ctc_log_probs (FloatTensor):
            nbest (int): number of N-best list
            exclude_eos (bool): exclude <eos> from hypothesis
            refs_id (List): reference list
            utt_ids (List): utterance id list
            speakers (List): speaker list
            ensmbl_eouts (List[FloatTensor]): encoder outputs for ensemble models
            ensmbl_elens (List[IntTensor]) encoder outputs for ensemble models
            ensmbl_decs (List[torch.nn.Module): decoders for ensemble models
            cache_states (bool): cache decoder states for fast decoding
        Returns:
            nbest_hyps_idx (List): length `[B]`, each of which contains list of N hypotheses
            aws (List): length `[B]`, each of which contains arrays of size `[H, L, T]`
            scores (List):

        """
        bs, xmax, _ = eouts.size()
        n_models = len(ensmbl_decs) + 1

        beam_width = params.get('recog_beam_width')
        assert 1 <= nbest <= beam_width
        ctc_weight = params.get('recog_ctc_weight')
        max_len_ratio = params.get('recog_max_len_ratio')
        min_len_ratio = params.get('recog_min_len_ratio')
        lp_weight = params.get('recog_length_penalty')
        length_norm = params.get('recog_length_norm')
        cache_emb = params.get('recog_cache_embedding')
        lm_weight = params.get('recog_lm_weight')
        lm_weight_second = params.get('recog_lm_second_weight')
        lm_weight_second_bwd = params.get('recog_lm_bwd_weight')
        eos_threshold = params.get('recog_eos_threshold')
        lm_state_carry_over = params.get('recog_lm_state_carry_over')
        softmax_smoothing = params.get('recog_softmax_smoothing')
        eps_wait = params.get('recog_mma_delay_threshold')

        helper = BeamSearch(beam_width, self.eos, ctc_weight, lm_weight, self.device)
        lm = helper.verify_lm_eval_mode(lm, lm_weight, cache_emb)
        lm_second = helper.verify_lm_eval_mode(lm_second, lm_weight_second, cache_emb)
        lm_second_bwd = helper.verify_lm_eval_mode(lm_second_bwd, lm_weight_second_bwd, cache_emb)

        # cache token embeddings
        if cache_emb:
            self.cache_embedding(eouts.device)

        if ctc_log_probs is not None:
            assert ctc_weight > 0
            ctc_log_probs = tensor2np(ctc_log_probs)

        nbest_hyps_idx, aws, scores = [], [], []
        eos_flags = []
        for b in range(bs):
            # Initialization per utterance
            lmstate = None
            ys = eouts.new_zeros((1, 1), dtype=torch.int64).fill_(self.eos)
            # print(ys.shape)
            for layer in self.layers:
                layer.reset()

            # For joint CTC-Attention decoding
            ctc_prefix_scorer = None
            if ctc_log_probs is not None:
                if self.bwd:
                    ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b][::-1], self.blank, self.eos)
                else:
                    ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b], self.blank, self.eos)

            if speakers is not None:
                if speakers[b] == self.prev_spk:
                    if lm_state_carry_over and isinstance(lm, RNNLM):
                        lmstate = self.lmstate_final
                self.prev_spk = speakers[b]

            end_hyps = []
            hyps = [{'hyp': [self.eos],
                     'ys': ys,
                     'cache': None,
                     'score': 0.,
                     'score_att': 0.,
                     'score_ctc': 0.,
                     'score_lm': 0.,
                     'aws': [None],
                     'lmstate': lmstate,
                     'ensmbl_cache': [[None] * dec.n_layers for dec in ensmbl_decs] if n_models > 1 else None,
                     'ctc_state': ctc_prefix_scorer.initial_state() if ctc_prefix_scorer is not None else None,
                     'quantity_rate': 1.,
                     'streamable': True,
                     'streaming_failed_point': 1000}]
            streamable_global = True
            ymax = math.ceil(elens[b] * max_len_ratio)
            for i in range(ymax):
                # batchfy all hypotheses for batch decoding
                cache = [None] * self.n_layers
                if cache_states and i > 0:
                    for lth in range(self.n_layers): # 
                        cache[lth] = torch.cat([beam['cache'][lth] for beam in hyps], dim=0)
                ys = eouts.new_zeros((len(hyps), i + 1), dtype=torch.int64)
                for j, beam in enumerate(hyps):
                    ys[j, :] = beam['ys']
                if i > 0:
                    xy_aws_prev = torch.cat([beam['aws'][-1] for beam in hyps], dim=0)  # `[B, n_layers, H_ma, 1, klen]`
                else:
                    xy_aws_prev = None

                # Update LM states for shallow fusion
                y_lm = ys[:, -1:].clone()  # NOTE: this is important
                _, lmstate, scores_lm = helper.update_rnnlm_state_batch(lm, hyps, y_lm)

                # for the main model
                # print(i)
                causal_mask = eouts.new_ones(i + 1, i + 1, dtype=torch.uint8)
                causal_mask = torch.tril(causal_mask).unsqueeze(0).repeat([ys.size(0), 1, 1])
                # print(causal_mask.shape)
                out = self.pos_enc(self.embed_token_id(ys), scale=True)  # scaled + dropout
                # print(out.shape)
                # assert False, 'vv'
                n_heads_total = 0
                eouts_b = eouts[b:b + 1, :elens[b]].repeat([ys.size(0), 1, 1]) # [Beam, T, dim]
                new_cache = [None] * self.n_layers
                xy_aws_layers = []
                xy_aws = None
                lth_s = self.mma_first_layer - 1
                # 自回归解码
                for lth, layer in enumerate(self.layers):
                    out = layer(
                        out, causal_mask, eouts_b, None,
                        cache=cache[lth],
                        xy_aws_prev=xy_aws_prev[:, lth - lth_s] if lth >= lth_s and i > 0 else None,
                        eps_wait=eps_wait)
                    xy_aws = layer.xy_aws

                    new_cache[lth] = out
                    if xy_aws is not None:
                        xy_aws_layers.append(xy_aws)
                logits = self.output(self.norm_out(out[:, -1])) # 取当前时刻概率输出
                probs = torch.softmax(logits * softmax_smoothing, dim=1)
                xy_aws_layers = torch.stack(xy_aws_layers, dim=1)  # `[B, H, n_layers, L, T]`

                # Ensemble initialization
                ensmbl_cache = [[None] * dec.n_layers for dec in ensmbl_decs]
                if n_models > 1 and cache_states and i > 0:
                    for i_e, dec in enumerate(ensmbl_decs):
                        for lth in range(dec.n_layers):
                            ensmbl_cache[i_e][lth] = torch.cat([beam['ensmbl_cache'][i_e][lth]
                                                                for beam in hyps], dim=0)

                # for the ensemble
                ensmbl_new_cache = [[None] * dec.n_layers for dec in ensmbl_decs]
                for i_e, dec in enumerate(ensmbl_decs):
                    out_e = dec.pos_enc(dec.embed(ys))  # scaled + dropout
                    eouts_e = ensmbl_eouts[i_e][b:b + 1, :elens[b]].repeat([ys.size(0), 1, 1])
                    for lth in range(dec.n_layers):
                        out_e = dec.layers[lth](out_e, causal_mask, eouts_e, None,
                                                cache=ensmbl_cache[i_e][lth])
                        ensmbl_new_cache[i_e][lth] = out_e
                    logits_e = dec.output(dec.norm_out(out_e[:, -1]))
                    probs += torch.softmax(logits_e * softmax_smoothing, dim=1)
                    # NOTE: sum in the probability scale (not log-scale)

                # Ensemble 多个模型融合
                scores_att = torch.log(probs / n_models) # [1, vocab]
                # print(scores_att.shape)
                # assert False, 'vv'
                new_hyps = []
                for j, beam in enumerate(hyps): # hyps [,] # 每个beam生成beam
                    # Attention scores
                    total_scores_att = beam['score_att'] + scores_att[j:j + 1] # current time T # [[vocab]]
                    total_scores = total_scores_att * (1 - ctc_weight)

                    # Add LM score <before> top-K selection
                    if lm is not None:
                        total_scores_lm = beam['score_lm'] + scores_lm[j:j + 1, -1]
                        total_scores += total_scores_lm * lm_weight
                    else:
                        total_scores_lm = eouts.new_zeros(1, self.vocab)

                    # topk_ids 
                    total_scores_topk, topk_ids = torch.topk(
                        total_scores, k=beam_width, dim=1, largest=True, sorted=True)

                    # Add length penalty
                    if lp_weight > 0:
                        total_scores_topk += (len(beam['hyp'][1:]) + 1) * lp_weight

                    # Add CTC score
                    new_ctc_states, total_scores_ctc, total_scores_topk = helper.add_ctc_score(
                        beam['hyp'], topk_ids, beam['ctc_state'],
                        total_scores_topk, ctc_prefix_scorer)

                    new_aws = beam['aws'] + [xy_aws_layers[j:j + 1, :, :, -1:]]
                    aws_j = torch.cat(new_aws[1:], dim=3)  # `[1, H, n_layers, L, T]`

                    # forward direction
                    for k in range(beam_width):
                        idx = topk_ids[0, k].item() # k-beam 的索引
                        length_norm_factor = len(beam['hyp'][1:]) + 1 if length_norm else 1
                        total_score = total_scores_topk[0, k].item() / length_norm_factor # 当前长度

                        if idx == self.eos:
                            # Exclude short hypotheses
                            # remove 短句 中间的静默信号
                            if len(beam['hyp'][1:]) < elens[b] * min_len_ratio:
                                continue
                            # EOS threshold
                            # 找到不是EOS的最大得分idx
                            max_score_no_eos = scores_att[j, :idx].max(0)[0].item()
                            max_score_no_eos = max(max_score_no_eos, scores_att[j, idx + 1:].max(0)[0].item())
                            if scores_att[j, idx].item() <= eos_threshold * max_score_no_eos:
                                # 继续识别 跳过当前帧
                                continue

                        streaming_failed_point = beam['streaming_failed_point']
                        quantity_rate = 1.
                        # 流式相关的
                        if self.attn_type == 'mocha':
                            n_tokens_hyp_k = i + 1
                            n_quantity_k = aws_j[:, :, :, :n_tokens_hyp_k].int().sum().item()
                            quantity_diff = n_tokens_hyp_k * n_heads_total - n_quantity_k

                            if quantity_diff != 0:
                                if idx == self.eos:
                                    n_tokens_hyp_k -= 1  # NOTE: do not count <eos> for streamability
                                    n_quantity_k = aws_j[:, :, :, :n_tokens_hyp_k].int().sum().item()
                                else:
                                    streamable_global = False
                                if n_tokens_hyp_k * n_heads_total == 0:
                                    quantity_rate = 0
                                else:
                                    quantity_rate = n_quantity_k / (n_tokens_hyp_k * n_heads_total)

                            if beam['streamable'] and not streamable_global:
                                streaming_failed_point = i

                        new_hyps.append(
                            {'hyp': beam['hyp'] + [idx],
                             'ys': torch.cat([beam['ys'], eouts.new_zeros((1, 1), dtype=torch.int64).fill_(idx)], dim=-1),
                             'cache': [new_cache_l[j:j + 1] for new_cache_l in new_cache] if cache_states else cache,
                             'score': total_score,
                             'score_att': total_scores_att[0, idx].item(),
                             'score_ctc': total_scores_ctc[k].item(),
                             'score_lm': total_scores_lm[0, idx].item(),
                             'aws': new_aws,
                             'lmstate': {'hxs': lmstate['hxs'][:, j:j + 1],
                                         'cxs': lmstate['cxs'][:, j:j + 1]} if lmstate is not None else None,
                             'ctc_state': new_ctc_states[k] if ctc_prefix_scorer is not None else None,
                             'ensmbl_cache': [[new_cache_e_l[j:j + 1] for new_cache_e_l in new_cache_e]
                                              for new_cache_e in ensmbl_new_cache] if cache_states else None,
                             'streamable': streamable_global,
                             'streaming_failed_point': streaming_failed_point,
                             'quantity_rate': quantity_rate})

                # Local pruning 
                # new_hyps[beamsize,hyps]
                new_hyps_sorted = sorted(new_hyps, key=lambda x: x['score'], reverse=True)[:beam_width]

                # Remove complete hypotheses
                # 剪枝 结果beamwidth大小的列表
                new_hyps, end_hyps, is_finish = helper.remove_complete_hyp(
                    new_hyps_sorted, end_hyps, prune=True)
                hyps = new_hyps[:]
                if is_finish:
                    break

            # Global pruning # 一句识别结束
            if len(end_hyps) == 0:
                end_hyps = hyps[:]
            elif len(end_hyps) < nbest and nbest > 1:
                end_hyps.extend(hyps[:nbest - len(end_hyps)])

            # forward/backward second-pass LM rescoring
            end_hyps = helper.lm_rescoring(end_hyps, lm_second, lm_weight_second,
                                           length_norm=length_norm, tag='second')
            end_hyps = helper.lm_rescoring(end_hyps, lm_second_bwd, lm_weight_second_bwd,
                                           length_norm=length_norm, tag='second_bwd')

            # Sort by score
            end_hyps = sorted(end_hyps, key=lambda x: x['score'], reverse=True)

            # TODO: 
            for j in range(len(end_hyps[0]['aws'][1:])):
                tmp = end_hyps[0]['aws'][j + 1]
                end_hyps[0]['aws'][j + 1] = tmp.view(1, -1, tmp.size(-2), tmp.size(-1))

            # metrics for streaming infernece
            self.streamable = end_hyps[0]['streamable']
            self.quantity_rate = end_hyps[0]['quantity_rate']
            self.last_success_frame_ratio = None

            if idx2token is not None:
                if utt_ids is not None:
                    logger.info('Utt-id: %s' % utt_ids[b])
                assert self.vocab == idx2token.vocab
                logger.info('=' * 200)
                for k in range(len(end_hyps)):
                    if refs_id is not None:
                        logger.info('Ref: %s' % idx2token(refs_id[b]))
                    logger.info('Hyp: %s' % idx2token(
                        end_hyps[k]['hyp'][1:][::-1] if self.bwd else end_hyps[k]['hyp'][1:]))
                    logger.info('num tokens (hyp): %d' % len(end_hyps[k]['hyp'][1:]))
                    logger.info('log prob (hyp): %.7f' % end_hyps[k]['score'])
                    logger.info('log prob (hyp, att): %.7f' %
                                (end_hyps[k]['score_att'] * (1 - ctc_weight)))
                    if ctc_prefix_scorer is not None:
                        logger.info('log prob (hyp, ctc): %.7f' %
                                    (end_hyps[k]['score_ctc'] * ctc_weight))
                    if lm is not None:
                        logger.info('log prob (hyp, first-pass lm): %.7f' %
                                    (end_hyps[k]['score_lm'] * lm_weight))
                    if lm_second is not None:
                        logger.info('log prob (hyp, second-pass lm): %.7f' %
                                    (end_hyps[k]['score_lm_second'] * lm_weight_second))
                    if lm_second_bwd is not None:
                        logger.info('log prob (hyp, second-pass lm, reverse): %.7f' %
                                    (end_hyps[k]['score_lm_second_bwd'] * lm_weight_second_bwd))
                    if self.attn_type == 'mocha':
                        logger.info('streamable: %s' % end_hyps[k]['streamable'])
                        logger.info('streaming failed point: %d' %
                                    (end_hyps[k]['streaming_failed_point'] + 1))
                        logger.info('quantity rate [%%]: %.2f' %
                                    (end_hyps[k]['quantity_rate'] * 100))
                    logger.info('-' * 50)

                if self.attn_type == 'mocha' and end_hyps[0]['streaming_failed_point'] < 1000:
                    assert not self.streamable
                    aws_last_success = end_hyps[0]['aws'][1:][end_hyps[0]['streaming_failed_point'] - 1]
                    rightmost_frame = max(0, aws_last_success[0, :, 0].nonzero()[:, -1].max().item()) + 1
                    frame_ratio = rightmost_frame * 100 / xmax
                    self.last_success_frame_ratio = frame_ratio
                    logger.info('streaming last success frame ratio: %.2f' % frame_ratio)

            # N-best list
            if self.bwd:
                # Reverse the order
                nbest_hyps_idx += [[np.array(end_hyps[n]['hyp'][1:][::-1]) for n in range(nbest)]]
                aws += [[tensor2np(torch.cat(end_hyps[n]['aws'][1:][::-1], dim=2).squeeze(0)) for n in range(nbest)]]
            else:
                nbest_hyps_idx += [[np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest)]]
                aws += [[tensor2np(torch.cat(end_hyps[n]['aws'][1:], dim=2).squeeze(0)) for n in range(nbest)]]
            scores += [[end_hyps[n]['score_att'] for n in range(nbest)]]

            # Check <eos>
            eos_flags.append([(end_hyps[n]['hyp'][-1] == self.eos) for n in range(nbest)])

        # Exclude <eos> (<sos> in case of the backward decoder)
        if exclude_eos:
            if self.bwd:
                nbest_hyps_idx = [[nbest_hyps_idx[b][n][1:] if eos_flags[b][n]
                                   else nbest_hyps_idx[b][n] for n in range(nbest)] for b in range(bs)]
                aws = [[aws[b][n][:, 1:] if eos_flags[b][n] else aws[b][n] for n in range(nbest)] for b in range(bs)]
            else:
                nbest_hyps_idx = [[nbest_hyps_idx[b][n][:-1] if eos_flags[b][n]
                                   else nbest_hyps_idx[b][n] for n in range(nbest)] for b in range(bs)]
                aws = [[aws[b][n][:, :-1] if eos_flags[b][n] else aws[b][n] for n in range(nbest)] for b in range(bs)]

        # Store ASR/LM state
        if bs == 1:
            self.lmstate_final = end_hyps[0]['lmstate']

        return nbest_hyps_idx, aws, scores
Beispiel #2
0
    def beam_search(self, eouts, elens, params, idx2token,
                    lm=None, lm_second=None, lm_second_bwd=None,
                    nbest=1, refs_id=None, utt_ids=None, speakers=None):
        """Beam search decoding.

        Args:
            eouts (FloatTensor): `[B, T, enc_n_units]`
            elens (List): length `B`
            params (dict):
                recog_beam_width (int): size of beam
                recog_length_penalty (float): length penalty
                recog_lm_weight (float): weight of first path LM score
                recog_lm_second_weight (float): weight of second path LM score
                recog_lm_bwd_weight (float): weight of second path backward LM score
            idx2token (): converter from index to token
            lm: firsh path LM
            lm_second: second path LM
            lm_second_bwd: second path backward LM
            nbest (int):
            refs_id (List): reference list
            utt_ids (List): utterance id list
            speakers (List): speaker list
        Returns:
            nbest_hyps_idx (List[List[List]]): Best path hypothesis

        """
        bs = eouts.size(0)

        beam_width = params['recog_beam_width']
        lp_weight = params['recog_length_penalty']
        lm_weight = params['recog_lm_weight']
        lm_weight_second = params['recog_lm_second_weight']
        lm_weight_second_bwd = params['recog_lm_bwd_weight']

        helper = BeamSearch(beam_width, self.eos, 1.0, eouts.device)
        lm = helper.verify_lm_eval_mode(lm, lm_weight)
        lm_second = helper.verify_lm_eval_mode(lm_second, lm_weight_second)
        lm_second_bwd = helper.verify_lm_eval_mode(lm_second_bwd, lm_weight_second_bwd)

        nbest_hyps_idx = []
        log_probs = torch.log_softmax(self.output(eouts), dim=-1)
        for b in range(bs):
            # Elements in the beam are (prefix, (p_b, p_no_blank))
            # Initialize the beam with the empty sequence, a probability of
            # 1 for ending in blank and zero for ending in non-blank (in log space).
            beam = [{'hyp': [self.eos],  # <eos> is used for LM
                     'p_b': LOG_1,
                     'p_nb': LOG_0,
                     'score_lm': LOG_1,
                     'lmstate': None}]

            for t in range(elens[b]):
                new_beam = []

                # Pick up the top-k scores
                log_probs_topk, topk_ids = torch.topk(
                    log_probs[b:b + 1, t], k=min(beam_width, self.vocab), dim=-1, largest=True, sorted=True)

                for i_beam in range(len(beam)):
                    hyp = beam[i_beam]['hyp'][:]
                    p_b = beam[i_beam]['p_b']
                    p_nb = beam[i_beam]['p_nb']
                    score_lm = beam[i_beam]['score_lm']

                    # case 1. hyp is not extended
                    new_p_b = np.logaddexp(p_b + log_probs[b, t, self.blank].item(),
                                           p_nb + log_probs[b, t, self.blank].item())
                    if len(hyp) > 1:
                        new_p_nb = p_nb + log_probs[b, t, hyp[-1]].item()
                    else:
                        new_p_nb = LOG_0
                    score_ctc = np.logaddexp(new_p_b, new_p_nb)
                    score_lp = len(hyp[1:]) * lp_weight
                    new_beam.append({'hyp': hyp,
                                     'score': score_ctc + score_lm + score_lp,
                                     'p_b': new_p_b,
                                     'p_nb': new_p_nb,
                                     'score_ctc': score_ctc,
                                     'score_lm': score_lm,
                                     'score_lp': score_lp,
                                     'lmstate': beam[i_beam]['lmstate']})

                    # Update LM states for shallow fusion
                    if lm is not None:
                        _, lmstate, lm_log_probs = lm.predict(
                            eouts.new_zeros(1, 1).fill_(hyp[-1]), beam[i_beam]['lmstate'])
                    else:
                        lmstate = None

                    # case 2. hyp is extended
                    new_p_b = LOG_0
                    for c in tensor2np(topk_ids)[0]:
                        p_t = log_probs[b, t, c].item()

                        if c == self.blank:
                            continue

                        c_prev = hyp[-1] if len(hyp) > 1 else None
                        if c == c_prev:
                            new_p_nb = p_b + p_t
                            # TODO(hirofumi): apply character LM here
                        else:
                            new_p_nb = np.logaddexp(p_b + p_t, p_nb + p_t)
                            # TODO(hirofumi): apply character LM here
                            if c == self.space:
                                pass
                                # TODO(hirofumi): apply word LM here

                        score_ctc = np.logaddexp(new_p_b, new_p_nb)
                        score_lp = (len(hyp[1:]) + 1) * lp_weight
                        if lm_weight > 0 and lm is not None:
                            local_score_lm = lm_log_probs[0, 0, c].item() * lm_weight
                            score_lm += local_score_lm
                        new_beam.append({'hyp': hyp + [c],
                                         'score': score_ctc + score_lm + score_lp,
                                         'p_b': new_p_b,
                                         'p_nb': new_p_nb,
                                         'score_ctc': score_ctc,
                                         'score_lm': score_lm,
                                         'score_lp': score_lp,
                                         'lmstate': lmstate})

                # Pruning
                beam = sorted(new_beam, key=lambda x: x['score'], reverse=True)[:beam_width]

            # forward second path LM rescoring
            helper.lm_rescoring(beam, lm_second, lm_weight_second, tag='second')

            # backward secodn path LM rescoring
            helper.lm_rescoring(beam, lm_second_bwd, lm_weight_second_bwd, tag='second_bwd')

            # Exclude <eos>
            nbest_hyps_idx.append([hyp['hyp'][1:] for hyp in beam])

            if idx2token is not None:
                if utt_ids is not None:
                    logger.info('Utt-id: %s' % utt_ids[b])
                assert self.vocab == idx2token.vocab
                logger.info('=' * 200)
                for k in range(len(beam)):
                    if refs_id is not None:
                        logger.info('Ref: %s' % idx2token(refs_id[b]))
                    logger.info('Hyp: %s' % idx2token(beam[k]['hyp'][1:]))
                    logger.info('log prob (hyp): %.7f' % beam[k]['score'])
                    logger.info('log prob (hyp, ctc): %.7f' % (beam[k]['score_ctc']))
                    logger.info('log prob (hyp, lp): %.7f' % (beam[k]['score_lp'] * lp_weight))
                    if lm is not None:
                        logger.info('log prob (hyp, first-path lm): %.7f' %
                                    (beam[k]['score_lm'] * lm_weight))
                    if lm_second is not None:
                        logger.info('log prob (hyp, second-path lm): %.7f' %
                                    (beam[k]['score_lm_second'] * lm_weight_second))
                    logger.info('-' * 50)

        return nbest_hyps_idx
Beispiel #3
0
    def beam_search(self,
                    eouts,
                    elens,
                    params,
                    idx2token=None,
                    lm=None,
                    lm_second=None,
                    lm_second_bwd=None,
                    ctc_log_probs=None,
                    nbest=1,
                    exclude_eos=False,
                    refs_id=None,
                    utt_ids=None,
                    speakers=None,
                    ensmbl_eouts=[],
                    ensmbl_elens=[],
                    ensmbl_decs=[]):
        """Beam search decoding.

        Args:
            eouts (FloatTensor): `[B, T, enc_n_units]`
            elens (IntTensor): `[B]`
            params (dict): decoding hyperparameters
            idx2token (): converter from index to token
            lm (torch.nn.module): firsh-pass LM
            lm_second (torch.nn.module): second-pass LM
            lm_second_bwd (torch.nn.module): second-pass backward LM
            ctc_log_probs (FloatTensor): `[B, T, vocab]`
            nbest (int): number of N-best list
            exclude_eos (bool): exclude <eos> from hypothesis
            refs_id (List): reference list
            utt_ids (List): utterance id list
            speakers (List): speaker list
            ensmbl_eouts (List[FloatTensor]): encoder outputs for ensemble models
            ensmbl_elens (List[IntTensor]) encoder outputs for ensemble models
            ensmbl_decs (List[torch.nn.Module): decoders for ensemble models
        Returns:
            nbest_hyps_idx (List): length `[B]`, each of which contains list of N hypotheses
            aws: dummy
            scores: dummy

        """
        bs = eouts.size(0)

        beam_width = params.get('recog_beam_width')
        assert 1 <= nbest <= beam_width
        ctc_weight = params.get('recog_ctc_weight')
        assert ctc_weight == 0
        assert ctc_log_probs is None
        cache_emb = params.get('recog_cache_embedding')
        lm_weight = params.get('recog_lm_weight')
        lm_weight_second = params.get('recog_lm_second_weight')
        lm_weight_second_bwd = params.get('recog_lm_bwd_weight')
        lm_state_CO = params.get('recog_lm_state_carry_over')
        softmax_smoothing = params.get('recog_softmax_smoothing')
        beam_search_type = params.get('recog_rnnt_beam_search_type')

        helper = BeamSearch(beam_width, self.eos, ctc_weight, lm_weight,
                            eouts.device)
        lm = helper.verify_lm_eval_mode(lm, lm_weight, cache_emb)
        if lm is not None:
            assert isinstance(lm, RNNLM)
        lm_second = helper.verify_lm_eval_mode(lm_second, lm_weight_second,
                                               cache_emb)
        lm_second_bwd = helper.verify_lm_eval_mode(lm_second_bwd,
                                                   lm_weight_second_bwd,
                                                   cache_emb)

        # cache token embeddings
        if cache_emb:
            self.cache_embedding(eouts.device)

        nbest_hyps_idx = []
        for b in range(bs):
            # Initialization per utterance
            dstate = {
                'hxs': eouts.new_zeros(self.n_layers, 1, self.dec_n_units),
                'cxs': eouts.new_zeros(self.n_layers, 1, self.dec_n_units)
            }
            lmstate = {
                'hxs': eouts.new_zeros(lm.n_layers, 1, lm.n_units),
                'cxs': eouts.new_zeros(lm.n_layers, 1, lm.n_units)
            } if lm is not None else None

            if speakers is not None:
                if speakers[b] == self.prev_spk:
                    if lm_state_CO:
                        lmstate = self.lmstate_final
                self.prev_spk = speakers[b]

            end_hyps = []
            hyps = self.initialize_beam([self.eos], dstate, lmstate)
            self.state_cache = OrderedDict()

            if beam_search_type == 'time_sync_mono':
                hyps, new_hyps = self._time_sync_mono(
                    hyps, helper, eouts[b:b + 1, :elens[b]], softmax_smoothing,
                    lm)
            elif beam_search_type == 'time_sync':
                hyps, new_hyps = self._time_sync(hyps, helper,
                                                 eouts[b:b + 1, :elens[b]],
                                                 softmax_smoothing, lm)
            else:
                raise NotImplementedError(beam_search_type)

            # Global pruning
            end_hyps = hyps[:]
            if len(end_hyps) < nbest and nbest > 1:
                end_hyps.extend(new_hyps[:nbest - len(end_hyps)])

            # forward/backward second-pass LM rescoring
            end_hyps = helper.lm_rescoring(end_hyps,
                                           lm_second,
                                           lm_weight_second,
                                           tag='second')
            end_hyps = helper.lm_rescoring(end_hyps,
                                           lm_second_bwd,
                                           lm_weight_second_bwd,
                                           tag='second_bwd')

            # Normalize by length
            end_hyps = sorted(
                end_hyps,
                key=lambda x: x['score'] / max(len(x['hyp'][1:]), 1),
                reverse=True)
            # NOTE: See Algorithm 1 in https://arxiv.org/abs/1211.3711

            if idx2token is not None:
                if utt_ids is not None:
                    logger.info('Utt-id: %s' % utt_ids[b])
                assert self.vocab == idx2token.vocab
                logger.info('=' * 200)
                for k in range(len(end_hyps)):
                    if refs_id is not None:
                        logger.info('Ref: %s' % idx2token(refs_id[b]))
                    logger.info('Hyp: %s' % idx2token(end_hyps[k]['hyp'][1:]))
                    if len(end_hyps[k]['hyp']) > 1:
                        logger.info('num tokens (hyp): %d' %
                                    len(end_hyps[k]['hyp'][1:]))
                    logger.info('log prob (hyp): %.7f' % end_hyps[k]['score'])
                    logger.info('log prob (hyp, rnnt): %.7f' %
                                end_hyps[k]['score_rnnt'])
                    if lm is not None:
                        logger.info('log prob (hyp, first-pass lm): %.7f' %
                                    (end_hyps[k]['score_lm'] * lm_weight))
                    if lm_second is not None:
                        logger.info('log prob (hyp, second-pass lm): %.7f' %
                                    (end_hyps[k]['score_lm_second'] *
                                     lm_weight_second))
                    if lm_second_bwd is not None:
                        logger.info(
                            'log prob (hyp, second-pass lm, reverse): %.7f' %
                            (end_hyps[k]['score_lm_second_bwd'] *
                             lm_weight_second_bwd))
                    logger.info('-' * 50)

            # N-best list (exclude <eos>)
            nbest_hyps_idx += [[
                np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest)
            ]]

        # Store ASR/LM state
        if bs == 1:
            self.dstates_final = end_hyps[0]['dstate']
            self.lmstate_final = end_hyps[0]['lmstate']

        return nbest_hyps_idx, None, None
Beispiel #4
0
    def beam_search(self,
                    eouts,
                    elens,
                    params,
                    idx2token,
                    lm=None,
                    lm_second=None,
                    lm_second_bwd=None,
                    nbest=1,
                    refs_id=None,
                    utt_ids=None,
                    speakers=None):
        """Beam search decoding.

        Args:
            eouts (FloatTensor): `[B, T, enc_n_units]`
            elens (List): length `[B]`
            params (dict): decoding hyperparameters
            idx2token (): converter from index to token
            lm (torch.nn.module): firsh-pass LM
            lm_second (torch.nn.module): second-pass LM
            lm_second_bwd (torch.nn.module): second-pass backward LM
            nbest (int): number of N-best list
            refs_id (List): reference list
            utt_ids (List): utterance id list
            speakers (List): speaker list
        Returns:
            nbest_hyps_idx (List[List[List]]): Best path hypothesis

        """
        bs = eouts.size(0)

        beam_width = params.get('recog_beam_width')
        lp_weight = params.get('recog_length_penalty')
        cache_emb = params.get('recog_cache_embedding')
        lm_weight = params.get('recog_lm_weight')
        lm_weight_second = params.get('recog_lm_second_weight')
        lm_weight_second_bwd = params.get('recog_lm_bwd_weight')
        lm_state_CO = params.get('recog_lm_state_carry_over')
        softmax_smoothing = params.get('recog_softmax_smoothing')

        helper = BeamSearch(beam_width, self.eos, 1.0, lm_weight, eouts.device)
        lm = helper.verify_lm_eval_mode(lm, lm_weight, cache_emb)
        if lm is not None:
            assert isinstance(lm, RNNLM)
        lm_second = helper.verify_lm_eval_mode(lm_second, lm_weight_second,
                                               cache_emb)
        lm_second_bwd = helper.verify_lm_eval_mode(lm_second_bwd,
                                                   lm_weight_second_bwd,
                                                   cache_emb)

        log_probs = torch.log_softmax(self.output(eouts) * softmax_smoothing,
                                      dim=-1)

        nbest_hyps_idx = []
        for b in range(bs):
            # Initialization per utterance
            lmstate = {
                'hxs': eouts.new_zeros(lm.n_layers, 1, lm.n_units),
                'cxs': eouts.new_zeros(lm.n_layers, 1, lm.n_units)
            } if lm is not None else None

            if speakers is not None:
                if speakers[b] == self.prev_spk:
                    if lm_state_CO:
                        lmstate = self.lmstate_final
                self.prev_spk = speakers[b]

            hyps = self.initialize_beam([self.eos], lmstate)
            self.state_cache = OrderedDict()

            hyps, new_hyps_sorted = self._beam_search(hyps, helper,
                                                      log_probs[b], lm,
                                                      lp_weight)

            # Global pruning
            end_hyps = hyps[:]
            if len(end_hyps) < nbest and nbest > 1:
                end_hyps.extend(new_hyps_sorted[:nbest - len(end_hyps)])

            # forward/backward second-pass LM rescoring
            end_hyps = helper.lm_rescoring(end_hyps,
                                           lm_second,
                                           lm_weight_second,
                                           tag='second')
            end_hyps = helper.lm_rescoring(end_hyps,
                                           lm_second_bwd,
                                           lm_weight_second_bwd,
                                           tag='second_bwd')

            # Normalize by length
            end_hyps = sorted(
                end_hyps,
                key=lambda x: x['score'] / max(len(x['hyp'][1:]), 1),
                reverse=True)

            if idx2token is not None:
                if utt_ids is not None:
                    logger.info('Utt-id: %s' % utt_ids[b])
                assert self.vocab == idx2token.vocab
                logger.info('=' * 200)
                for k in range(len(end_hyps)):
                    if refs_id is not None:
                        logger.info('Ref: %s' % idx2token(refs_id[b]))
                    logger.info('Hyp: %s' % idx2token(end_hyps[k]['hyp'][1:]))
                    logger.info('log prob (hyp): %.7f' % end_hyps[k]['score'])
                    logger.info('log prob (hyp, ctc): %.7f' %
                                (end_hyps[k]['score_ctc']))
                    logger.info('log prob (hyp, lp): %.7f' %
                                (end_hyps[k]['score_lp'] * lp_weight))
                    if lm is not None:
                        logger.info('log prob (hyp, first-pass lm): %.7f' %
                                    (end_hyps[k]['score_lm'] * lm_weight))
                    if lm_second is not None:
                        logger.info('log prob (hyp, second-pass lm): %.7f' %
                                    (end_hyps[k]['score_lm_second'] *
                                     lm_weight_second))
                    if lm_second_bwd is not None:
                        logger.info(
                            'log prob (hyp, second-pass lm, reverse): %.7f' %
                            (end_hyps[k]['score_lm_second_bwd'] *
                             lm_weight_second_bwd))
                    logger.info('-' * 50)

            # N-best list (exclude <eos>)
            nbest_hyps_idx += [[
                np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest)
            ]]

        # Store LM state
        if bs == 1:
            self.lmstate_final = end_hyps[0]['lmstate']

        return nbest_hyps_idx
    def beam_search(self,
                    eouts,
                    elens,
                    params,
                    idx2token=None,
                    lm=None,
                    lm_second=None,
                    lm_second_bwd=None,
                    ctc_log_probs=None,
                    nbest=1,
                    exclude_eos=False,
                    refs_id=None,
                    utt_ids=None,
                    speakers=None,
                    ensmbl_eouts=[],
                    ensmbl_elens=[],
                    ensmbl_decs=[]):
        """Beam search decoding.

        Args:
            eouts (FloatTensor): `[B, T, enc_n_units]`
            elens (IntTensor): `[B]`
            params (dict): hyperparameters for decoding
            idx2token (): converter from index to token
            lm (torch.nn.module): firsh path LM
            lm_second (torch.nn.module): second path LM
            lm_second_bwd (torch.nn.module): secoding path backward LM
            ctc_log_probs (FloatTensor): `[B, T, vocab]`
            nbest (int): number of N-best list
            exclude_eos (bool): exclude <eos> from hypothesis
            refs_id (List): reference list
            utt_ids (List): utterance id list
            speakers (List): speaker list
            ensmbl_eouts (List[FloatTensor]): encoder outputs for ensemble models
            ensmbl_elens (List[IntTensor]) encoder outputs for ensemble models
            ensmbl_decs (List[torch.nn.Module): decoders for ensemble models
        Returns:
            nbest_hyps_idx (List): length `[B]`, each of which contains list of N hypotheses
            aws: dummy
            scores: dummy

        """
        bs = eouts.size(0)

        beam_width = params['recog_beam_width']
        assert 1 <= nbest <= beam_width
        ctc_weight = params['recog_ctc_weight']
        assert ctc_weight == 0
        assert ctc_log_probs is None
        lm_weight = params['recog_lm_weight']
        lm_weight_second = params['recog_lm_second_weight']
        lm_weight_second_bwd = params['recog_lm_bwd_weight']
        # asr_state_carry_over = params['recog_asr_state_carry_over']
        lm_state_carry_over = params['recog_lm_state_carry_over']
        merge_prob = True  # TODO: make this parameter

        helper = BeamSearch(beam_width, self.eos, ctc_weight, eouts.device)
        lm = helper.verify_lm_eval_mode(lm, lm_weight)
        lm_second = helper.verify_lm_eval_mode(lm_second, lm_weight_second)
        lm_second_bwd = helper.verify_lm_eval_mode(lm_second_bwd,
                                                   lm_weight_second_bwd)

        nbest_hyps_idx = []
        eos_flags = []
        for b in range(bs):
            # Initialization per utterance
            y = eouts.new_zeros((1, 1), dtype=torch.int64).fill_(self.eos)
            y_emb = self.dropout_emb(self.embed(y))
            dout, dstate = self.recurrency(y_emb, None)
            lmstate = None

            if speakers is not None:
                if speakers[b] == self.prev_spk:
                    if lm_state_carry_over and isinstance(lm, RNNLM):
                        lmstate = self.lmstate_final
                self.prev_spk = speakers[b]

            end_hyps = []
            hyps = [{
                'hyp': [self.eos],
                'hyp_ids_str': '',
                'ys': [self.eos],
                'score': 0.,
                'score_rnnt': 0.,
                'score_lm': 0.,
                'dout': dout,
                'dstate': dstate,
                'lmstate': lmstate
            }]
            for t in range(elens[b]):
                # batchfy all hypotheses for batch decoding
                douts = torch.cat([beam['dout'] for beam in hyps], dim=0)
                logits = self.joint(
                    eouts[b:b + 1, t:t + 1].repeat([douts.size(0), 1, 1]),
                    douts)
                scores_rnnt = torch.log_softmax(logits.squeeze(2).squeeze(1),
                                                dim=-1)  # `[B, vocab]`

                new_hyps = []
                for j, beam in enumerate(hyps):
                    # Transducer scores
                    total_scores_rnnt = beam['score_rnnt'] + scores_rnnt[j:j +
                                                                         1]
                    total_scores_topk, topk_ids = torch.topk(total_scores_rnnt,
                                                             k=beam_width,
                                                             dim=-1,
                                                             largest=True,
                                                             sorted=True)

                    for k in range(beam_width):
                        idx = topk_ids[0, k].item()
                        total_score = total_scores_topk[0, k].item()
                        total_score_lm = beam['score_lm']

                        if idx == self.blank:
                            new_hyps.append(beam.copy())
                            new_hyps[-1]['score'] += scores_rnnt[
                                j, self.blank].item()
                            new_hyps[-1]['score_rnnt'] += scores_rnnt[
                                j, self.blank].item()
                            continue

                        # Update prediction network only when predicting non-blank labels
                        hyp_ids = beam['hyp'] + [idx]
                        hyp_ids_str = ' '.join(list(map(str, hyp_ids)))
                        if hyp_ids_str in self.state_cache.keys():
                            # from cache
                            dout = self.state_cache[hyp_ids_str]['dout']
                            dstate = self.state_cache[hyp_ids_str]['dstate']
                            lmstate = self.state_cache[hyp_ids_str]['lmstate']
                            total_score_lm = self.state_cache[hyp_ids_str][
                                'total_score_lm']
                        else:
                            y = eouts.new_zeros((1, 1),
                                                dtype=torch.int64).fill_(idx)
                            y_emb = self.dropout_emb(self.embed(y))
                            dout, dstate = self.recurrency(
                                y_emb, beam['dstate'])

                            # Update LM states for shallow fusion
                            y_prev = eouts.new_zeros(
                                (1, 1),
                                dtype=torch.int64).fill_(beam['hyp'][-1])
                            _, lmstate, scores_lm = helper.update_rnnlm_state(
                                lm, beam, y_prev)
                            if lm is not None:
                                total_score_lm += scores_lm[0, -1, idx].item()

                            self.state_cache[hyp_ids_str] = {
                                'dout': dout,
                                'dstate': dstate,
                                'lmstate': {
                                    'hxs': lmstate['hxs'],
                                    'cxs': lmstate['cxs']
                                } if lmstate is not None else None,
                                'total_score_lm': total_score_lm,
                            }

                        if lm is not None:
                            total_score += total_score_lm * lm_weight

                        new_hyps.append({
                            'hyp':
                            hyp_ids,
                            'hyp_ids_str':
                            hyp_ids_str,
                            'score':
                            total_score,
                            'score_rnnt':
                            total_scores_rnnt[0, idx].item(),
                            'score_lm':
                            total_score_lm,
                            'dout':
                            dout,
                            'dstate':
                            dstate,
                            'lmstate': {
                                'hxs': lmstate['hxs'],
                                'cxs': lmstate['cxs']
                            } if lmstate is not None else None
                        })

                # Local pruning
                new_hyps_sorted = sorted(new_hyps,
                                         key=lambda x: x['score'],
                                         reverse=True)
                new_hyps_sorted = helper.merge_rnnt_path(
                    new_hyps_sorted, merge_prob)[:beam_width]

                # Remove complete hypotheses
                new_hyps, end_hyps, is_finish = helper.remove_complete_hyp(
                    new_hyps_sorted, end_hyps)
                hyps = new_hyps[:]
                if is_finish:
                    break

            # Global pruning
            if len(end_hyps) == 0:
                end_hyps = hyps[:]
            elif len(end_hyps) < nbest and nbest > 1:
                end_hyps.extend(hyps[:nbest - len(end_hyps)])

            # forward second path LM rescoring
            helper.lm_rescoring(end_hyps,
                                lm_second,
                                lm_weight_second,
                                tag='second')

            # backward second path LM rescoring
            helper.lm_rescoring(end_hyps,
                                lm_second_bwd,
                                lm_weight_second_bwd,
                                tag='second_bwd')

            # Sort by score
            end_hyps = sorted(
                end_hyps,
                key=lambda x: x['score'] / max(len(x['hyp'][1:]), 1),
                reverse=True)

            # Reset state cache
            self.state_cache = OrderedDict()

            if idx2token is not None:
                if utt_ids is not None:
                    logger.info('Utt-id: %s' % utt_ids[b])
                assert self.vocab == idx2token.vocab
                logger.info('=' * 200)
                for k in range(len(end_hyps)):
                    if refs_id is not None:
                        logger.info('Ref: %s' % idx2token(refs_id[b]))
                    logger.info('Hyp: %s' % idx2token(end_hyps[k]['hyp'][1:]))
                    logger.info('log prob (hyp): %.7f' % end_hyps[k]['score'])
                    logger.info('log prob (hyp, rnnt): %.7f' %
                                end_hyps[k]['score_rnnt'])
                    if lm is not None:
                        logger.info('log prob (hyp, first-path lm): %.7f' %
                                    (end_hyps[k]['score_lm'] * lm_weight))
                    if lm_second is not None:
                        logger.info('log prob (hyp, second-path lm): %.7f' %
                                    (end_hyps[k]['score_lm_second'] *
                                     lm_weight_second))
                    if lm_second_bwd is not None:
                        logger.info(
                            'log prob (hyp, second-path lm, reverse): %.7f' %
                            (end_hyps[k]['score_lm_second_bwd'] *
                             lm_weight_second_bwd))
                    logger.info('-' * 50)

            # N-best list
            nbest_hyps_idx += [[
                np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest)
            ]]

            # Check <eos>
            eos_flags.append([(end_hyps[n]['hyp'][-1] == self.eos)
                              for n in range(nbest)])

        return nbest_hyps_idx, None, None