def test_real_ctc_decode2(): with open(os.path.join(data_dir, 'ctc-test.pickle'), 'rb') as f: seq, label = pickle.load(f, encoding='bytes') seq = torch.as_tensor(seq).squeeze().log_softmax(1) tokenizer = CharTokenizer(os.path.join(data_dir, 'toy-data-vocab.txt')) beam_width = 16 result = decoders.ctc_beam_search_decoder(seq, blank=tokenizer.blank_idx, beam_size=beam_width) txt_result = ''.join(tokenizer.idx2token(result[0][1])) assert txt_result == 'then seconds' assert np.allclose(1.1842575, result[0][0], atol=1e-3) # lm-based decoding scorer = KenLMScorer(os.path.join(data_dir, 'ctc-test-lm.binary'), tokenizer, alpha=2.0, beta=0.5, unit='word') result = decoders.ctc_beam_search_decoder(seq, lm_scorer=scorer, blank=tokenizer.blank_idx, beam_size=beam_width) txt_result = ''.join(tokenizer.idx2token(result[0][1])) assert txt_result == label
def test_beam_search_decoder(data): vocab_list, beam_size, log_probs_seq1, log_probs_seq2, _, beam_search_result = data beam_result = decoders.ctc_beam_search_decoder(log_probs_seq1, beam_size=beam_size) assert ''.join([vocab_list[s] for s in beam_result[0][1]]) == beam_search_result[0] beam_result = decoders.ctc_beam_search_decoder(log_probs_seq2, beam_size=beam_size) assert ''.join([vocab_list[s] for s in beam_result[0][1]]) == beam_search_result[1]
def test_real_ctc_decode(): data = np.genfromtxt(os.path.join(data_dir, "rnn_output.csv"), delimiter=';')[:, :-1] inputs = torch.as_tensor(data).log_softmax(1) tokenizer = CharTokenizer(os.path.join(data_dir, 'labels.txt')) # greedy using beam result = decoders.ctc_greedy_decoder(inputs, blank=tokenizer.blank_idx) txt_result = ''.join(tokenizer.idx2token(result[1])) assert "the fak friend of the fomly hae tC" == txt_result # default beam decoding result = decoders.ctc_beam_search_decoder(inputs, blank=tokenizer.blank_idx, beam_size=25) txt_result = ''.join(tokenizer.idx2token(result[0][1])) # assert "the fak friend of the fomcly hae tC" == txt_result # lm-based decoding scorer = KenLMScorer(os.path.join(data_dir, 'bigram.arpa'), tokenizer, alpha=2.0, beta=0, unit='word') result = decoders.ctc_beam_search_decoder(inputs, lm_scorer=scorer, blank=tokenizer.blank_idx, beam_size=25) txt_result = ''.join(tokenizer.idx2token(result[0][1])) assert "the fake friend of the fomlyhaetC" == txt_result # lm-based decoding with trie scorer = KenLMScorer(os.path.join(data_dir, 'bigram.arpa'), tokenizer, trie_path=os.path.join(data_dir, 'bigram.fst'), alpha=2.0, beta=0, unit='word') result = decoders.ctc_beam_search_decoder(inputs, lm_scorer=scorer, blank=tokenizer.blank_idx, beam_size=25) txt_result = ''.join(tokenizer.idx2token(result[0][1])) for r in result: print(tokenizer.idx2token(r[1]), r[0]) assert "the fake friend of the family, like the" == txt_result
def decode_segment(segment, offset, mode, tokenizer, scorer): if mode == 'greedy': result = decoders.ctc_greedy_decoder(segment, blank=tokenizer.blank_idx) best_result = result elif mode == 'lm+fst': result = decoders.ctc_beam_search_decoder(segment, lm_scorer=scorer, blank=tokenizer.blank_idx, beam_size=200) best_result = result[0] else: raise ValueError('mode must be one of: "greedy" | "lm+fst"') if mode == 'lm+fst' and not best_result[1]: return [], "", [] # remove spaces at the end best_token_ids = np.trim_zeros(np.array(best_result[1]), 'b') new_l = len(best_token_ids) frame_ids = np.array(best_result[2])[:new_l] frame_ids = np.append(frame_ids, frame_ids[-1]) frame_ids += offset txt_result = ''.join(tokenizer.idx2token(best_result[1])) word_starts = [0] + list( frame_ids[np.where(best_token_ids == 0)[0] + 1] * 0.02) word_ends = list(frame_ids[np.nonzero(best_token_ids == 0)[0] - 1] * 0.02) + [frame_ids[-1] * 0.02] segment_transcription_json = [{ 'transcript': triple[0], 'start': triple[1], 'end': triple[2] } for triple in zip(txt_result.split(' '), word_starts, word_ends)] return frame_ids.tolist(), txt_result, segment_transcription_json
def test_ctc_beam_search_decoder_tf(): log_input = torch.tensor([[0, 0.6, 0, 0.4, 0, 0], [0, 0.5, 0, 0.5, 0, 0], [0, 0.4, 0, 0.6, 0, 0], [0, 0.4, 0, 0.6, 0, 0], [0, 0.4, 0, 0.6, 0, 0]], dtype=torch.float32).log() beam_results = decoders.ctc_beam_search_decoder(log_input, beam_size=30) assert beam_results[0][1] == (1, 3) assert beam_results[1][1] == (1, 3, 1) assert beam_results[2][1] == (3, 1, 3)
def test_ctc_decoder_beam_search_different_blank_idx(): input_log_prob_matrix_0 = torch.tensor( [ [0.173908, 0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352], [0.230517, 0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581], [0.238763, 0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289], [0.20655, 0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803], [0.129878, 0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297], # Random entry added in at time=5 [0.160671, 0.155251, 0.164444, 0.173517, 0.176138, 0.169979] ], dtype=torch.float32).log() results = decoders.ctc_beam_search_decoder(input_log_prob_matrix_0, blank=0, beam_size=2) assert len(results[0][1]) == 2 assert len(results[1][1]) == 3 assert np.alltrue(results[0][1] == (2, 1)) assert np.alltrue(results[1][1] == (2, 1, 4)) assert np.allclose(4.73437, results[0][0]) # tf results: 7.0157223 assert np.allclose(5.318605, results[1][0])
def real_ctc_decode(filename, data_dir, frame_segments, mode='lm+fst'): data = np.load(os.path.join(data_dir, filename + '.npy')) inputs = torch.as_tensor(data).log_softmax(1) tokenizer = CharTokenizer('labels.txt') # greedy using beam greedy_result = decoders.ctc_greedy_decoder(inputs, blank=tokenizer.blank_idx) txt_result_greedy = ''.join(tokenizer.idx2token(greedy_result[1])) print(txt_result_greedy) transcription_json = [] scorer = KenLMScorer(os.path.join(data_dir, 'lm.arpa'), tokenizer, trie_path=os.path.join(data_dir, filename + '.fst'), alpha=0.01, beta=0.0, unit='word') print('unmerged segments:', frame_segments) merged_segments = merge_segments(frame_segments) print('merged segments:', merged_segments) transcription_json = [] res = [] t0 = time.time() if len(frame_segments) < 3 and False: pass # result = decoders.ctc_beam_search_decoder(inputs, # lm_scorer=scorer, # blank=tokenizer.blank_idx, # beam_size=300) # frame_idss = [] # best_result = result[0] # if not best_result[1]: # raise Exception # best_token_ids = np.trim_zeros(np.array(best_result[1]), 'b') # new_l = len(best_token_ids) # frame_ids = np.array(best_result[2])[:new_l] # frame_ids += start # frame_idss += list(frame_ids) # frame_idss += [frame_idss[-1]] # txt_result = ''.join(tokenizer.idx2token(best_result[1])) # res.append(txt_result) # best_result = result[0] # txt_result = ''.join(tokenizer.idx2token(best_result[1])) # # word_starts = [0] + list(frame_ids[np.where(best_token_ids == 0)[0] + 1] * 0.02) # word_ends = list(frame_ids[np.where(best_token_ids == 0)[0] - 1] * 0.02) + [frame_ids[-1] * 0.02] # transcription_json += [{'transcript': triple[0], 'start': triple[1], 'end': triple[2]} # for triple in zip(txt_result.split(' '), word_starts, word_ends)] # for txt in res: # print(txt) else: frame_idss = [] for start, end in tqdm(merged_segments): result = decoders.ctc_beam_search_decoder( inputs[start:end, :], lm_scorer=scorer, blank=tokenizer.blank_idx, beam_size=200) best_result = result[0] if not best_result[1]: continue best_token_ids = np.trim_zeros(np.array(best_result[1]), 'b') new_l = len(best_token_ids) frame_ids = np.array(best_result[2])[:new_l] frame_ids += start frame_idss += list(frame_ids) frame_idss += [frame_idss[-1]] txt_result = ''.join(tokenizer.idx2token(best_result[1])) res.append(txt_result) best_result = result[0] txt_result = ''.join(tokenizer.idx2token(best_result[1])) word_starts = [0] + list( frame_ids[np.where(best_token_ids == 0)[0] + 1] * 0.02) word_ends = list(frame_ids[np.where(best_token_ids == 0)[0] - 1] * 0.02) + [frame_ids[-1] * 0.02] transcription_json += [{ 'transcript': triple[0], 'start': triple[1], 'end': triple[2] } for triple in zip(txt_result.split(' '), word_starts, word_ends)] for txt in res: print(txt) t1 = time.time() result_filename = os.path.join(data_dir, filename + '_result.txt') tlog_filename = os.path.join(data_dir, filename + '_result.json') with open(result_filename, 'w', encoding='utf8') as f: print(' '.join(res), file=f) json.dump(transcription_json, open(tlog_filename, 'w', encoding='utf8'), indent=4, ensure_ascii=False) return np.array(frame_idss) * 0.02