def test_build_and_extend_consistency(self): """ Make sure that index is built correctly no matter whether the input to build() is a single sentence or a list of sentences. """ # Test build() test_sentence = "This is a test sentence" test_corpus = ["This is a", "test sentence"] test_tokenized_corpus = [["This", "is", "a"], ["test", "sentence"]] t2i1 = T2I.build(test_sentence) t2i2 = T2I.build(test_corpus) t2i3 = T2I.build(test_tokenized_corpus) self.assertEqual(t2i1, t2i2, t2i3) # Test extend() test_sentence2 = "These are new words" test_corpus2 = ["These are", "new words"] test_tokenized_corpus2 = [["These", "are"], ["new", "words"]] self.assertEqual(t2i1.extend(test_sentence2), t2i2.extend(test_corpus2), t2i3.extend(test_tokenized_corpus2)) # Test extend with a mix of types self.assertEqual(t2i1.extend(test_tokenized_corpus2), t2i2.extend(test_sentence2), t2i3.extend(test_corpus2))
def test_improper_min_freq(self): """ Test whether weird values for min_freq raise an AssertionError. """ with self.assertRaises(AssertionError): T2I({}, min_freq=0) with self.assertRaises(AssertionError): T2I({}, min_freq=-1)
def test_serialization(self): """ The above. """ t2i = T2I.build(" ".join([ random_str(random.randint(3, 10)) for _ in range(random.randint(20, 40)) ])) t2i.save(self.path) self.assertEqual(T2I.load(self.path), t2i)
def _assert_indexing_consistency(self, corpus: Corpus, t2i: T2I, joiner: str = " ", delimiter: str = " "): """ Test whether first indexing and then un-indexing yields the original sequence. """ self.assertEqual( t2i.unindex(t2i.index(corpus, delimiter=delimiter), joiner=joiner), corpus)
def test_eq(self): """ Test the __eq__ function. """ t2i1 = T2I.build(self.test_corpus1) t2i2 = T2I.build(self.test_corpus1) t2i3 = T2I.build(self.test_corpus2) self.assertTrue((t2i1 == t2i2)) self.assertFalse((t2i1 == t2i3)) self.assertFalse((t2i1 == t2i1._index))
def test_count_init(self): """ Test whether tokens are ignored during the normal T2I initialization when their frequency is too low. """ # Test whether warning is given when a counter is given but min_freq is still 1 with warnings.catch_warnings(record=True) as caught_warnings: T2I(counter=self.counter) self.assertEqual(len(caught_warnings), 1) t2i = T2I(self.index, counter=self.counter, min_freq=self.min_freq) self.assertTrue( self._check_freq_filtering(t2i, self.counter, self.min_freq))
def test_count_vocab_file(self): """ Test whether tokens are ignored when building the T2I object from a vocab file. """ with warnings.catch_warnings(record=True) as caught_warnings: T2I.from_file(self.vocab_path, counter=self.counter) self.assertEqual(len(caught_warnings), 1) t2i = T2I.from_file(self.vocab_path, counter=self.counter, min_freq=self.min_freq) self.assertTrue( self._check_freq_filtering(t2i, self.counter, self.min_freq))
def test_count_build(self): """ Test whether tokens are ignored when using build() when their frequency is too low. """ with warnings.catch_warnings(record=True) as caught_warnings: T2I.build(self.corpus, counter=self.counter) self.assertEqual(len(caught_warnings), 1) t2i = T2I.build(self.corpus, counter=self.counter, min_freq=self.min_freq) self.assertTrue( self._check_freq_filtering(t2i, self.counter, self.min_freq))
def test_init(self): """ The above. """ # Init an empty T2I object empty_t2i = T2I() self.assertEqual(3, len(empty_t2i)) self.assertEqual(Index, type(empty_t2i._index)) # Init a T2I object with unk and eos token t2i = T2I({"<eos>": 10, "<unk>": 14}) self.assertEqual(t2i["<unk>"], 0) self.assertEqual(t2i["<eos>"], 1) self.assertEqual(t2i["<pad>"], 2)
def test_custom_special_tokens_indexing(self): """ Test indexing with custom eos / unk token. """ t2i = T2I.build( self.test_corpus3, unk_token="#UNK#", eos_token="#EOS#", pad_token="#PAD#", special_tokens=("#MASK#", "#FLASK#"), ) self.assertEqual(t2i.index(self.test_corpus5), self.indexed_test_corpus45) self.assertEqual(t2i.index(self.test_corpus5b), self.indexed_test_corpus45) self._assert_indexing_consistency(self.test_corpus5, t2i) self.assertIn("#MASK#", t2i) self.assertIn("#FLASK#", t2i) string_repr = str(t2i) self.assertIn("#MASK#", string_repr) self.assertIn("#FLASK#", string_repr) self.assertEqual(t2i.index(self.test_corpus5c), self.indexed_test_corpus5c) # Make sure special tokens are still there after extend() extended_t2i = t2i.extend(self.test_corpus4) self.assertIn("#MASK#", extended_t2i) self.assertIn("#FLASK#", extended_t2i) extended_string_repr = str(extended_t2i) self.assertIn("#MASK#", extended_string_repr) self.assertIn("#FLASK#", extended_string_repr) self.assertEqual(extended_t2i.index(self.test_corpus5c), self.indexed_test_corpus5c2)
def test_immutability(self): """ Test whether the T2I stays immutable after object init. """ t2i = T2I.build(self.test_corpus1) with self.assertRaises(TypeError): t2i["banana"] = 66
def test_torchtext_compatibility(self): """ Test whether the vocab object is compatible with the torchtext Vocab class. """ t2i = T2I.build(self.test_corpus1) self.assertEqual(t2i.t2i, t2i.stoi) self.assertEqual(t2i.i2t, t2i.itos)
def test_eos_indexing(self): """ Test indexing with (default) end-of-sequence token. """ t2i = T2I.build(self.test_corpus3) self.assertEqual(t2i.index(self.test_corpus3b), self.indexed_test_corpus3) self._assert_indexing_consistency(self.test_corpus3b, t2i)
def test_default_indexing(self): """ Test normal indexing case. """ t2i = T2I.build(self.test_corpus1) self.assertEqual(t2i.index(self.test_corpus1), self.indexed_test_corpus1) self._assert_indexing_consistency(self.test_corpus1, t2i)
def test_iter(self): """ Test the __iter__ method. """ t2i = T2I.build(self.test_corpus1) contents = set([(k, v) for k, v in t2i]) expected_contents = {("A", 0), ("B", 1), ("C", 2), ("D", 3), ("E", 4), ("<unk>", 5), ("<eos>", 6), ("<pad>", 7)} self.assertEqual(expected_contents, contents)
def test_unk_indexing(self): """ Test indexing with unknown words. """ t2i = T2I.build(self.test_corpus3) self.assertEqual(t2i.index(self.test_corpus4), self.indexed_test_corpus45) self.assertEqual(t2i.index(self.test_corpus4b), self.indexed_test_corpus45) self._assert_indexing_consistency(self.test_corpus4, t2i)
def test_delimiter_indexing(self): """ Test indexing with different delimiter. """ t2i = T2I.build(self.test_corpus2, delimiter="-") self.assertEqual(t2i.index(self.test_corpus2, delimiter="-"), self.indexed_test_corpus2) self._assert_indexing_consistency(self.test_corpus2, t2i, joiner="-", delimiter="-")
def test_check_corpus(self): """ Test the _check_corpus() function. """ test_sentence = "This is a test sentence" test_corpus = ["This is a", "test sentence"] test_tokenized_corpus = [["This", "is", "a"], ["test", "sentence"]] # These should work T2I._check_corpus(test_sentence) T2I._check_corpus(test_corpus) T2I._check_corpus(test_tokenized_corpus) # These should fail # Completely unexpected type with self.assertRaises(AssertionError): T2I._check_corpus(list(range(10))) # Additional nesting with self.assertRaises(AssertionError): T2I._check_corpus([test_tokenized_corpus])
def test_iter(self): """ Test the __iter__, tokens() and indices() functions. """ t2i = T2I.build(self.test_corpus1) contents = set([(k, v) for k, v in t2i]) expected_contents = [("A", 0), ("B", 1), ("C", 2), ("D", 3), ("E", 4), ("<unk>", 5), ("<eos>", 6), ("<pad>", 7)] self.assertEqual(set(expected_contents), contents) self.assertEqual(list(zip(*expected_contents))[0], t2i.tokens()) self.assertEqual(list(zip(*expected_contents))[1], t2i.indices())
def test_representation(self): """ Test whether the string representation works correctly. """ t2i = T2I.build(self.test_corpus1, unk_token=">UNK<", eos_token=">EOS<", pad_token=">PAD<") str_representation = str(t2i) self.assertIn(str(len(t2i)), str_representation) self.assertIn(">UNK<", str_representation) self.assertIn(">EOS<", str_representation) self.assertIn(">PAD<", str_representation)
def setUp(self): corpus = [ "this is a test sentence .", "the mailman bites the dog .", "colorless green ideas sleep furiously .", "the horse raised past the barn fell .", ] self.sentences = [ "this green dog furiously bites the horse . <eos>", "sleep is a past test . <eos> <pad> <pad>", "the test is a barn . <eos> <pad> <pad>", ] self.t2i = T2I.build(corpus)
def test_correct_indexing(self): """ Test if indexing of new tokens is done correctly if the indices in the T2I class so far are arbitrary. In that case, indexing should be continued from the highest index. """ t2i = T2I.from_file(self.vocab_path3) highest_index = max(t2i.indices()) test_sent = "These are definitely new non-random tokens ." t2i = t2i.extend(test_sent) self.assertTrue( all([t2i[token] > highest_index for token in test_sent.split(" ")]))
def test_constant_memory_usage(self): """ Make sure that a T2I object doesn't allocate more memory when unknown tokens are being looked up (like defaultdicts do). """ t2i = T2I.build(self.test_corpus1) old_len = len(t2i) old_mem_usage = sys.getsizeof(t2i) # Look up unknown tokens for token in [random_str(5) for _ in range(10)]: t2i[token] new_len = len(t2i) new_mem_usage = sys.getsizeof(t2i) self.assertEqual(old_len, new_len) self.assertEqual(old_mem_usage, new_mem_usage)
def test_automatic_padding(self): """ Test whether the automatic padding functionality works as expected. """ t2i = T2I.build(self.test_corpus1) # Now index corpus with sequences of uneven length corpus = ["A A A", "D", "D A", "C B A B A B"] # pad_to argument error cases with self.assertRaises(AssertionError): t2i(corpus, pad_to="min") with self.assertRaises(AssertionError): t2i(corpus, pad_to=0) with self.assertRaises(TypeError): t2i(corpus, pad_to=bool) # Max padding indexed_corpus = t2i(corpus, pad_to="max") seq_lengths = [len(seq) for seq in indexed_corpus] self.assertEqual(len(set(seq_lengths)), 1) self.assertTrue(all([seq_len == 6 for seq_len in seq_lengths])) self.assertTrue( all([t2i[t2i.pad_token] in seq for seq in indexed_corpus[:3]])) # Specified padding indexed_corpus2 = t2i(corpus, pad_to=10) seq_lengths2 = [len(seq) for seq in indexed_corpus2] self.assertEqual(len(set(seq_lengths2)), 1) self.assertTrue(all([seq_len == 10 for seq_len in seq_lengths2])) self.assertTrue( all([t2i[t2i.pad_token] in seq for seq in indexed_corpus2])) # Test warning with warnings.catch_warnings(record=True) as caught_warnings: t2i(corpus, pad_to=2) self.assertEqual(len(caught_warnings), 2)
def test_extend(self): """ Test extending an existing index with an additional corpus. """ t2i = T2I.build(self.test_corpus3) additional_corpus = "These are new words" t2i = t2i.extend(additional_corpus) for token in additional_corpus.split(" "): self.assertIn(token, t2i) self.assertEqual(token, t2i.i2t[t2i[token]]) test_sentence = "This is a new sentence" indexed_test_sentence = [0, 1, 2, 10, 4] self.assertEqual(t2i.index(test_sentence), indexed_test_sentence) self._assert_indexing_consistency(test_sentence, t2i) # Test whether i2t was updated self.assertTrue( all([t2i[token] in t2i.i2t for token in additional_corpus.split()]))
def test_max_size(self): """ Test whether indexing stops once maximum specified size of T2I object was reached. """ # 1. Test during init index = {n: n for n in range(10)} t2i1 = T2I(index, max_size=3) self.assertEqual(len(t2i1), 3) self.assertTrue(all([i not in t2i1 for i in range(3, 10)])) # With special tokens t2i2 = T2I(index, max_size=10, special_tokens=("<mask>", "<flask>")) self.assertEqual(len(t2i2), 10) self.assertTrue(all([i not in t2i2 for i in range(6, 10)])) # 2. Test using build() corpus = "this is a long test sentence with exactly boring words" t2i3 = T2I.build(corpus, max_size=3) self.assertEqual(len(t2i3), 3) self.assertTrue( all([token not in t2i3 for token in corpus.split()[3:]])) self.assertTrue(all([i not in t2i3.indices() for i in range(3, 10)])) # With special tokens t2i4 = T2I.build(corpus, max_size=10, special_tokens=("<mask>", "<flask>")) self.assertEqual(len(t2i4), 10) self.assertTrue( all([token not in t2i4 for token in corpus.split()[6:]])) # 3. Test when building from file t2i5 = T2I.from_file(self.vocab_path, max_size=18) self.assertEqual(len(t2i5), 18) self.assertTrue(all([token not in t2i5 for token in self.tokens[16:]])) # With special tokens t2i6 = T2I.from_file(self.vocab_path, max_size=21, special_tokens=("<mask>", "<flask>")) self.assertEqual(len(t2i6), 21) self.assertTrue(all([token not in t2i6 for token in self.tokens[17:]]))
def test_building_from_file(self): """ Test building a T2I object from a vocab file. """ # ### Proper vocab files ### # First vocab file format: One token per line t2i1 = T2I.from_file(self.vocab_path1) self.assertTrue([ t2i1[token] == idx for token, idx in zip(self.tokens, range(len(self.tokens))) ]) # Second vocab file format: Token and index, separated by tab t2i2 = T2I.from_file(self.vocab_path2) self.assertTrue([ t2i2[token] == idx for token, idx in zip(self.tokens, self.indices2) ]) # Second vocab file format, this time with higher indices t2i3 = T2I.from_file(self.vocab_path3) self.assertTrue([ t2i3[token] == idx for token, idx in zip(self.tokens, self.indices3) ]) # Second vocab file format, but with different delimiter t2i4 = T2I.from_file(self.vocab_path4, delimiter="###") self.assertTrue([ t2i4[token] == idx for token, idx in zip(self.tokens, self.indices2) ]) # unk, eos, special tokens already in vocab file t2i5 = T2I.from_file(self.vocab_path5, special_tokens=("<mask>", "<flask>")) self.assertEqual(t2i1["<eos>"], t2i5["<eos>"]) self.assertEqual(t2i1["<unk>"], t2i5["<unk>"]) # unk, eos, special tokens already in vocab file, second format t2i5b = T2I.from_file(self.vocab_path5b, special_tokens=("<mask>", "<flask>")) self.assertEqual(t2i1["<eos>"], t2i5b["<eos>"]) self.assertEqual(t2i1["<unk>"], t2i5b["<unk>"]) # ### Improper vocab files ### # Nonsensical format with self.assertRaises(ValueError): T2I.from_file(self.vocab_path6) # Mixed format with self.assertRaises(ValueError): T2I.from_file(self.vocab_path7) # Too many columns with self.assertRaises(ValueError): T2I.from_file(self.vocab_path8) # Second format but no ints in second column with self.assertRaises(ValueError): T2I.from_file(self.vocab_path9)
def test_special_token_init(self): """ Test init where unk, eos and pad token are erroneously also specified as special tokens. """ for token in [STD_UNK, STD_EOS, STD_PAD]: with self.assertRaises(AssertionError): T2I(special_tokens=[token]) with self.assertRaises(AssertionError): T2I.build(self.tokens, special_tokens=[token]) with self.assertRaises(AssertionError): T2I(unk_token="#UNK#", special_tokens=["#UNK#"]) with self.assertRaises(AssertionError): T2I.build(self.tokens, unk_token="#UNK#", special_tokens=["#UNK#"]) with self.assertRaises(AssertionError): T2I(unk_token="#PAD#", special_tokens=["#PAD#"]) with self.assertRaises(AssertionError): T2I.build(self.tokens, unk_token="#PAD#", special_tokens=["#PAD#"]) with self.assertRaises(AssertionError): T2I(unk_token="#EOS#", special_tokens=["#EOS#"]) with self.assertRaises(AssertionError): T2I.build(self.tokens, unk_token="#EOS#", special_tokens=["#EOS#"])
def setUp(self): test_corpus = "This is a long test sentence . It contains many words." self.t2i = T2I.build(test_corpus)