コード例 #1
0
def test_greedy_decoder(data):
    vocab_list, _, log_probs_seq1, log_probs_seq2, greedy_result, _ = data
    _, bst_result, _ = decoders.ctc_greedy_decoder(log_probs_seq1)
    assert ''.join(vocab_list[i] for i in bst_result) == greedy_result[0]

    _, bst_result, _ = decoders.ctc_greedy_decoder(log_probs_seq2)
    assert ''.join(vocab_list[i] for i in bst_result) == greedy_result[1]
コード例 #2
0
def decode_segment(segment, offset, mode, tokenizer, scorer):
    if mode == 'greedy':
        result = decoders.ctc_greedy_decoder(segment,
                                             blank=tokenizer.blank_idx)
        best_result = result
    elif mode == 'lm+fst':
        result = decoders.ctc_beam_search_decoder(segment,
                                                  lm_scorer=scorer,
                                                  blank=tokenizer.blank_idx,
                                                  beam_size=200)
        best_result = result[0]
    else:
        raise ValueError('mode must be one of: "greedy" | "lm+fst"')

    if mode == 'lm+fst' and not best_result[1]:
        return [], "", []
    # remove spaces at the end
    best_token_ids = np.trim_zeros(np.array(best_result[1]), 'b')
    new_l = len(best_token_ids)
    frame_ids = np.array(best_result[2])[:new_l]
    frame_ids = np.append(frame_ids, frame_ids[-1])
    frame_ids += offset
    txt_result = ''.join(tokenizer.idx2token(best_result[1]))
    word_starts = [0] + list(
        frame_ids[np.where(best_token_ids == 0)[0] + 1] * 0.02)
    word_ends = list(frame_ids[np.nonzero(best_token_ids == 0)[0] - 1] *
                     0.02) + [frame_ids[-1] * 0.02]
    segment_transcription_json = [{
        'transcript': triple[0],
        'start': triple[1],
        'end': triple[2]
    } for triple in zip(txt_result.split(' '), word_starts, word_ends)]
    return frame_ids.tolist(), txt_result, segment_transcription_json
コード例 #3
0
def test_real_ctc_decode():
    data = np.genfromtxt(os.path.join(data_dir, "rnn_output.csv"),
                         delimiter=';')[:, :-1]
    inputs = torch.as_tensor(data).log_softmax(1)

    tokenizer = CharTokenizer(os.path.join(data_dir, 'labels.txt'))

    # greedy using beam
    result = decoders.ctc_greedy_decoder(inputs, blank=tokenizer.blank_idx)
    txt_result = ''.join(tokenizer.idx2token(result[1]))

    assert "the fak friend of the fomly hae tC" == txt_result

    # default beam decoding
    result = decoders.ctc_beam_search_decoder(inputs,
                                              blank=tokenizer.blank_idx,
                                              beam_size=25)

    txt_result = ''.join(tokenizer.idx2token(result[0][1]))
    # assert "the fak friend of the fomcly hae tC" == txt_result

    # lm-based decoding
    scorer = KenLMScorer(os.path.join(data_dir, 'bigram.arpa'),
                         tokenizer,
                         alpha=2.0,
                         beta=0,
                         unit='word')
    result = decoders.ctc_beam_search_decoder(inputs,
                                              lm_scorer=scorer,
                                              blank=tokenizer.blank_idx,
                                              beam_size=25)
    txt_result = ''.join(tokenizer.idx2token(result[0][1]))
    assert "the fake friend of the fomlyhaetC" == txt_result

    # lm-based decoding with trie
    scorer = KenLMScorer(os.path.join(data_dir, 'bigram.arpa'),
                         tokenizer,
                         trie_path=os.path.join(data_dir, 'bigram.fst'),
                         alpha=2.0,
                         beta=0,
                         unit='word')
    result = decoders.ctc_beam_search_decoder(inputs,
                                              lm_scorer=scorer,
                                              blank=tokenizer.blank_idx,
                                              beam_size=25)
    txt_result = ''.join(tokenizer.idx2token(result[0][1]))

    for r in result:
        print(tokenizer.idx2token(r[1]), r[0])
    assert "the fake friend of the family, like the" == txt_result
コード例 #4
0
def real_ctc_decode(filename, data_dir, frame_segments, mode='lm+fst'):
    data = np.load(os.path.join(data_dir, filename + '.npy'))

    inputs = torch.as_tensor(data).log_softmax(1)

    tokenizer = CharTokenizer('labels.txt')

    # greedy using beam
    greedy_result = decoders.ctc_greedy_decoder(inputs,
                                                blank=tokenizer.blank_idx)
    txt_result_greedy = ''.join(tokenizer.idx2token(greedy_result[1]))

    print(txt_result_greedy)
    transcription_json = []

    scorer = KenLMScorer(os.path.join(data_dir, 'lm.arpa'),
                         tokenizer,
                         trie_path=os.path.join(data_dir, filename + '.fst'),
                         alpha=0.01,
                         beta=0.0,
                         unit='word')
    print('unmerged segments:', frame_segments)
    merged_segments = merge_segments(frame_segments)
    print('merged segments:', merged_segments)

    transcription_json = []
    res = []
    t0 = time.time()
    if len(frame_segments) < 3 and False:
        pass
        # result = decoders.ctc_beam_search_decoder(inputs,
        #                                           lm_scorer=scorer,
        #                                           blank=tokenizer.blank_idx,
        #                                           beam_size=300)
        # frame_idss = []
        # best_result = result[0]
        # if not best_result[1]:
        #     raise Exception
        # best_token_ids = np.trim_zeros(np.array(best_result[1]), 'b')
        # new_l = len(best_token_ids)
        # frame_ids = np.array(best_result[2])[:new_l]
        # frame_ids += start
        # frame_idss += list(frame_ids)
        # frame_idss += [frame_idss[-1]]
        # txt_result = ''.join(tokenizer.idx2token(best_result[1]))
        # res.append(txt_result)
        # best_result = result[0]
        # txt_result = ''.join(tokenizer.idx2token(best_result[1]))
        #
        # word_starts = [0] + list(frame_ids[np.where(best_token_ids == 0)[0] + 1] * 0.02)
        # word_ends = list(frame_ids[np.where(best_token_ids == 0)[0] - 1] * 0.02) + [frame_ids[-1] * 0.02]
        # transcription_json += [{'transcript': triple[0], 'start': triple[1], 'end': triple[2]}
        #                        for triple in zip(txt_result.split(' '), word_starts, word_ends)]
        # for txt in res:
        #     print(txt)
    else:
        frame_idss = []
        for start, end in tqdm(merged_segments):
            result = decoders.ctc_beam_search_decoder(
                inputs[start:end, :],
                lm_scorer=scorer,
                blank=tokenizer.blank_idx,
                beam_size=200)
            best_result = result[0]
            if not best_result[1]:
                continue
            best_token_ids = np.trim_zeros(np.array(best_result[1]), 'b')
            new_l = len(best_token_ids)
            frame_ids = np.array(best_result[2])[:new_l]
            frame_ids += start
            frame_idss += list(frame_ids)
            frame_idss += [frame_idss[-1]]
            txt_result = ''.join(tokenizer.idx2token(best_result[1]))
            res.append(txt_result)
            best_result = result[0]
            txt_result = ''.join(tokenizer.idx2token(best_result[1]))

            word_starts = [0] + list(
                frame_ids[np.where(best_token_ids == 0)[0] + 1] * 0.02)
            word_ends = list(frame_ids[np.where(best_token_ids == 0)[0] - 1] *
                             0.02) + [frame_ids[-1] * 0.02]
            transcription_json += [{
                'transcript': triple[0],
                'start': triple[1],
                'end': triple[2]
            } for triple in zip(txt_result.split(' '), word_starts, word_ends)]
        for txt in res:
            print(txt)
    t1 = time.time()
    result_filename = os.path.join(data_dir, filename + '_result.txt')
    tlog_filename = os.path.join(data_dir, filename + '_result.json')

    with open(result_filename, 'w', encoding='utf8') as f:
        print(' '.join(res), file=f)
    json.dump(transcription_json,
              open(tlog_filename, 'w', encoding='utf8'),
              indent=4,
              ensure_ascii=False)
    return np.array(frame_idss) * 0.02