def test_phoneme_table_get_all_labels():
    phoneme_table = PhonemeTable()
    phoneme_table.add_label('a')
    got = phoneme_table.get_all_labels()
    assert isinstance(got, dict)
    assert len(got) == 3
    assert got[0] == '<epsilon>'
    assert got[1] == '<blank>'
    assert got[2] == 'a'
def test_phoneme_table_add_labels():
    phoneme_table = PhonemeTable()
    phoneme_table.add_labels(['a', 'i'])
    assert phoneme_table.num_labels() == 4
    assert phoneme_table.get_label_id('a') == 2
    assert phoneme_table.get_label(2) == 'a'
    assert phoneme_table.get_label_id('i') == 3
    assert phoneme_table.get_label(3) == 'i'
Esempio n. 3
0
def test_compose_token_and_lexicon_fst(workdir, words_without_homophones):
    vocab = get_vocabulary_table(workdir, words_without_homophones)
    lexicon = get_lexicon(words_without_homophones)

    phoneme_table = PhonemeTable()
    phoneme_table.add_labels(phonemes)

    lexicon_fst = lexicon.create_fst(phoneme_table, vocab)

    token = Token()
    token_fst = token.create_fst(phoneme_table)

    fst = pywrapfst.compose(token_fst.arcsort('olabel'), lexicon_fst)
    fst = pywrapfst.determinize(fst)
def test_phoneme_table_get_auxiliary_label_id():
    phoneme_table = PhonemeTable()
    phoneme_table.add_label('a')
    phoneme_table.set_auxiliary_label('#0')
    phoneme_table.set_auxiliary_label('#1')
    assert phoneme_table.get_label_id('a') == 2
    assert phoneme_table.get_auxiliary_label_id('#0') == 3
    assert phoneme_table.get_auxiliary_label_id('#1') == 4
Esempio n. 5
0
def test_compose_token_and_lexicon_fst_with_homophones(workdir,
                                                       words_with_homophones):
    vocab = get_vocabulary_table(workdir, words_with_homophones)
    lexicon = get_lexicon(words_with_homophones)

    phoneme_table = PhonemeTable()
    phoneme_table.add_labels(phonemes)

    lexicon_fst = lexicon.create_fst(phoneme_table, vocab, min_freq=0)

    token = Token()
    token_fst = token.create_fst(phoneme_table)

    fst = pywrapfst.compose(token_fst.arcsort('olabel'), lexicon_fst)
    with pytest.raises(pywrapfst.FstOpError):
        pywrapfst.determinize(fst)
Esempio n. 6
0
def test_token_create_fst_with_auxiliary_labels():
    phoneme_table = PhonemeTable()
    phoneme_table.add_labels(['a', 'i'])
    epsilon_id = phoneme_table.get_epsilon_id()
    blank_id = phoneme_table.get_blank_id()
    a = phoneme_table.get_label_id('a')
    i = phoneme_table.get_label_id('i')
    phoneme_table.set_auxiliary_label('#0')
    phoneme_table.set_auxiliary_label('#1')
    aux0 = phoneme_table.get_auxiliary_label_id('#0')
    aux1 = phoneme_table.get_auxiliary_label_id('#1')

    fst = Token().create_fst(phoneme_table)
    assert (fst.num_states() == 5)
    # start state
    state = 0
    assert (fst.num_arcs(state) == 3)
    gen_arc = fst.arcs(state)
    is_expected_arc(next(gen_arc), blank_id, epsilon_id, state)
    is_expected_arc(next(gen_arc), a, a, 3)
    is_expected_arc(next(gen_arc), i, i, 4)
    # second state
    state = 1
    assert (fst.num_arcs(state) == 2)
    gen_arc = fst.arcs(state)
    is_expected_arc(next(gen_arc), blank_id, epsilon_id, state)
    is_expected_arc(next(gen_arc), epsilon_id, epsilon_id, 2)
    # final(auxiliary) state
    state = 2
    assert (fst.num_arcs(state) == 3)
    gen_arc = fst.arcs(state)
    is_expected_arc(next(gen_arc), epsilon_id, epsilon_id, 0)
    is_expected_arc(next(gen_arc), epsilon_id, aux0, state)
    is_expected_arc(next(gen_arc), epsilon_id, aux1, state)
    # a
    state = 3
    assert (fst.num_arcs(state) == 2)
    gen_arc = fst.arcs(state)
    is_expected_arc(next(gen_arc), a, epsilon_id, state)
    is_expected_arc(next(gen_arc), epsilon_id, epsilon_id, 1)
    # b
    state = 4
    assert (fst.num_arcs(state) == 2)
    gen_arc = fst.arcs(state)
    is_expected_arc(next(gen_arc), i, epsilon_id, state)
    is_expected_arc(next(gen_arc), epsilon_id, epsilon_id, 1)
Esempio n. 7
0
def test_wfst_decoder_normal_transition():
    phoneme_table = PhonemeTable()
    phoneme_table.add_labels(phonemes)

    fst_compiler = _FstCompiler()
    eps = phoneme_table.get_epsilon_id()
    blank = phoneme_table.get_blank_id()
    a = phoneme_table.get_label_id('a')
    i = phoneme_table.get_label_id('i')
    fst_compiler.add_arc(0, 1, blank, eps, 0.2)
    fst_compiler.add_arc(1, 2, a, eps, 0.1)
    fst_compiler.add_arc(1, 3, i, eps, 0.2)
    fst = fst_compiler.compile()

    wfst_decoder = WFSTDecoder(fst)
    prev_paths = {
        0: wfst_decoder.Path(score=0,
                             prev_path=None,
                             frame_index=0,
                             olabel=None)
    }
    curr_paths = {}
    wfst_decoder.normal_transition(prev_paths, curr_paths, 0, blank)
    assert 1 in curr_paths
    assert round(curr_paths[1].score, 6) == 0.2
    assert round(curr_paths[1].prev_path.score, 6) == 0
    prev_paths = curr_paths
    curr_paths = {}
    wfst_decoder.normal_transition(prev_paths, curr_paths, 1, a)
    assert 2 in curr_paths
    assert round(curr_paths[2].score, 6) == 0.3
    assert curr_paths[2].frame_index == 1
    assert round(curr_paths[2].prev_path.score, 6) == 0.2
Esempio n. 8
0
def test_lexicon_create_fst_without_homophones(workdir,
                                               words_without_homophones):
    vocab = get_vocabulary_table(workdir, words_without_homophones)
    lexicon = get_lexicon(words_without_homophones)

    phoneme_table = PhonemeTable()
    phoneme_table.add_labels(phonemes)
    epsilon_id = phoneme_table.get_epsilon_id()
    a = phoneme_table.get_label_id('a')
    i = phoneme_table.get_label_id('i')
    o = phoneme_table.get_label_id('o')

    fst = lexicon.create_fst(phoneme_table, vocab, min_freq=0)
    assert (fst.num_states() == 7)
    aux0 = phoneme_table.get_auxiliary_label_id('#0')

    state = 0
    assert (fst.num_arcs(0) == 2)
    gen = fst.arcs(state)
    arc = next(gen)
    is_expected_arc(arc, a, vocab.get_label_id('愛'), 1)
    arc = gen.__next__()
    is_expected_arc(arc, a, vocab.get_label_id('青'), 4)

    state = 1
    assert (fst.num_arcs(state) == 1)
    arc = next(fst.arcs(state))
    is_expected_arc(arc, i, epsilon_id, 2)

    state = 2
    assert (fst.num_arcs(state) == 1)
    arc = next(fst.arcs(state))
    is_expected_arc(arc, aux0, epsilon_id, 3)

    state = 3
    assert (fst.num_arcs(state) == 1)
    arc = next(fst.arcs(state))
    is_expected_arc(arc, epsilon_id, epsilon_id, 0)

    state = 4
    assert (fst.num_arcs(state) == 1)
    arc = next(fst.arcs(state))
    is_expected_arc(arc, o, epsilon_id, 5)

    state = 5
    assert (fst.num_arcs(state) == 1)
    arc = next(fst.arcs(state))
    is_expected_arc(arc, aux0, epsilon_id, 6)

    state = 6
    assert (fst.num_arcs(state) == 1)
    arc = next(fst.arcs(state))
    is_expected_arc(arc, epsilon_id, epsilon_id, 0)
def test_phoneme_table_get_all_auxiliary_labels():
    phoneme_table = PhonemeTable()
    phoneme_table.add_label('a')
    phoneme_table.set_auxiliary_label('#0')
    phoneme_table.set_auxiliary_label('#1')
    got = phoneme_table.get_all_auxiliary_labels()
    assert isinstance(got, dict)
    assert len(got) == 2
    assert got[3] == '#0'
    assert got[4] == '#1'
Esempio n. 10
0
def test_wfst_decoder_decode(workdir, words_for_corpus_with_homophones):
    corpus_path = os.path.join(workdir, 'corpus.txt')
    create_corpus(corpus_path, words_for_corpus_with_homophones)

    vocab_path = os.path.join(workdir, 'vocab.syms')
    vocab = create_vocabulary_symbol_table(vocab_path, corpus_path)

    phoneme_table = PhonemeTable()
    phoneme_table.add_labels(phonemes)

    lexicon = get_lexicon(words_for_corpus_with_homophones)
    lexicon_fst = lexicon.create_fst(phoneme_table, vocab, min_freq=0)

    token = Token()
    token_fst = token.create_fst(phoneme_table)

    grammar_path = os.path.join(workdir, 'grammar.fst')
    grammar = Grammar()
    grammar_fst = grammar.create_fst(grammar_path, vocab_path, corpus_path)

    wfst_decoder = WFSTDecoder()
    wfst_decoder.create_fst(token_fst, lexicon_fst, grammar_fst)

    blank_id = phoneme_table.get_blank_id()
    a = phoneme_table.get_label_id('a')
    i = phoneme_table.get_label_id('i')
    d = phoneme_table.get_label_id('d')
    e = phoneme_table.get_label_id('e')
    s = phoneme_table.get_label_id('s')
    o = phoneme_table.get_label_id('o')
    m = phoneme_table.get_label_id('m')
    r = phoneme_table.get_label_id('r')
    u = phoneme_table.get_label_id('u')
    frame_labels = [
        blank_id, blank_id, a, a, i, i, i, d, e, blank_id, s, s, o, o, o, m, e,
        r, r, u
    ]
    got = wfst_decoder.decode(frame_labels, vocab)
    assert got == '藍で染める'
Esempio n. 11
0
def test_wfst_decoder_create_fst(workdir, words_for_corpus_without_homophones):
    corpus_path = os.path.join(workdir, 'corpus.txt')
    create_corpus(corpus_path, words_for_corpus_without_homophones)

    vocab_path = os.path.join(workdir, 'vocab.syms')
    vocab = create_vocabulary_symbol_table(vocab_path, corpus_path)

    phoneme_table = PhonemeTable()
    phoneme_table.add_labels(phonemes)

    lexicon = get_lexicon(words_for_corpus_without_homophones)
    lexicon_fst = lexicon.create_fst(phoneme_table, vocab, min_freq=0)

    token = Token()
    token_fst = token.create_fst(phoneme_table)

    grammar_path = os.path.join(workdir, 'grammar.fst')
    grammar = Grammar()
    grammar_fst = grammar.create_fst(grammar_path, vocab_path, corpus_path)

    wfst_decoder = WFSTDecoder()
    wfst_decoder.create_fst(token_fst, lexicon_fst, grammar_fst)
Esempio n. 12
0
def test_wfst_decoder_epsilon_transition():
    phoneme_table = PhonemeTable()
    phoneme_table.add_labels(phonemes)

    fst_compiler = _FstCompiler()
    eps = phoneme_table.get_epsilon_id()
    a = phoneme_table.get_label_id('a')
    fst_compiler.add_arc(0, 1, eps, eps, 0.1)
    fst_compiler.add_arc(1, 2, eps, eps, 0.2)
    fst_compiler.add_arc(1, 3, eps, eps, 0.3)
    fst_compiler.add_arc(0, 2, eps, eps, 0.15)
    fst_compiler.add_arc(0, 3, eps, eps, 0.5)
    fst_compiler.add_arc(0, 4, a, eps, 0.15)
    fst_compiler.set_final(4)
    fst = fst_compiler.compile()

    wfst_decoder = WFSTDecoder(fst)
    paths = {
        0: wfst_decoder.Path(score=0,
                             prev_path=None,
                             frame_index=0,
                             olabel=None)
    }
    frame_index = 0
    wfst_decoder.epsilon_transition(paths, phoneme_table, frame_index)
    # check new state 1 is added to paths
    # TODO: state:1が残るのは果たしてよいのか?無限ループしそう
    assert 1 in paths
    assert round(paths[1].score, 6) == 0.1
    assert paths[1].prev_path.score == 0
    # check existing state 2 is updated via better path
    assert 2 in paths
    assert round(paths[2].score, 6) == 0.15
    assert paths[2].prev_path.score == 0
    # check existing state 3 is not updated
    assert round(paths[3].score, 6) == 0.4
    assert round(paths[3].prev_path.score, 6) == 0.1
    assert paths[3].prev_path.prev_path.score == 0
    # check new state 4 is not added to paths
    # because it's not epsilon transition
    assert 4 not in paths
Esempio n. 13
0
parser = argparse.ArgumentParser()
parser.add_argument('workdir', help='Directory path where files are saved')
parser.add_argument('wav_files', nargs='*', help='wave files to recognize')
parser.add_argument('--vocabulary-symbol-file', type=str, default='vocab.syms',
                    help='Vocabulary symbol file name')
parser.add_argument('--feature-params-file', type=str,
                    default='feature_params.json',
                    help='Feature params file name')
parser.add_argument('--model-file', type=str, default='model.bin',
                    help='model file name')
parser.add_argument('--decoder-fst-file', type=str, default='decoder.fst',
                    help='Decoder FST file name')
args = parser.parse_args()

phoneme_table = PhonemeTable()
phoneme_table.add_labels(phonemes)
epsilon_id = phoneme_table.get_epsilon_id()
print('Loading model ...')
model_path = os.path.join(args.workdir, args.model_file)
model = EESENAcousticModel.load(model_path)
feature_params_path = os.path.join(args.workdir, args.feature_params_file)
feature_params = FeatureParams.load(feature_params_path)
batch = []
for wav_file in args.wav_files:
    data = extract_feature_from_wavfile(wav_file, feature_params)
    batch.append(torch.from_numpy(data))
output = model.predict(pad_sequence(batch))
for idx, wav_file in enumerate(args.wav_files):
    print('Decoding {} ... '.format(wav_file))
    frame_labels = [int(frame_label) for frame_label in output[:, idx]]
def test_phoneme_table_add_label():
    phoneme_table = PhonemeTable()
    phoneme_table.add_label('a')
    assert phoneme_table.num_labels() == 3
    assert phoneme_table.get_label_id('a') == 2
    assert phoneme_table.get_label(2) == 'a'
def test_phoneme_table_get_blank_id():
    phoneme_table = PhonemeTable()
    assert phoneme_table.get_blank_id() == 1
    assert phoneme_table.get_label_id('<blank>') == 1
    assert phoneme_table.get_label(1) == '<blank>'
def test_phoneme_table_get_epsilon_id():
    phoneme_table = PhonemeTable()
    assert phoneme_table.get_epsilon_id() == 0
    assert phoneme_table.get_label_id('<epsilon>') == 0
    assert phoneme_table.get_label(0) == '<epsilon>'
Esempio n. 17
0
                    help='Optimizer to use')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
parser.add_argument('--epochs',
                    type=int,
                    default=5,
                    help='number of epochs for training')
parser.add_argument('--batch-size', type=int, default=32, help='batch size')
parser.add_argument('--device', type=str, default='cpu', help='Device string')
parser.add_argument('--model-file',
                    type=str,
                    default="model.bin",
                    help='Model file to save')
parser.add_argument('--resume', action='store_true')
args = parser.parse_args()

phoneme_table = PhonemeTable()
phoneme_table.add_labels(phonemes)
print('Loading training data ...')
training_data_dirpath = os.path.join(args.workdir, args.training_data_dirname)
repository_tr = TrainingDatasetRepository(training_data_dirpath)
dataset_tr = IterableAudioDataset(repository_tr, phoneme_table)
dataloader_tr = DataLoader(dataset_tr,
                           batch_size=args.batch_size,
                           collate_fn=collate_for_ctc)
print('Loading development data ...')
development_data_dirpath = os.path.join(args.workdir,
                                        args.development_data_dirname)
repository_dev = DevelopmentDatasetRepository(development_data_dirpath)
dataloaders_dev = []
for dataset_dev in AudioDataset.load_all(repository_dev, phoneme_table):
    dataloader_dev = DataLoader(dataset_dev,
Esempio n. 18
0
                    help='Token FST file name')
parser.add_argument('--lexicon-fst-file',
                    type=str,
                    default='lexicon.fst',
                    help='Lexicon FST file name')
parser.add_argument('--grammar-fst-file',
                    type=str,
                    default='grammar.fst',
                    help='Grammar FST file name')
parser.add_argument('--decoder-fst-file',
                    type=str,
                    default='decoder.fst',
                    help='Decoder FST file name')
args = parser.parse_args()

phoneme_table = PhonemeTable()
phoneme_table.add_labels(phonemes)
print('Creating vocabulary symbol ...')
corpus_path = os.path.join(args.workdir, args.corpus_file)
vocabulary_symbol_path = os.path.join(args.workdir,
                                      args.vocabulary_symbol_file)
VocabularySymbolTable.create_symbol(vocabulary_symbol_path, corpus_path)
vocabulary_symbol_table = VocabularySymbolTable.load_symbol(
    vocabulary_symbol_path)
print('Creating lexicon FST ...')
lexicon_path = os.path.join(args.workdir, args.lexicon_file)
lexicon_fst_filepath = os.path.join(args.workdir, args.lexicon_fst_file)
lexicon = Lexicon()
lexicon.load(lexicon_path)
lexicon_fst = lexicon.create_fst(phoneme_table, vocabulary_symbol_table)
lexicon_fst.write(lexicon_fst_filepath)