def test_set_from_file_reads_padded_files(self): # pylint: disable=protected-access vocab_filename = self.TEST_DIR / 'vocab_file' with codecs.open(vocab_filename, 'w', 'utf-8') as vocab_file: vocab_file.write('<S>\n') vocab_file.write('</S>\n') vocab_file.write('<UNK>\n') vocab_file.write('a\n') vocab_file.write('tricky\x0bchar\n') vocab_file.write('word\n') vocab_file.write('another\n') vocab = Vocabulary() vocab.set_from_file(vocab_filename, is_padded=True, oov_token="<UNK>") assert vocab._oov_token == DEFAULT_OOV_TOKEN assert vocab.get_token_index("random string") == 3 assert vocab.get_token_index("<S>") == 1 assert vocab.get_token_index("</S>") == 2 assert vocab.get_token_index(DEFAULT_OOV_TOKEN) == 3 assert vocab.get_token_index("a") == 4 assert vocab.get_token_index("tricky\x0bchar") == 5 assert vocab.get_token_index("word") == 6 assert vocab.get_token_index("another") == 7 assert vocab.get_token_from_index(0) == vocab._padding_token assert vocab.get_token_from_index(1) == "<S>" assert vocab.get_token_from_index(2) == "</S>" assert vocab.get_token_from_index(3) == DEFAULT_OOV_TOKEN assert vocab.get_token_from_index(4) == "a" assert vocab.get_token_from_index(5) == "tricky\x0bchar" assert vocab.get_token_from_index(6) == "word" assert vocab.get_token_from_index(7) == "another"
def test_set_from_file_reads_non_padded_files(self): # pylint: disable=protected-access vocab_filename = self.TEST_DIR / 'vocab_file' with codecs.open(vocab_filename, 'w', 'utf-8') as vocab_file: vocab_file.write('B-PERS\n') vocab_file.write('I-PERS\n') vocab_file.write('O\n') vocab_file.write('B-ORG\n') vocab_file.write('I-ORG\n') vocab = Vocabulary() vocab.set_from_file(vocab_filename, is_padded=False, namespace='tags') assert vocab.get_token_index("B-PERS", namespace='tags') == 0 assert vocab.get_token_index("I-PERS", namespace='tags') == 1 assert vocab.get_token_index("O", namespace='tags') == 2 assert vocab.get_token_index("B-ORG", namespace='tags') == 3 assert vocab.get_token_index("I-ORG", namespace='tags') == 4 assert vocab.get_token_from_index(0, namespace='tags') == "B-PERS" assert vocab.get_token_from_index(1, namespace='tags') == "I-PERS" assert vocab.get_token_from_index(2, namespace='tags') == "O" assert vocab.get_token_from_index(3, namespace='tags') == "B-ORG" assert vocab.get_token_from_index(4, namespace='tags') == "I-ORG"