def test_phoneme_table_add_labels():
    phoneme_table = PhonemeTable()
    phoneme_table.add_labels(['a', 'i'])
    assert phoneme_table.num_labels() == 4
    assert phoneme_table.get_label_id('a') == 2
    assert phoneme_table.get_label(2) == 'a'
    assert phoneme_table.get_label_id('i') == 3
    assert phoneme_table.get_label(3) == 'i'
def test_phoneme_table_add_label():
    phoneme_table = PhonemeTable()
    phoneme_table.add_label('a')
    assert phoneme_table.num_labels() == 3
    assert phoneme_table.get_label_id('a') == 2
    assert phoneme_table.get_label(2) == 'a'
def test_phoneme_table_get_epsilon_id():
    phoneme_table = PhonemeTable()
    assert phoneme_table.get_epsilon_id() == 0
    assert phoneme_table.get_label_id('<epsilon>') == 0
    assert phoneme_table.get_label(0) == '<epsilon>'
def test_phoneme_table_get_blank_id():
    phoneme_table = PhonemeTable()
    assert phoneme_table.get_blank_id() == 1
    assert phoneme_table.get_label_id('<blank>') == 1
    assert phoneme_table.get_label(1) == '<blank>'
Example #5
0
phoneme_table = PhonemeTable()
phoneme_table.add_labels(phonemes)
epsilon_id = phoneme_table.get_epsilon_id()
print('Loading model ...')
model_path = os.path.join(args.workdir, args.model_file)
model = EESENAcousticModel.load(model_path)
feature_params_path = os.path.join(args.workdir, args.feature_params_file)
feature_params = FeatureParams.load(feature_params_path)
batch = []
for wav_file in args.wav_files:
    data = extract_feature_from_wavfile(wav_file, feature_params)
    batch.append(torch.from_numpy(data))
output = model.predict(pad_sequence(batch))
for idx, wav_file in enumerate(args.wav_files):
    print('Decoding {} ... '.format(wav_file))
    frame_labels = [int(frame_label) for frame_label in output[:, idx]]
    print('  acoustic labels = {}'.format(' '.join(
        [phoneme_table.get_label(frame_label) for frame_label in frame_labels
         if frame_label != phoneme_table.get_blank_id()]))
    )
    vocabulary_symbol_path = os.path.join(
        args.workdir, args.vocabulary_symbol_file)
    vocab_symbol = VocabularySymbolTable.load_symbol(
        vocabulary_symbol_path)
    decoder_fst_path = os.path.join(args.workdir, args.decoder_fst_file)
    wfst_decoder = WFSTDecoder()
    wfst_decoder.read_fst(decoder_fst_path)
    print('  text = {} '.format(wfst_decoder.decode(
        frame_labels, vocab_symbol, epsilon_id=epsilon_id)))