def test_wfst_decoder_normal_transition(): phoneme_table = PhonemeTable() phoneme_table.add_labels(phonemes) fst_compiler = _FstCompiler() eps = phoneme_table.get_epsilon_id() blank = phoneme_table.get_blank_id() a = phoneme_table.get_label_id('a') i = phoneme_table.get_label_id('i') fst_compiler.add_arc(0, 1, blank, eps, 0.2) fst_compiler.add_arc(1, 2, a, eps, 0.1) fst_compiler.add_arc(1, 3, i, eps, 0.2) fst = fst_compiler.compile() wfst_decoder = WFSTDecoder(fst) prev_paths = { 0: wfst_decoder.Path(score=0, prev_path=None, frame_index=0, olabel=None) } curr_paths = {} wfst_decoder.normal_transition(prev_paths, curr_paths, 0, blank) assert 1 in curr_paths assert round(curr_paths[1].score, 6) == 0.2 assert round(curr_paths[1].prev_path.score, 6) == 0 prev_paths = curr_paths curr_paths = {} wfst_decoder.normal_transition(prev_paths, curr_paths, 1, a) assert 2 in curr_paths assert round(curr_paths[2].score, 6) == 0.3 assert curr_paths[2].frame_index == 1 assert round(curr_paths[2].prev_path.score, 6) == 0.2
def test_lexicon_create_fst_without_homophones(workdir, words_without_homophones): vocab = get_vocabulary_table(workdir, words_without_homophones) lexicon = get_lexicon(words_without_homophones) phoneme_table = PhonemeTable() phoneme_table.add_labels(phonemes) epsilon_id = phoneme_table.get_epsilon_id() a = phoneme_table.get_label_id('a') i = phoneme_table.get_label_id('i') o = phoneme_table.get_label_id('o') fst = lexicon.create_fst(phoneme_table, vocab, min_freq=0) assert (fst.num_states() == 7) aux0 = phoneme_table.get_auxiliary_label_id('#0') state = 0 assert (fst.num_arcs(0) == 2) gen = fst.arcs(state) arc = next(gen) is_expected_arc(arc, a, vocab.get_label_id('愛'), 1) arc = gen.__next__() is_expected_arc(arc, a, vocab.get_label_id('青'), 4) state = 1 assert (fst.num_arcs(state) == 1) arc = next(fst.arcs(state)) is_expected_arc(arc, i, epsilon_id, 2) state = 2 assert (fst.num_arcs(state) == 1) arc = next(fst.arcs(state)) is_expected_arc(arc, aux0, epsilon_id, 3) state = 3 assert (fst.num_arcs(state) == 1) arc = next(fst.arcs(state)) is_expected_arc(arc, epsilon_id, epsilon_id, 0) state = 4 assert (fst.num_arcs(state) == 1) arc = next(fst.arcs(state)) is_expected_arc(arc, o, epsilon_id, 5) state = 5 assert (fst.num_arcs(state) == 1) arc = next(fst.arcs(state)) is_expected_arc(arc, aux0, epsilon_id, 6) state = 6 assert (fst.num_arcs(state) == 1) arc = next(fst.arcs(state)) is_expected_arc(arc, epsilon_id, epsilon_id, 0)
def test_token_create_fst_with_auxiliary_labels(): phoneme_table = PhonemeTable() phoneme_table.add_labels(['a', 'i']) epsilon_id = phoneme_table.get_epsilon_id() blank_id = phoneme_table.get_blank_id() a = phoneme_table.get_label_id('a') i = phoneme_table.get_label_id('i') phoneme_table.set_auxiliary_label('#0') phoneme_table.set_auxiliary_label('#1') aux0 = phoneme_table.get_auxiliary_label_id('#0') aux1 = phoneme_table.get_auxiliary_label_id('#1') fst = Token().create_fst(phoneme_table) assert (fst.num_states() == 5) # start state state = 0 assert (fst.num_arcs(state) == 3) gen_arc = fst.arcs(state) is_expected_arc(next(gen_arc), blank_id, epsilon_id, state) is_expected_arc(next(gen_arc), a, a, 3) is_expected_arc(next(gen_arc), i, i, 4) # second state state = 1 assert (fst.num_arcs(state) == 2) gen_arc = fst.arcs(state) is_expected_arc(next(gen_arc), blank_id, epsilon_id, state) is_expected_arc(next(gen_arc), epsilon_id, epsilon_id, 2) # final(auxiliary) state state = 2 assert (fst.num_arcs(state) == 3) gen_arc = fst.arcs(state) is_expected_arc(next(gen_arc), epsilon_id, epsilon_id, 0) is_expected_arc(next(gen_arc), epsilon_id, aux0, state) is_expected_arc(next(gen_arc), epsilon_id, aux1, state) # a state = 3 assert (fst.num_arcs(state) == 2) gen_arc = fst.arcs(state) is_expected_arc(next(gen_arc), a, epsilon_id, state) is_expected_arc(next(gen_arc), epsilon_id, epsilon_id, 1) # b state = 4 assert (fst.num_arcs(state) == 2) gen_arc = fst.arcs(state) is_expected_arc(next(gen_arc), i, epsilon_id, state) is_expected_arc(next(gen_arc), epsilon_id, epsilon_id, 1)
def test_wfst_decoder_epsilon_transition(): phoneme_table = PhonemeTable() phoneme_table.add_labels(phonemes) fst_compiler = _FstCompiler() eps = phoneme_table.get_epsilon_id() a = phoneme_table.get_label_id('a') fst_compiler.add_arc(0, 1, eps, eps, 0.1) fst_compiler.add_arc(1, 2, eps, eps, 0.2) fst_compiler.add_arc(1, 3, eps, eps, 0.3) fst_compiler.add_arc(0, 2, eps, eps, 0.15) fst_compiler.add_arc(0, 3, eps, eps, 0.5) fst_compiler.add_arc(0, 4, a, eps, 0.15) fst_compiler.set_final(4) fst = fst_compiler.compile() wfst_decoder = WFSTDecoder(fst) paths = { 0: wfst_decoder.Path(score=0, prev_path=None, frame_index=0, olabel=None) } frame_index = 0 wfst_decoder.epsilon_transition(paths, phoneme_table, frame_index) # check new state 1 is added to paths # TODO: state:1が残るのは果たしてよいのか?無限ループしそう assert 1 in paths assert round(paths[1].score, 6) == 0.1 assert paths[1].prev_path.score == 0 # check existing state 2 is updated via better path assert 2 in paths assert round(paths[2].score, 6) == 0.15 assert paths[2].prev_path.score == 0 # check existing state 3 is not updated assert round(paths[3].score, 6) == 0.4 assert round(paths[3].prev_path.score, 6) == 0.1 assert paths[3].prev_path.prev_path.score == 0 # check new state 4 is not added to paths # because it's not epsilon transition assert 4 not in paths
def test_phoneme_table_get_epsilon_id(): phoneme_table = PhonemeTable() assert phoneme_table.get_epsilon_id() == 0 assert phoneme_table.get_label_id('<epsilon>') == 0 assert phoneme_table.get_label(0) == '<epsilon>'
parser.add_argument('workdir', help='Directory path where files are saved') parser.add_argument('wav_files', nargs='*', help='wave files to recognize') parser.add_argument('--vocabulary-symbol-file', type=str, default='vocab.syms', help='Vocabulary symbol file name') parser.add_argument('--feature-params-file', type=str, default='feature_params.json', help='Feature params file name') parser.add_argument('--model-file', type=str, default='model.bin', help='model file name') parser.add_argument('--decoder-fst-file', type=str, default='decoder.fst', help='Decoder FST file name') args = parser.parse_args() phoneme_table = PhonemeTable() phoneme_table.add_labels(phonemes) epsilon_id = phoneme_table.get_epsilon_id() print('Loading model ...') model_path = os.path.join(args.workdir, args.model_file) model = EESENAcousticModel.load(model_path) feature_params_path = os.path.join(args.workdir, args.feature_params_file) feature_params = FeatureParams.load(feature_params_path) batch = [] for wav_file in args.wav_files: data = extract_feature_from_wavfile(wav_file, feature_params) batch.append(torch.from_numpy(data)) output = model.predict(pad_sequence(batch)) for idx, wav_file in enumerate(args.wav_files): print('Decoding {} ... '.format(wav_file)) frame_labels = [int(frame_label) for frame_label in output[:, idx]] print(' acoustic labels = {}'.format(' '.join( [phoneme_table.get_label(frame_label) for frame_label in frame_labels