示例#1
0
    def encode(self, xs, task='all', flip=False):
        """Encode acoustic or text features.

        Args:
            xs (list): A list of length `[B]`, which contains Tensor of size `[T, input_dim]`
            task (str): all or ys or ys_sub*
            flip (bool): if True, flip acoustic features in the time-dimension
        Returns:
            enc_outs (dict):
            perm_ids ():

        """
        if 'lmobj' in task:
            eouts = {'ys': {'xs': None, 'xlens': None},
                     'ys.ctc': {'xs': None, 'xlens': None},
                     'ys_sub1': {'xs': None, 'xlens': None},
                     'ys_sub1.ctc': {'xs': None, 'xlens': None},
                     'ys_sub2': {'xs': None, 'xlens': None},
                     'ys_sub2.ctc': {'xs': None, 'xlens': None}}
            return eouts, None
        else:
            # Sort by lenghts in the descending order
            perm_ids = sorted(list(range(0, len(xs), 1)),
                              key=lambda i: len(xs[i]), reverse=True)
            xs = [xs[i] for i in perm_ids]
            # NOTE: must be descending order for pack_padded_sequence

            if self.input_type == 'speech':
                # Frame stacking
                if self.nstacks > 1:
                    xs = [stack_frame(x, self.nstacks, self.nskips)for x in xs]

                # Splicing
                if self.nsplices > 1:
                    xs = [splice(x, self.nsplices, self.nstacks) for x in xs]

                xlens = [len(x) for x in xs]
                # Flip acoustic features in the reverse order
                if flip:
                    xs = [torch.from_numpy(np.flip(x, axis=0).copy()).float().cuda(self.device_id) for x in xs]
                else:
                    xs = [np2tensor(x, self.device_id).float() for x in xs]
                xs = pad_list(xs)

            elif self.input_type == 'text':
                xlens = [len(x) for x in xs]
                xs = [np2tensor(np.fromiter(x, dtype=np.int64), self.device_id).long() for x in xs]
                xs = pad_list(xs, self.pad)
                xs = self.embed_in(xs)

            enc_outs = self.enc(xs, xlens, task)

            if self.main_weight < 1 and self.enc_type == 'cnn':
                for sub in ['sub1', 'sub2']:
                    enc_outs['ys_' + sub]['xs'] = enc_outs['ys']['xs'].clone()
                    enc_outs['ys_' + sub]['xlens'] = copy.deepcopy(enc_outs['ys']['xlens'])

            # Bridge between the encoder and decoder
            if self.main_weight > 0 and (self.enc_type == 'cnn' or self.bridge_layer) and (task in ['all', 'ys']):
                enc_outs['ys']['xs'] = self.bridge(enc_outs['ys']['xs'])
            if self.sub1_weight > 0 and (self.enc_type == 'cnn' or self.bridge_layer) and (task in ['all', 'ys_sub1']):
                enc_outs['ys_sub1']['xs'] = self.bridge_sub1(enc_outs['ys_sub1']['xs'])
            if self.sub2_weight > 0 and (self.enc_type == 'cnn' or self.bridge_layer)and (task in ['all', 'ys_sub2']):
                enc_outs['ys_sub2']['xs'] = self.bridge_sub2(enc_outs['ys_sub2']['xs'])

            return enc_outs, perm_ids
示例#2
0
    def beam_search(self,
                    eouts,
                    elens,
                    params,
                    idx2token,
                    lm=None,
                    lm_rev=None,
                    ctc_log_probs=None,
                    nbest=1,
                    exclude_eos=False,
                    refs_id=None,
                    utt_ids=None,
                    speakers=None,
                    ensmbl_eouts=None,
                    ensmbl_elens=None,
                    ensmbl_decs=[]):
        """Beam search decoding.

        Args:
            eouts (FloatTensor): `[B, T, dec_n_units]`
            elens (IntTensor): `[B]`
            params (dict):
                recog_beam_width (int): size of hyp
                recog_max_len_ratio (int): maximum sequence length of tokens
                recog_min_len_ratio (float): minimum sequence length of tokens
                recog_length_penalty (float): length penalty
                recog_coverage_penalty (float): coverage penalty
                recog_coverage_threshold (float): threshold for coverage penalty
                recog_lm_weight (float): weight of LM score
                recog_n_caches (int):
            idx2token (): converter from index to token
            lm (RNNLM or GatedConvLM or TransformerLM):
            lm_rev (RNNLM or GatedConvLM or TransformerLM):
            ctc_log_probs (FloatTensor):
            nbest (int):
            exclude_eos (bool):
            refs_id (list):
            utt_ids (list):
            speakers (list):
            ensmbl_eouts (list): list of FloatTensor
            ensmbl_elens (list) list of list
            ensmbl_decs (list): list of torch.nn.Module
        Returns:
            nbest_hyps_idx (list): A list of length `[B]`, which contains list of N hypotheses
            aws: dummy
            scores: dummy
            cache_info: dummy

        """
        logger = logging.getLogger("decoding")

        bs = eouts.size(0)
        best_hyps = []

        oracle = params['recog_oracle']
        beam_width = params['recog_beam_width']
        ctc_weight = params['recog_ctc_weight']
        lm_weight = params['recog_lm_weight']
        asr_state_carry_over = params['recog_asr_state_carry_over']
        lm_state_carry_over = params['recog_lm_state_carry_over']
        lm_usage = params['recog_lm_usage']

        if lm is not None:
            lm.eval()

        for b in range(bs):
            # Initialization
            y = eouts.new_zeros(bs, 1).fill_(self.eos)
            dout, dstate = self.recurrency(self.embed(y), None)
            lmstate = None

            if lm_state_carry_over:
                lmstate = self.lmstate_final
            self.prev_spk = speakers[b]

            end_hyps = []
            hyps = [{
                'hyp': [self.eos],
                'lattice': [],
                'ref_id': [self.eos],
                'score': 0.0,
                'score_lm': 0.0,
                'score_ctc': 0.0,
                'dout': dout,
                'dstate': dstate,
                'lmstate': lmstate,
            }]
            for t in range(elens[b]):
                new_hyps = []
                for hyp in hyps:
                    prev_idx = ([self.eos] +
                                refs_id[b])[t] if oracle else hyp['hyp'][-1]
                    score = hyp['score']
                    score_lm = hyp['score_lm']
                    dout = hyp['dout']
                    dstate = hyp['dstate']
                    lmstate = hyp['lmstate']

                    # Pick up the top-k scores
                    out = self.joint(eouts[b:b + 1, t:t + 1], dout.squeeze(1))
                    log_probs = F.log_softmax(out.squeeze(2), dim=-1)
                    log_probs_topk, topk_ids = torch.topk(log_probs[0, 0],
                                                          k=min(
                                                              beam_width,
                                                              self.vocab),
                                                          dim=-1,
                                                          largest=True,
                                                          sorted=True)

                    for k in range(beam_width):
                        idx = topk_ids[k].item()
                        score += log_probs_topk[k].item()

                        # Update prediction network only when predicting non-blank labels
                        lattice = hyp['lattice'] + [idx]
                        if idx == self.blank:
                            hyp_id = hyp['hyp']
                        else:
                            hyp_id = hyp['hyp'] + [idx]
                            hyp_str = ' '.join(list(map(str, hyp_id[1:])))
                            if hyp_str in self.state_cache.keys():
                                # from cache
                                dout = self.state_cache[hyp_str]['dout']
                                new_dstate = self.state_cache[hyp_str][
                                    'dstate']
                            else:
                                if oracle:
                                    y = eouts.new_zeros(1, 1).fill_(
                                        refs_id[b, len(hyp_id) - 1])
                                else:
                                    y = eouts.new_zeros(1, 1).fill_(idx)
                                dout, new_dstate = self.recurrency(
                                    self.embed(y), dstate)

                                # Update LM states for shallow fusion
                                if lm_weight > 0 and lm is not None:
                                    _, lmstate, lm_log_probs = lm.predict(
                                        eouts.new_zeros(1, 1).fill_(prev_idx),
                                        hyp['lmstate'])
                                    local_score_lm = lm_log_probs[0,
                                                                  idx].item()
                                    score_lm += local_score_lm * lm_weight
                                    score += local_score_lm * lm_weight

                                # to cache
                                self.state_cache[hyp_str] = {
                                    'lattice': lattice,
                                    'dout': dout,
                                    'dstate': new_dstate,
                                    'lmstate': lmstate,
                                }

                        new_hyps.append({
                            'hyp':
                            hyp_id,
                            'lattice':
                            lattice,
                            'score':
                            score,
                            'score_lm':
                            score_lm,
                            'score_ctc':
                            0,  # TODO(hirofumi):
                            'dout':
                            dout,
                            'dstate':
                            dstate if idx == self.blank else new_dstate,
                            'lmstate':
                            lmstate,
                        })

                # Local pruning
                new_hyps_tmp = sorted(new_hyps,
                                      key=lambda x: x['score'],
                                      reverse=True)[:beam_width]

                # Remove complete hypotheses
                new_hyps = []
                for hyp in new_hyps_tmp:
                    if oracle:
                        if t == len(refs_id[b]):
                            end_hyps += [hyp]
                        else:
                            new_hyps += [hyp]
                    else:
                        if self.end_pointing and hyp['hyp'][-1] == self.eos:
                            end_hyps += [hyp]
                        else:
                            new_hyps += [hyp]
                if len(end_hyps) >= beam_width:
                    end_hyps = end_hyps[:beam_width]
                    logger.info('End-pointed at %d / %d frames' %
                                (t, elens[b]))
                    break
                hyps = new_hyps[:]

            # Rescoing lattice
            if lm_weight > 0 and lm is not None and lm_usage == 'rescoring':
                new_hyps = []
                for hyp in hyps:
                    ys = [
                        np2tensor(np.fromiter(hyp['hyp'], dtype=np.int64),
                                  self.device_id)
                    ]
                    ys_pad = pad_list(ys, lm.pad)
                    _, _, lm_log_probs = lm.predict(ys_pad, None)
                    score_ctc = 0  # TODO(hirofumi):
                    score_lm = lm_log_probs.sum() * lm_weight
                    new_hyps.append({
                        'hyp': hyp['hyp'],
                        'score': hyp['score'] + score_lm,
                        'score_ctc': score_ctc,
                        'score_lm': score_lm
                    })
                hyps = sorted(new_hyps, key=lambda x: x['score'], reverse=True)

            # Exclude <eos>
            if False and exclude_eos and self.end_pointing and hyps[0]['hyp'][
                    -1] == self.eos:
                best_hyps.append([hyps[0]['hyp'][1:-1]])
            else:
                best_hyps.append([hyps[0]['hyp'][1:]])

            # Reset state cache
            self.state_cache = OrderedDict()

            if utt_ids is not None:
                logger.info('Utt-id: %s' % utt_ids[b])
            if refs_id is not None and self.vocab == idx2token.vocab:
                logger.info('Ref: %s' % idx2token(refs_id[b]))
            logger.info('Hyp: %s' % idx2token(hyps[0]['hyp'][1:]))
            logger.info('log prob (hyp): %.7f' % hyps[0]['score'])
            if ctc_weight > 0 and ctc_log_probs is not None:
                logger.info('log prob (hyp, ctc): %.7f' %
                            (hyps[0]['score_ctc']))
            # logger.info('log prob (lp): %.7f' % hyps[0]['score_lp'])
            if lm_weight > 0 and lm is not None:
                logger.info('log prob (hyp, lm): %.7f' % (hyps[0]['score_lm']))

        return np.array(best_hyps), None, None, None
示例#3
0
def test_forward_streaming_chunkwise(args):
    args = make_args(**args)
    assert args['chunk_size_left'] > 0
    unidir = args['rnn_type'] in ['conv_lstm', 'conv_gru', 'lstm', 'gru']

    batch_size = 1
    xmaxs = [t for t in range(160, 192, 1)]
    device_id = -1
    N_l = max(0, args['chunk_size_left']) // args['n_stacks']
    N_r = max(0, args['chunk_size_right']) // args['n_stacks']
    if unidir:
        args['chunk_size_left'] = 0
        args['chunk_size_right'] = 0
    module = importlib.import_module('neural_sp.models.seq2seq.encoders.rnn')
    enc = module.RNNEncoder(**args)

    factor = enc.subsampling_factor
    lookback = enc.conv.n_frames_context if enc.conv is not None else 0
    lookahead = enc.conv.n_frames_context if enc.conv is not None else 0

    module_fs = importlib.import_module(
        'neural_sp.models.seq2seq.frontends.frame_stacking')

    if enc.conv is not None:
        enc.turn_off_ceil_mode(enc)

    enc.eval()
    with torch.no_grad():
        for xmax in xmaxs:
            xs = np.random.randn(batch_size, xmax,
                                 args['input_dim']).astype(np.float32)

            if args['n_stacks'] > 1:
                xs = [
                    module_fs.stack_frame(x, args['n_stacks'],
                                          args['n_stacks']) for x in xs
                ]

            xlens = torch.IntTensor([len(x) for x in xs])
            xmax = xlens.max().item()

            # all inputs
            xs_pad = pad_list([np2tensor(x, device_id).float() for x in xs],
                              0.)

            enc_out_dict = enc(xs_pad, xlens, task='all')
            assert enc_out_dict['ys']['xs'].size(0) == batch_size
            assert enc_out_dict['ys']['xs'].size(
                1) == enc_out_dict['ys']['xlens'][0]

            enc.reset_cache()

            # chunk by chunk encoding
            eouts_stream = []
            n_chunks = math.ceil(xmax / N_l)
            j = 0  # time offset for input
            j_out = 0  # time offset for encoder output
            for chunk_idx in range(n_chunks):
                start = j - lookback
                end = (j + N_l + N_r) + lookahead
                xs_pad_stream = pad_list([
                    np2tensor(x[max(0, start):end], device_id).float()
                    for x in xs
                ], 0.)
                xlens_stream = torch.IntTensor(
                    [xs_pad_stream.size(1) for x in xs])
                enc_out_dict_stream = enc(xs_pad_stream,
                                          xlens_stream,
                                          task='all',
                                          streaming=True,
                                          lookback=start > 0,
                                          lookahead=end < xmax - 1)

                a = enc_out_dict['ys']['xs'][:, j_out:j_out + (N_l // factor)]
                b = enc_out_dict_stream['ys']['xs']
                b = b[:, :a.size(1)]
                for t in range(a.size(1)):
                    print(torch.equal(a[:, t], b[:, t]))
                eouts_stream.append(b)

                j += N_l
                j_out += (N_l // factor)
                if j > xmax:
                    break

            enc.reset_cache()

            eouts_stream = torch.cat(eouts_stream, dim=1)
            assert enc_out_dict['ys']['xs'].size() == eouts_stream.size()
            assert torch.equal(enc_out_dict['ys']['xs'], eouts_stream)
def test_decoding(backward, params):
    args = make_args()
    params = make_decode_params(**params)
    params['backward'] = backward

    batch_size = params['recog_batch_size']
    emax = 40
    device = "cpu"

    eouts = np.random.randn(batch_size, emax, ENC_N_UNITS).astype(np.float32)
    elens = torch.IntTensor([len(x) for x in eouts])
    eouts = pad_list([np2tensor(x, device).float() for x in eouts], 0.)
    ctc_log_probs = None
    if params['recog_ctc_weight'] > 0:
        ctc_logits = torch.FloatTensor(batch_size, emax, VOCAB, device=device)
        ctc_log_probs = torch.softmax(ctc_logits, dim=-1)
    lm = None
    if params['recog_lm_weight'] > 0:
        args_lm = make_args_rnnlm()
        module = importlib.import_module('neural_sp.models.lm.rnnlm')
        lm = module.RNNLM(args_lm).to(device)
    lm_second = None
    if params['recog_lm_second_weight'] > 0:
        args_lm = make_args_rnnlm()
        module = importlib.import_module('neural_sp.models.lm.rnnlm')
        lm_second = module.RNNLM(args_lm).to(device)
    lm_second_bwd = None
    if params['recog_lm_bwd_weight'] > 0:
        args_lm = make_args_rnnlm()
        module = importlib.import_module('neural_sp.models.lm.rnnlm')
        lm_second_bwd = module.RNNLM(args_lm).to(device)

    ylens = [4, 5, 3, 7]
    ys = [np.random.randint(0, VOCAB, ylen).astype(np.int32) for ylen in ylens]

    module = importlib.import_module(
        'neural_sp.models.seq2seq.decoders.transformer')
    dec = module.TransformerDecoder(**args)
    dec = dec.to(device)

    # TODO(hirofumi0810):
    # recog_lm_state_carry_over

    dec.eval()
    with torch.no_grad():
        if params['recog_beam_width'] == 1:
            out = dec.greedy(eouts,
                             elens,
                             max_len_ratio=1.0,
                             idx2token=None,
                             exclude_eos=params['exclude_eos'],
                             refs_id=ys,
                             utt_ids=None,
                             speakers=None,
                             cache_states=params['cache_states'])
            assert len(out) == 2
            hyps, aws = out
            assert isinstance(hyps, list)
            assert len(hyps) == batch_size
            assert isinstance(aws, list)
            assert aws[0].shape == (args['n_heads'] * args['n_layers'],
                                    len(hyps[0]), emax)
        else:
            out = dec.beam_search(eouts,
                                  elens,
                                  params,
                                  idx2token=None,
                                  lm=lm,
                                  lm_second=lm_second,
                                  lm_second_bwd=lm_second_bwd,
                                  ctc_log_probs=ctc_log_probs,
                                  nbest=params['nbest'],
                                  exclude_eos=params['exclude_eos'],
                                  refs_id=None,
                                  utt_ids=None,
                                  speakers=None,
                                  cache_states=params['cache_states'])
            assert len(out) == 3
            nbest_hyps, aws, scores = out
            assert isinstance(nbest_hyps, list)
            assert len(nbest_hyps) == batch_size
            assert len(nbest_hyps[0]) == params['nbest']
            ymax = len(nbest_hyps[0][0])
            assert isinstance(aws, list)
            assert aws[0][0].shape == (args['n_heads'] * args['n_layers'],
                                       ymax, emax)
            assert isinstance(scores, list)
            assert len(scores) == batch_size
            assert len(scores[0]) == params['nbest']

            # ensemble
            ensmbl_eouts, ensmbl_elens, ensmbl_decs = [], [], []
            for _ in range(3):
                ensmbl_eouts += [eouts]
                ensmbl_elens += [elens]
                ensmbl_decs += [dec]

            out = dec.beam_search(eouts,
                                  elens,
                                  params,
                                  idx2token=None,
                                  lm=lm,
                                  lm_second=lm_second,
                                  lm_second_bwd=lm_second_bwd,
                                  ctc_log_probs=ctc_log_probs,
                                  nbest=params['nbest'],
                                  exclude_eos=params['exclude_eos'],
                                  refs_id=None,
                                  utt_ids=None,
                                  speakers=None,
                                  ensmbl_eouts=ensmbl_eouts,
                                  ensmbl_elens=ensmbl_elens,
                                  ensmbl_decs=ensmbl_decs,
                                  cache_states=params['cache_states'])
            assert len(out) == 3
            nbest_hyps, aws, scores = out
            assert isinstance(nbest_hyps, list)
            assert len(nbest_hyps) == batch_size
            assert len(nbest_hyps[0]) == params['nbest']
            ymax = len(nbest_hyps[0][0])
            assert isinstance(aws, list)
            assert aws[0][0].shape == (args['n_heads'] * args['n_layers'],
                                       ymax, emax)
            assert isinstance(scores, list)
            assert len(scores) == batch_size
            assert len(scores[0]) == params['nbest']
示例#5
0
文件: ctc.py 项目: rxhmdia/neural_sp
    def beam_search(self,
                    eouts,
                    elens,
                    params,
                    idx2token,
                    lm=None,
                    lm_second=None,
                    lm_second_rev=None,
                    nbest=1,
                    refs_id=None,
                    utt_ids=None,
                    speakers=None):
        """Beam search decoding.

        Args:
            eouts (FloatTensor): `[B, T, enc_n_units]`
            elens (list): length `B`
            params (dict):
                recog_beam_width (int): size of beam
                recog_length_penalty (float): length penalty
                recog_lm_weight (float): weight of first path LM score
                recog_lm_second_weight (float): weight of second path LM score
                recog_lm_bwd_weight (float): weight of second path backward LM score
            idx2token (): converter from index to token
            lm: firsh path LM
            lm_second: second path LM
            lm_second_rev: secoding path backward LM
            nbest (int):
            refs_id (list): reference list
            utt_ids (list): utterance id list
            speakers (list): speaker list
        Returns:
            best_hyps (list): Best path hypothesis. `[B, L]`

        """
        bs = eouts.size(0)

        beam_width = params['recog_beam_width']
        lp_weight = params['recog_length_penalty']
        lm_weight = params['recog_lm_weight']
        lm_weight_second = params['recog_lm_second_weight']

        if lm is not None:
            assert lm_weight > 0
            lm.eval()
        if lm_second is not None:
            assert lm_weight_second > 0
            lm_second.eval()

        best_hyps = []
        log_probs = torch.log_softmax(self.output(eouts), dim=-1)
        for b in range(bs):
            # Elements in the beam are (prefix, (p_b, p_no_blank))
            # Initialize the beam with the empty sequence, a probability of
            # 1 for ending in blank and zero for ending in non-blank (in log space).
            beam = [{
                'hyp': [self.eos],  # <eos> is used for LM
                'p_b': LOG_1,
                'p_nb': LOG_0,
                'score_lm': LOG_1,
                'lmstate': None
            }]

            for t in range(elens[b]):
                new_beam = []

                # Pick up the top-k scores
                log_probs_topk, topk_ids = torch.topk(log_probs[b:b + 1, t],
                                                      k=min(
                                                          beam_width,
                                                          self.vocab),
                                                      dim=-1,
                                                      largest=True,
                                                      sorted=True)

                for i_beam in range(len(beam)):
                    hyp = beam[i_beam]['hyp'][:]
                    p_b = beam[i_beam]['p_b']
                    p_nb = beam[i_beam]['p_nb']
                    score_lm = beam[i_beam]['score_lm']

                    # case 1. hyp is not extended
                    new_p_b = np.logaddexp(
                        p_b + log_probs[b, t, self.blank].item(),
                        p_nb + log_probs[b, t, self.blank].item())
                    if len(hyp) > 1:
                        new_p_nb = p_nb + log_probs[b, t, hyp[-1]].item()
                    else:
                        new_p_nb = LOG_0
                    score_ctc = np.logaddexp(new_p_b, new_p_nb)
                    score_lp = len(hyp[1:]) * lp_weight
                    new_beam.append({
                        'hyp': hyp,
                        'score': score_ctc + score_lm + score_lp,
                        'p_b': new_p_b,
                        'p_nb': new_p_nb,
                        'score_ctc': score_ctc,
                        'score_lm': score_lm,
                        'score_lp': score_lp,
                        'lmstate': beam[i_beam]['lmstate']
                    })

                    # Update LM states for shallow fusion
                    if lm is not None:
                        _, lmstate, lm_log_probs = lm.predict(
                            eouts.new_zeros(1, 1).fill_(hyp[-1]),
                            beam[i_beam]['lmstate'])
                    else:
                        lmstate = None

                    # case 2. hyp is extended
                    new_p_b = LOG_0
                    for c in tensor2np(topk_ids)[0]:
                        p_t = log_probs[b, t, c].item()

                        if c == self.blank:
                            continue

                        c_prev = hyp[-1] if len(hyp) > 1 else None
                        if c == c_prev:
                            new_p_nb = p_b + p_t
                            # TODO(hirofumi): apply character LM here
                        else:
                            new_p_nb = np.logaddexp(p_b + p_t, p_nb + p_t)
                            # TODO(hirofumi): apply character LM here
                            if c == self.space:
                                pass
                                # TODO(hirofumi): apply word LM here

                        score_ctc = np.logaddexp(new_p_b, new_p_nb)
                        score_lp = (len(hyp[1:]) + 1) * lp_weight
                        if lm_weight > 0 and lm is not None:
                            local_score_lm = lm_log_probs[0, 0,
                                                          c].item() * lm_weight
                            score_lm += local_score_lm
                        new_beam.append({
                            'hyp': hyp + [c],
                            'score': score_ctc + score_lm + score_lp,
                            'p_b': new_p_b,
                            'p_nb': new_p_nb,
                            'score_ctc': score_ctc,
                            'score_lm': score_lm,
                            'score_lp': score_lp,
                            'lmstate': lmstate
                        })

                # Pruning
                beam = sorted(new_beam, key=lambda x: x['score'],
                              reverse=True)[:beam_width]

            # Rescoing lattice
            if lm_second is not None:
                new_beam = []
                for i_beam in range(len(beam)):
                    ys = [
                        np2tensor(
                            np.fromiter(beam[i_beam]['hyp'], dtype=np.int64),
                            self.device_id)
                    ]
                    ys_pad = pad_list(ys, lm_second.pad)
                    _, _, lm_log_probs = lm_second.predict(ys_pad, None)
                    score_ctc = np.logaddexp(beam[i_beam]['p_b'],
                                             beam[i_beam]['p_nb'])
                    score_lm = lm_log_probs.sum() * lm_weight_second
                    score_lp = len(beam[i_beam]['hyp'][1:]) * lp_weight
                    new_beam.append({
                        'hyp': beam[i_beam]['hyp'],
                        'score': score_ctc + score_lm + score_lp,
                        'score_ctc': score_ctc,
                        'score_lp': score_lp,
                        'score_lm': score_lm
                    })
                beam = sorted(new_beam, key=lambda x: x['score'], reverse=True)

            best_hyps.append(np.array(beam[0]['hyp'][1:]))

            if idx2token is not None:
                if utt_ids is not None:
                    logger.info('Utt-id: %s' % utt_ids[b])
                assert self.vocab == idx2token.vocab
                logger.info('=' * 200)
                for k in range(len(beam)):
                    if refs_id is not None:
                        logger.info('Ref: %s' % idx2token(refs_id[b]))
                    logger.info('Hyp: %s' % idx2token(beam[k]['hyp'][1:]))
                    logger.info('log prob (hyp): %.7f' % beam[k]['score'])
                    logger.info('log prob (hyp, ctc): %.7f' %
                                (beam[k]['score_ctc']))
                    logger.info('log prob (hyp, lp): %.7f' %
                                (beam[k]['score_lp'] * lp_weight))
                    if lm is not None:
                        logger.info('log prob (hyp, first-path lm): %.7f' %
                                    (beam[k]['score_lm'] * lm_weight))
                    if lm_second is not None:
                        logger.info(
                            'log prob (hyp, second-path lm): %.7f' %
                            (beam[k]['score_lm_second'] * lm_weight_second))
                    logger.info('-' * 50)

        return np.array(best_hyps)