コード例 #1
0
ファイル: decoder.py プロジェクト: zhy6599/speech-recognize
def beam_search_decode_with_ctc(x, han_vocab, sess):
    enc, lpz = sess.run([return_tensors[3], return_tensors[4]], {return_tensors[0]: np.array([x])})
    lpz = np.squeeze(lpz, axis=0)

    if ctc_weight == 0.0:
        lpz =None

    maxlen = enc.shape[1]
    minlen = 0

    eos = han_vocab.index('<EOS>')

    hyp = {'score': 0.0, 'yseq': [eos]}

    if lpz is not None:
        ctc_prefix_score = CTCPrefixScore(lpz, han_vocab.index('*'), eos)
        hyp['ctc_state_prev'] = ctc_prefix_score.initial_state()
        hyp['ctc_score_prev'] = 0.0
        # print(hyp['ctc_state_prev'])

        if ctc_weight != 1.0:
            ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO))
        else:
            ctc_beam = lpz.shape[-1]

    hyps = [hyp]
    ended_hyps = []
    for i in range(maxlen):
        hyps_best_kept = []

        max_len = max(len(hyp['yseq']) for hyp in hyps)
        ys = np.ones(shape=(len(hyps),max_len))*eos
        for k, hyp in enumerate(hyps):
            ys[k, :len(hyp['yseq'])] =hyp['yseq']

        local_att_scores_all = sess.run(return_tensors[5], {return_tensors[3]: np.tile(enc,(len(hyps),1,1)),
                                                            return_tensors[1]: ys,
                                                            return_tensors[2]: np.tile(np.array([[maxlen]]),(len(hyps),1))})

        # * 是ctc占位符 语言模型中不会出现
        local_att_scores_all[:, han_vocab.index('*')] = -10000000000.0
        local_scores_all = local_att_scores_all

        if lpz is not None:
            local_best_scores_all, local_best_ids_all = sess.run([score, ids],
                                                                 {topk_input: local_att_scores_all,
                                                                  beam_size: ctc_beam})

            ctc_state_prev = np.array([hyp['ctc_state_prev'] for hyp in hyps])
            ctc_score_prev = np.array([hyp['ctc_score_prev'] for hyp in hyps])
            ctc_score_prev = ctc_score_prev[:, np.newaxis]

            ctc_scores_all, ctc_states_all = ctc_prefix_score(
                    ys, local_best_ids_all, ctc_state_prev)

            local_att_scores_all_now = local_att_scores_all[
                np.arange(0, local_best_ids_all.shape[0])[:, np.newaxis], local_best_ids_all]

            local_scores_all = (1.0 - ctc_weight) * local_att_scores_all_now \
                               + ctc_weight * (ctc_scores_all - ctc_score_prev)


            local_best_scores_all, joint_best_ids_all = sess.run([score, ids],
                                                                 {topk_input: local_scores_all, beam_size: beam})
            local_best_ids_all = local_best_ids_all[
                np.arange(0, joint_best_ids_all.shape[0])[:, np.newaxis], joint_best_ids_all]

        else:
            local_best_scores_all, local_best_ids_all = sess.run([score, ids],
                                                                 {topk_input: local_scores_all,
                                                           beam_size: beam})
        for k, hyp in enumerate(hyps):
            for j in range(beam):
                new_hyp = {}
                new_hyp['score'] = hyp['score'] + float(local_best_scores_all[k, j])
                new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
                new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
                new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids_all[k, j])
                if lpz is not None:
                    new_hyp['ctc_state_prev'] = ctc_states_all[k, joint_best_ids_all[k, j]]
                    new_hyp['ctc_score_prev'] = ctc_scores_all[k, joint_best_ids_all[k, j]]
                hyps_best_kept.append(new_hyp)
            hyps_best_kept = sorted(
                hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam]
        hyps = hyps_best_kept

        if i == maxlen - 1:
            for hyp in hyps:
                if hyp['yseq'][-1] != eos:
                    hyp['yseq'].append(eos)

        remained_hyps = []
        for hyp in hyps:
            if hyp['yseq'][-1] == eos:
                if len(hyp['yseq']) > minlen:
                    hyp['score'] += (i+1) * penalty
                    ended_hyps.append(hyp)
            else:
                remained_hyps.append(hyp)
        if end_detect(ended_hyps, i):
            break
        hyps = remained_hyps
        if len(hyps) <= 0:
            break
    ended_hyps=sorted(
        ended_hyps, key=lambda x: x['score'], reverse=True)
    lenth = len(ended_hyps)
    if lenth == 0:
        return []
    index_l = [0]
    index_r = [len(ended_hyps)]
    for i in range(1,lenth):
        diff = len(ended_hyps[i]['yseq'])-len(ended_hyps[i-1]['yseq'])
        # 相邻分值差不多的 应该字数差不多  干掉字少的
        if abs(diff) > 3:
            if diff < 0:
                index_r.append(i)
            else:
                index_l.append(i)
        # 相邻分值相差很大 干掉前面分值大的 越大字越少
        elif ended_hyps[i]['score'] - ended_hyps[i - 1]['score'] < -10:
            index_r.append(i)
    m_l = max(index_l)
    m_r = min(index_r)

    if m_l >= m_r:
        m_r = lenth
    ended_hyps = ended_hyps[m_l:m_r]

    a_s = [i['score'] for i in ended_hyps]
    a_seq = [[han_vocab[b] for b in i['yseq']] for i in ended_hyps]
    for a, b in zip(a_s, a_seq):
        print(a, b)
    nbest_hyps = sorted(
        ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), nbest)]
    if len(nbest_hyps) == 0:
        return []
    result = [han_vocab[i] for i in nbest_hyps[0]['yseq'][1:-1]]
    return result
コード例 #2
0
ファイル: e2e_asr.py プロジェクト: zane678/espnet
    def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None):
        '''beam search implementation

        :param h:
        :param recog_args:
        :param char_list:
        :return:
        '''
        logging.info('input lengths: ' + str(h.shape[0]))
        # initialization
        c_list = [None]  # list of cell state of each layer
        z_list = [None]  # list of hidden state of each layer
        for l in six.moves.range(1, self.dlayers):
            c_list.append(None)
            z_list.append(None)
        a = None
        self.att.reset()  # reset pre-computation of h

        # search parms
        beam = recog_args.beam_size
        penalty = recog_args.penalty
        ctc_weight = recog_args.ctc_weight

        # preprate sos
        y = self.xp.full(1, self.sos, 'i')
        if recog_args.maxlenratio == 0:
            maxlen = h.shape[0]
        else:
            # maxlen >= 1
            maxlen = max(1, int(recog_args.maxlenratio * h.shape[0]))
        minlen = int(recog_args.minlenratio * h.shape[0])
        logging.info('max output length: ' + str(maxlen))
        logging.info('min output length: ' + str(minlen))

        # initialize hypothesis
        if rnnlm:
            hyp = {'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a, 'rnnlm_prev': None}
        else:
            hyp = {'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a}
        if lpz is not None:
            ctc_prefix_score = CTCPrefixScore(lpz, 0, self.eos, self.xp)
            hyp['ctc_state_prev'] = ctc_prefix_score.initial_state()
            hyp['ctc_score_prev'] = 0.0
            if ctc_weight != 1.0:
                # pre-pruning based on attention scores
                ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO))
            else:
                ctc_beam = lpz.shape[-1]
        hyps = [hyp]
        ended_hyps = []

        for i in six.moves.range(maxlen):
            logging.debug('position ' + str(i))

            hyps_best_kept = []
            for hyp in hyps:
                ey = self.embed(hyp['yseq'][i])           # utt list (1) x zdim
                att_c, att_w = self.att([h], hyp['z_prev'][0], hyp['a_prev'])
                ey = F.hstack((ey, att_c))   # utt(1) x (zdim + hdim)
                c_list[0], z_list[0] = self.lstm0(hyp['c_prev'][0], hyp['z_prev'][0], ey)
                for l in six.moves.range(1, self.dlayers):
                    c_list[l], z_list[l] = self['lstm%d' % l](
                        hyp['c_prev'][l], hyp['z_prev'][l], z_list[l - 1])

                # get nbest local scores and their ids
                local_att_scores = F.log_softmax(self.output(z_list[-1])).data
                if rnnlm:
                    rnnlm_state, local_lm_scores = rnnlm.predict(hyp['rnnlm_prev'], hyp['yseq'][i])
                    local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores
                else:
                    local_scores = local_att_scores

                if lpz is not None:
                    local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][:ctc_beam]
                    ctc_scores, ctc_states = ctc_prefix_score(hyp['yseq'], local_best_ids, hyp['ctc_state_prev'])
                    local_scores = \
                        (1.0 - ctc_weight) * local_att_scores[:, local_best_ids] \
                        + ctc_weight * (ctc_scores - hyp['ctc_score_prev'])
                    if rnnlm:
                        local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids]
                    joint_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][:beam]
                    local_best_scores = local_scores[:, joint_best_ids]
                    local_best_ids = local_best_ids[joint_best_ids]
                else:
                    local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][:beam]
                    local_best_scores = local_scores[:, local_best_ids]

                for j in six.moves.range(beam):
                    new_hyp = {}
                    # do not copy {z,c}_list directly
                    new_hyp['z_prev'] = z_list[:]
                    new_hyp['c_prev'] = c_list[:]
                    new_hyp['a_prev'] = att_w
                    new_hyp['score'] = hyp['score'] + local_best_scores[0, j]
                    new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
                    new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
                    new_hyp['yseq'][len(hyp['yseq'])] = self.xp.full(
                        1, local_best_ids[j], 'i')
                    if rnnlm:
                        new_hyp['rnnlm_prev'] = rnnlm_state
                    if lpz is not None:
                        new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[j]]
                        new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[j]]
                    # will be (2 x beam) hyps at most
                    hyps_best_kept.append(new_hyp)

                hyps_best_kept = sorted(
                    hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam]

            # sort and get nbest
            hyps = hyps_best_kept
            logging.debug('number of pruned hypothes: ' + str(len(hyps)))
            logging.debug('best hypo: ' + ''.join([char_list[int(x)]
                                                   for x in hyps[0]['yseq'][1:]]).replace('<space>', ' '))

            # add eos in the final loop to avoid that there are no ended hyps
            if i == maxlen - 1:
                logging.info('adding <eos> in the last postion in the loop')
                for hyp in hyps:
                    hyp['yseq'].append(self.xp.full(1, self.eos, 'i'))

            # add ended hypothes to a final list, and removed them from current hypothes
            # (this will be a problem, number of hyps < beam)
            remained_hyps = []
            for hyp in hyps:
                if hyp['yseq'][-1] == self.eos:
                    # only store the sequence that has more than minlen outputs
                    # also add penalty
                    if len(hyp['yseq']) > minlen:
                        hyp['score'] += (i + 1) * penalty
                        if rnnlm:  # Word LM needs to add final <eos> score
                            hyp['score'] += recog_args.lm_weight * rnnlm.final(
                                hyp['rnnlm_prev'])
                        ended_hyps.append(hyp)
                else:
                    remained_hyps.append(hyp)

            # end detection
            if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0:
                logging.info('end detected at %d', i)
                break

            hyps = remained_hyps
            if len(hyps) > 0:
                logging.debug('remeined hypothes: ' + str(len(hyps)))
            else:
                logging.info('no hypothesis. Finish decoding.')
                break

            for hyp in hyps:
                logging.debug('hypo: ' + ''.join([char_list[int(x)]
                                                  for x in hyp['yseq'][1:]]).replace('<space>', ' '))

            logging.debug('number of ended hypothes: ' + str(len(ended_hyps)))

        nbest_hyps = sorted(
            ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)]

        # check number of hypotheis
        if len(nbest_hyps) == 0:
            logging.warn('there is no N-best results, perform recognition again with smaller minlenratio.')
            # should copy becasuse Namespace will be overwritten globally
            recog_args = Namespace(**vars(recog_args))
            recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1)
            return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm)

        logging.info('total log probability: ' + str(nbest_hyps[0]['score']))
        logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq'])))

        return nbest_hyps
コード例 #3
0
    def recognize_beam(self, encoder_outputs, char_list, lpz, aligns_pad,
                       args):
        """Beam search, decode one utterence now.
        Args:
            encoder_outputs: T x H
            char_list: list of character
            args: args.beam

        Returns:
            nbest_hyps:
        """
        # search params
        #import pdb
        #pdb.set_trace()
        beam = args.beam_size
        nbest = args.nbest
        ctc_weight = args.ctc_weight
        CTC_SCORING_RATIO = 1.5
        if args.decode_max_len != 0:
            maxlen = args.decode_max_len
        elif lpz is not None:
            maxlen = int(len(torch.nonzero(torch.max(lpz, dim=-1)[1])) * 1.5)
        elif args.align_trun:
            maxlen = int(aligns_pad.size(1) * 1.5)

        # *********Init decoder rnn
        h_list = [self.zero_state(encoder_outputs.unsqueeze(0))]
        c_list = [self.zero_state(encoder_outputs.unsqueeze(0))]
        for l in range(1, self.num_layers):
            h_list.append(self.zero_state(encoder_outputs.unsqueeze(0)))
            c_list.append(self.zero_state(encoder_outputs.unsqueeze(0)))
        att_c = self.zero_state(encoder_outputs.unsqueeze(0),
                                H=encoder_outputs.unsqueeze(0).size(2))
        # prepare sos
        y = self.sos_id
        vy = encoder_outputs.new_zeros(1).long()

        hyp = {
            'score': 0.0,
            'yseq': [y],
            'c_prev': c_list,
            'h_prev': h_list,
            'a_prev': att_c
        }
        if lpz is not None:
            #import pdb
            #pdb.set_trace()
            ctc_prefix_score = CTCPrefixScore(lpz.detach().cpu().numpy(), 0,
                                              self.eos_id, np)
            hyp['ctc_state_prev'] = ctc_prefix_score.initial_state()
            hyp['ctc_score_prev'] = 0.0
            if ctc_weight != 1.0:
                ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO))
            else:
                ctc_beam = lpz.shape[-1]
            if args.trun:
                ctc_greedy = torch.max(lpz, dim=-1)[1].unsqueeze(dim=0)
                #print(ctc_greedy)
                aligns = []
                for k in range(ctc_greedy.size()[0]):
                    align = (torch.nonzero(ctc_greedy[k]) +
                             1).reshape(-1).cpu().numpy().tolist()
                    align.insert(0, 0)
                    aligns.append(align)
                #print(aligns[0:2])
                #print(np.shape(aligns))
                #aligns = torch.Tensor(aligns).long().cuda()
                aligns_pad = pad_list([torch.Tensor(y).long() for y in aligns],
                                      IGNORE_ID)

        hyps = [hyp]
        ended_hyps = []

        for i in range(maxlen):
            hyps_best_kept = []
            for hyp in hyps:
                # vy.unsqueeze(1)
                vy[0] = hyp['yseq'][i]
                embedded = self.embedding(vy)
                # embedded.unsqueeze(0)
                # step 1. decoder RNN: s_i = RNN(s_i−1,y_i−1,c_i−1)
                rnn_input = torch.cat((embedded, hyp['a_prev']), dim=1)
                h_list[0], c_list[0] = self.rnn[0](
                    rnn_input, (hyp['h_prev'][0], hyp['c_prev'][0]))
                for l in range(1, self.num_layers):
                    h_list[l], c_list[l] = self.rnn[l](
                        h_list[l - 1], (hyp['h_prev'][l], hyp['c_prev'][l]))
                rnn_output = h_list[-1]
                # step 2. attention: c_i = AttentionContext(s_i,h)
                # below unsqueeze: (N x H) -> (N x 1 x H)
                #import pdb
                #pdb.set_trace()
                mask = None
                if args.trun or args.align_trun:
                    mask = torch.ones(encoder_outputs.unsqueeze(0).size(0),
                                      encoder_outputs.unsqueeze(0).size(1),
                                      dtype=torch.uint8).cuda()
                    #mask = torch.zeros(encoder_outputs.unsqueeze(0).size(0),encoder_outputs.unsqueeze(0).size(1),dtype=torch.uint8).cuda()
                    if i + 1 < aligns_pad.size(1):
                        for m in range(mask.size(0)):
                            if self.peak_left != 0:
                                left_id = max(i - self.peak_left + 1, 0)
                            else:
                                left_id = 0
                            right_id = min(i + 1 + self.peak_right,
                                           aligns_pad.size(1) - 1)
                            left_bound = min(
                                aligns_pad[m][left_id] + self.offset,
                                rnn_output.size(1))
                            right_bound = max(
                                min(aligns_pad[m][right_id] + self.offset,
                                    rnn_output.size(1)), 0)
                            #right_bound = max(min(aligns_pad[m][i+1] + self.offset, rnn_output.size(1)), 0)
                            #left_bound = 0
                            #mask[m][0:right_bound] = 0
                            #mask[m][right_bound:-1] = 1
                            mask[m][left_bound:right_bound] = 0

                att_c, att_w = self.attention(rnn_output.unsqueeze(dim=1),
                                              encoder_outputs.unsqueeze(0),
                                              mask)
                att_c = att_c.squeeze(dim=1)
                # step 3. concate s_i and c_i, and input to MLP
                mlp_input = torch.cat((rnn_output, att_c), dim=1)
                predicted_y_t = self.mlp(mlp_input)
                local_att_scores = F.log_softmax(predicted_y_t, dim=1)

                local_scores = local_att_scores

                if args.ctc_weight > 0:
                    #import pdb
                    #pdb.set_trace()
                    local_best_scores, local_best_ids = torch.topk(
                        local_att_scores, ctc_beam, dim=1)
                    ctc_scores, ctc_states = ctc_prefix_score(
                        hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev'])
                    local_scores = (
                        1.0 - ctc_weight) * local_att_scores[:, local_best_ids[
                            0]] + ctc_weight * torch.from_numpy(
                                ctc_scores - hyp['ctc_score_prev']).cuda()
                    local_best_scores, joint_best_ids = torch.topk(
                        local_scores, beam, dim=1)
                    local_best_ids = local_best_ids[:, joint_best_ids[0]]
                else:
                    # topk scores
                    local_best_scores, local_best_ids = torch.topk(
                        local_scores, beam, dim=1)

                for j in range(beam):
                    new_hyp = {}
                    new_hyp['h_prev'] = h_list[:]
                    new_hyp['c_prev'] = c_list[:]
                    new_hyp['a_prev'] = att_c[:]
                    new_hyp['score'] = hyp['score'] + local_best_scores[0, j]
                    new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
                    new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
                    new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0,
                                                                           j])
                    # will be (2 x beam) hyps at most
                    if args.ctc_weight > 0:
                        new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[
                            0, j]]
                        new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[
                            0, j]]
                    hyps_best_kept.append(new_hyp)
                hyps_best_kept = sorted(hyps_best_kept,
                                        key=lambda x: x['score'],
                                        reverse=True)[:beam]
            # end for hyp in hyps
            hyps = hyps_best_kept

            # add eos in the final loop to avoid that there are no ended hyps
            if i == maxlen - 1:
                for hyp in hyps:
                    hyp['yseq'].append(self.eos_id)

            # add ended hypothes to a final list, and removed them from current hypothes
            # (this will be a probmlem, number of hyps < beam)
            remained_hyps = []
            for hyp in hyps:
                if hyp['yseq'][-1] == self.eos_id:
                    # hyp['score'] += (i + 1) * penalty
                    ended_hyps.append(hyp)
                else:
                    remained_hyps.append(hyp)

            hyps = remained_hyps
            if len(hyps) > 0:
                print('remeined hypothes: ' + str(len(hyps)))
            else:
                print('no hypothesis. Finish decoding.')
                break
            #import pdb
            #pdb.set_trace()
            for hyp in hyps:
                print('hypo: ' +
                      ' '.join([char_list[int(x)] for x in hyp['yseq'][1:]]))
        # end for i in range(maxlen)
        nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'],
                            reverse=True)[:min(len(ended_hyps), nbest)]
        #print(nbest_hyps)
        return nbest_hyps
コード例 #4
0
ファイル: e2e_asr_attctc_th.py プロジェクト: zhoutf/espnet
    def recognize_beam(self, h, lpz, recog_args, char_list):
        '''beam search implementation

        :param Variable h:
        :param Namespace recog_args:
        :param char_list:
        :return:
        '''
        logging.info('input lengths: ' + str(h.size(0)))
        # initialization
        c_list = [self.zero_state(h.unsqueeze(0))]
        z_list = [self.zero_state(h.unsqueeze(0))]
        for l in six.moves.range(1, self.dlayers):
            c_list.append(self.zero_state(h.unsqueeze(0)))
            z_list.append(self.zero_state(h.unsqueeze(0)))
        a = None
        self.att.reset()  # reset pre-computation of h

        # search parms
        beam = recog_args.beam_size
        penalty = recog_args.penalty
        ctc_weight = recog_args.ctc_weight

        # preprate sos
        y = self.sos
        vy = Variable(h.data.new(1).zero_().long(), volatile=True)
        if recog_args.maxlenratio == 0:
            maxlen = h.shape[0]
        else:
            # maxlen >= 1
            maxlen = max(1, int(recog_args.maxlenratio * h.size(0)))
        minlen = int(recog_args.minlenratio * h.size(0))
        logging.info('max output length: ' + str(maxlen))
        logging.info('min output length: ' + str(minlen))

        # initialize hypothesis
        hyp = {
            'score': 0.0,
            'yseq': [y],
            'c_prev': c_list,
            'z_prev': z_list,
            'a_prev': a
        }
        if lpz is not None:
            ctc_prefix_score = CTCPrefixScore(lpz.numpy(), 0, self.eos, np)
            hyp['ctc_state_prev'] = ctc_prefix_score.initial_state()
            hyp['ctc_score_prev'] = 0.0
            ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO))
        hyps = [hyp]
        ended_hyps = []

        for i in six.moves.range(maxlen):
            logging.debug('position ' + str(i))

            hyps_best_kept = []
            for hyp in hyps:
                vy.unsqueeze(1)
                vy[0] = hyp['yseq'][i]
                ey = self.embed(vy)  # utt list (1) x zdim
                ey.unsqueeze(0)
                att_c, att_w = self.att(h.unsqueeze(0), [h.size(0)],
                                        hyp['z_prev'][0], hyp['a_prev'])
                ey = torch.cat((ey, att_c), dim=1)  # utt(1) x (zdim + hdim)
                z_list[0], c_list[0] = self.decoder[0](
                    ey, (hyp['z_prev'][0], hyp['c_prev'][0]))
                for l in six.moves.range(1, self.dlayers):
                    z_list[l], c_list[l] = self.decoder[l](
                        z_list[l - 1], (hyp['z_prev'][l], hyp['c_prev'][l]))

                # get nbest local scores and their ids
                local_att_scores = F.log_softmax(self.output(z_list[-1]),
                                                 dim=1).data
                if lpz is not None:
                    local_best_scores, local_best_ids = torch.topk(
                        local_att_scores, ctc_beam, dim=1)
                    ctc_scores, ctc_states = ctc_prefix_score(
                        hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev'])
                    local_scores = \
                        (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \
                        + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev'])
                    local_best_scores, joint_best_ids = torch.topk(
                        local_scores, beam, dim=1)
                    local_best_ids = local_best_ids[:, joint_best_ids[0]]
                else:
                    local_best_scores, local_best_ids = torch.topk(
                        local_att_scores, beam, dim=1)

                for j in six.moves.range(beam):
                    new_hyp = {}
                    new_hyp['z_prev'] = [z_list[0]]
                    new_hyp['c_prev'] = [c_list[0]]
                    for l in six.moves.range(1, self.dlayers):
                        new_hyp['z_prev'].append(z_list[l])
                        new_hyp['c_prev'].append(c_list[l])
                    new_hyp['a_prev'] = att_w
                    new_hyp['score'] = hyp['score'] + local_best_scores[0, j]
                    new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
                    new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
                    new_hyp['yseq'][len(hyp['yseq'])] = local_best_ids[0, j]
                    if lpz is not None:
                        new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[
                            0, j]]
                        new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[
                            0, j]]
                    # will be (2 x beam) hyps at most
                    hyps_best_kept.append(new_hyp)

                hyps_best_kept = sorted(hyps_best_kept,
                                        key=lambda x: x['score'],
                                        reverse=True)[:beam]

            # sort and get nbest
            hyps = hyps_best_kept
            logging.debug('number of pruned hypothes: ' + str(len(hyps)))
            logging.debug(
                'best hypo: ' +
                ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]]))

            # add eos in the final loop to avoid that there are no ended hyps
            if i == maxlen - 1:
                logging.info('adding <eos> in the last postion in the loop')
                for hyp in hyps:
                    hyp['yseq'].append(self.eos)

            # add ended hypothes to a final list, and removed them from current hypothes
            # (this will be a probmlem, number of hyps < beam)
            remained_hyps = []
            for hyp in hyps:
                if hyp['yseq'][-1] == self.eos:
                    # only store the sequence that has more than minlen outputs
                    # also add penalty
                    if len(hyp['yseq']) > minlen:
                        hyp['score'] += (i + 1) * penalty
                        ended_hyps.append(hyp)
                else:
                    remained_hyps.append(hyp)

            # end detection
            if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0:
                logging.info('end detected at %d', i)
                break

            hyps = remained_hyps
            if len(hyps) > 0:
                logging.debug('remeined hypothes: ' + str(len(hyps)))
            else:
                logging.info('no hypothesis. Finish decoding.')
                break

            for hyp in hyps:
                logging.debug(
                    'hypo: ' +
                    ''.join([char_list[int(x)] for x in hyp['yseq'][1:]]))

            logging.debug('number of ended hypothes: ' + str(len(ended_hyps)))

        best_hyp = sorted(ended_hyps, key=lambda x: x['score'],
                          reverse=True)[0]
        logging.info('total log probability: ' + str(best_hyp['score']))
        logging.info('normalized log probability: ' +
                     str(best_hyp['score'] / len(best_hyp['yseq'])))

        # remove sos
        return best_hyp['yseq'][1:]
コード例 #5
0
    def recognize_beam(self, encoder_outputs, char_list, lpz, args):
        """Beam search, decode one utterence now.
        Args:
            encoder_outputs: T x H
            char_list: list of character
            args: args.beam

        Returns:
            nbest_hyps:
        """
        # search params
        beam = args.beam_size
        nbest = args.nbest
        ctc_weight = args.ctc_weight
        CTC_SCORING_RATIO = 1.5
        if args.decode_max_len == 0:
            maxlen = encoder_outputs.size(0)
        else:
            maxlen = args.decode_max_len

        # *********Init decoder rnn
        h_list = [self.zero_state(encoder_outputs.unsqueeze(0))]
        c_list = [self.zero_state(encoder_outputs.unsqueeze(0))]
        for l in range(1, self.num_layers):
            h_list.append(self.zero_state(encoder_outputs.unsqueeze(0)))
            c_list.append(self.zero_state(encoder_outputs.unsqueeze(0)))
        att_c = self.zero_state(encoder_outputs.unsqueeze(0),
                                H=encoder_outputs.unsqueeze(0).size(2))
        # prepare sos
        y = self.sos_id
        vy = encoder_outputs.new_zeros(1).long()

        hyp = {
            'score': 0.0,
            'yseq': [y],
            'c_prev': c_list,
            'h_prev': h_list,
            'a_prev': att_c
        }
        if lpz is not None:
            import pdb
            pdb.set_trace()
            ctc_prefix_score = CTCPrefixScore(lpz.detach().cpu().numpy(), 0,
                                              self.eos_id, np)
            hyp['ctc_state_prev'] = ctc_prefix_score.initial_state()
            hyp['ctc_score_prev'] = 0.0
            if ctc_weight != 1.0:
                ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO))
            else:
                ctc_beam = lpz.shape[-1]
        hyps = [hyp]
        ended_hyps = []

        for i in range(maxlen):
            hyps_best_kept = []
            for hyp in hyps:
                # vy.unsqueeze(1)
                vy[0] = hyp['yseq'][i]
                embedded = self.embedding(vy)
                # embedded.unsqueeze(0)
                # step 1. decoder RNN: s_i = RNN(s_i−1,y_i−1,c_i−1)
                rnn_input = torch.cat((embedded, hyp['a_prev']), dim=1)
                h_list[0], c_list[0] = self.rnn[0](
                    rnn_input, (hyp['h_prev'][0], hyp['c_prev'][0]))
                for l in range(1, self.num_layers):
                    h_list[l], c_list[l] = self.rnn[l](
                        h_list[l - 1], (hyp['h_prev'][l], hyp['c_prev'][l]))
                rnn_output = h_list[-1]
                # step 2. attention: c_i = AttentionContext(s_i,h)
                # below unsqueeze: (N x H) -> (N x 1 x H)
                att_c, att_w = self.attention(rnn_output.unsqueeze(dim=1),
                                              encoder_outputs.unsqueeze(0))
                att_c = att_c.squeeze(dim=1)
                # step 3. concate s_i and c_i, and input to MLP
                mlp_input = torch.cat((rnn_output, att_c), dim=1)
                predicted_y_t = self.mlp(mlp_input)
                local_att_scores = F.log_softmax(predicted_y_t, dim=1)

                local_scores = local_att_scores

                if lpz is not None:
                    #import pdb
                    #pdb.set_trace()
                    local_best_scores, local_best_ids = torch.topk(
                        local_att_scores, ctc_beam, dim=1)
                    ctc_scores, ctc_states = ctc_prefix_score(
                        hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev'])
                    local_scores = (
                        1.0 - ctc_weight) * local_att_scores[:, local_best_ids[
                            0]] + ctc_weight * torch.from_numpy(
                                ctc_scores - hyp['ctc_score_prev']).cuda()
                    local_best_scores, joint_best_ids = torch.topk(
                        local_scores, beam, dim=1)
                    local_best_ids = local_best_ids[:, joint_best_ids[0]]
                else:
                    # topk scores
                    local_best_scores, local_best_ids = torch.topk(
                        local_scores, beam, dim=1)

                for j in range(beam):
                    new_hyp = {}
                    new_hyp['h_prev'] = h_list[:]
                    new_hyp['c_prev'] = c_list[:]
                    new_hyp['a_prev'] = att_c[:]
                    new_hyp['score'] = hyp['score'] + local_best_scores[0, j]
                    new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
                    new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
                    new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0,
                                                                           j])
                    # will be (2 x beam) hyps at most
                    if lpz is not None:
                        new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[
                            0, j]]
                        new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[
                            0, j]]
                    hyps_best_kept.append(new_hyp)
                hyps_best_kept = sorted(hyps_best_kept,
                                        key=lambda x: x['score'],
                                        reverse=True)[:beam]
            # end for hyp in hyps
            hyps = hyps_best_kept

            # add eos in the final loop to avoid that there are no ended hyps
            if i == maxlen - 1:
                for hyp in hyps:
                    hyp['yseq'].append(self.eos_id)

            # add ended hypothes to a final list, and removed them from current hypothes
            # (this will be a probmlem, number of hyps < beam)
            remained_hyps = []
            for hyp in hyps:
                if hyp['yseq'][-1] == self.eos_id:
                    # hyp['score'] += (i + 1) * penalty
                    ended_hyps.append(hyp)
                else:
                    remained_hyps.append(hyp)

            hyps = remained_hyps
            if len(hyps) > 0:
                print('remeined hypothes: ' + str(len(hyps)))
            else:
                print('no hypothesis. Finish decoding.')
                break

            for hyp in hyps:
                print('hypo: ' +
                      ' '.join([char_list[int(x)] for x in hyp['yseq'][1:]]))
        # end for i in range(maxlen)
        nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'],
                            reverse=True)[:min(len(ended_hyps), nbest)]
        #print(nbest_hyps)
        return nbest_hyps