def test_phoneme_table_add_labels(): phoneme_table = PhonemeTable() phoneme_table.add_labels(['a', 'i']) assert phoneme_table.num_labels() == 4 assert phoneme_table.get_label_id('a') == 2 assert phoneme_table.get_label(2) == 'a' assert phoneme_table.get_label_id('i') == 3 assert phoneme_table.get_label(3) == 'i'
def test_phoneme_table_add_label(): phoneme_table = PhonemeTable() phoneme_table.add_label('a') assert phoneme_table.num_labels() == 3 assert phoneme_table.get_label_id('a') == 2 assert phoneme_table.get_label(2) == 'a'
def test_phoneme_table_get_epsilon_id(): phoneme_table = PhonemeTable() assert phoneme_table.get_epsilon_id() == 0 assert phoneme_table.get_label_id('<epsilon>') == 0 assert phoneme_table.get_label(0) == '<epsilon>'
def test_phoneme_table_get_blank_id(): phoneme_table = PhonemeTable() assert phoneme_table.get_blank_id() == 1 assert phoneme_table.get_label_id('<blank>') == 1 assert phoneme_table.get_label(1) == '<blank>'
phoneme_table = PhonemeTable() phoneme_table.add_labels(phonemes) epsilon_id = phoneme_table.get_epsilon_id() print('Loading model ...') model_path = os.path.join(args.workdir, args.model_file) model = EESENAcousticModel.load(model_path) feature_params_path = os.path.join(args.workdir, args.feature_params_file) feature_params = FeatureParams.load(feature_params_path) batch = [] for wav_file in args.wav_files: data = extract_feature_from_wavfile(wav_file, feature_params) batch.append(torch.from_numpy(data)) output = model.predict(pad_sequence(batch)) for idx, wav_file in enumerate(args.wav_files): print('Decoding {} ... '.format(wav_file)) frame_labels = [int(frame_label) for frame_label in output[:, idx]] print(' acoustic labels = {}'.format(' '.join( [phoneme_table.get_label(frame_label) for frame_label in frame_labels if frame_label != phoneme_table.get_blank_id()])) ) vocabulary_symbol_path = os.path.join( args.workdir, args.vocabulary_symbol_file) vocab_symbol = VocabularySymbolTable.load_symbol( vocabulary_symbol_path) decoder_fst_path = os.path.join(args.workdir, args.decoder_fst_file) wfst_decoder = WFSTDecoder() wfst_decoder.read_fst(decoder_fst_path) print(' text = {} '.format(wfst_decoder.decode( frame_labels, vocab_symbol, epsilon_id=epsilon_id)))