def test_vocabulary_table_save_and_load(workdir):
    vocab_path = os.path.join(workdir, 'vocab.txt')
    vocab_table = VocabularyTable()
    vocab_table.add_label('私')
    vocab_table.add_label('あなた')
    vocab_table.add_label('あなた')
    vocab_table.save(vocab_path)

    vocab_loaded = VocabularyTable.load(vocab_path)
    assert vocab_loaded.get_label_id('<pad>') == 0
    assert vocab_loaded.get_label(0) == '<pad>'
    assert vocab_loaded.get_label_id('<unk>') == 1
    assert vocab_loaded.get_label(1) == '<unk>'
    assert vocab_loaded.get_label_id('<bos>') == 2
    assert vocab_loaded.get_label(2) == '<bos>'
    assert vocab_loaded.get_label_id('<eos>') == 3
    assert vocab_loaded.get_label(3) == '<eos>'
    assert vocab_loaded.get_label_id('私') == 4
    assert vocab_loaded.get_label(4) == '私'
    assert vocab_loaded.get_label_id('あなた') == 5
    assert vocab_loaded.get_label(5) == 'あなた'

    vocab_loaded = VocabularyTable.load(vocab_path, min_freq=2)
    assert vocab_loaded.get_label_id('あなた') == 4
    assert vocab_loaded.get_label(4) == 'あなた'
    assert vocab_loaded.get_label_id('私') == vocab_loaded.get_unk_id()

    vocab_loaded = VocabularyTable.load(vocab_path, min_freq=3)
    assert vocab_loaded.get_label_id('あなた') == vocab_loaded.get_unk_id()
    assert vocab_loaded.get_label_id('私') == vocab_loaded.get_unk_id()
def test_vocabulary_table_add_label():
    # when min_freq is 1
    vocab_table = VocabularyTable()
    vocab_table.add_label('私')
    assert vocab_table.num_labels() == 5
    assert vocab_table.get_label_id('私') == 4
    assert vocab_table.get_label(4) == '私'
    # when min_freq is 2 or more
    vocab_table = VocabularyTable(min_freq=2)
    vocab_table.add_label('私')
    assert vocab_table.num_labels() == 4
    assert vocab_table.get_label_id('私') == vocab_table.get_unk_id()
    vocab_table.add_label('私')
    assert vocab_table.num_labels() == 5
    assert vocab_table.get_label_id('私') == 4
Esempio n. 3
0
def get_vocabulary_table(workdir, words):
    vocab_table = VocabularyTable()
    for word in words:
        vocab_table.add_label(word['word'])
    return vocab_table