def test_build_and_extend_consistency(self): """ Make sure that index is built correctly no matter whether the input to build() is a single sentence or a list of sentences. """ # Test build() test_sentence = "This is a test sentence" test_corpus = ["This is a", "test sentence"] test_tokenized_corpus = [["This", "is", "a"], ["test", "sentence"]] t2i1 = T2I.build(test_sentence) t2i2 = T2I.build(test_corpus) t2i3 = T2I.build(test_tokenized_corpus) self.assertEqual(t2i1, t2i2, t2i3) # Test extend() test_sentence2 = "These are new words" test_corpus2 = ["These are", "new words"] test_tokenized_corpus2 = [["These", "are"], ["new", "words"]] self.assertEqual(t2i1.extend(test_sentence2), t2i2.extend(test_corpus2), t2i3.extend(test_tokenized_corpus2)) # Test extend with a mix of types self.assertEqual(t2i1.extend(test_tokenized_corpus2), t2i2.extend(test_sentence2), t2i3.extend(test_corpus2))
def test_eq(self): """ Test the __eq__ function. """ t2i1 = T2I.build(self.test_corpus1) t2i2 = T2I.build(self.test_corpus1) t2i3 = T2I.build(self.test_corpus2) self.assertTrue((t2i1 == t2i2)) self.assertFalse((t2i1 == t2i3)) self.assertFalse((t2i1 == t2i1._index))
def test_count_build(self): """ Test whether tokens are ignored when using build() when their frequency is too low. """ with warnings.catch_warnings(record=True) as caught_warnings: T2I.build(self.corpus, counter=self.counter) self.assertEqual(len(caught_warnings), 1) t2i = T2I.build(self.corpus, counter=self.counter, min_freq=self.min_freq) self.assertTrue( self._check_freq_filtering(t2i, self.counter, self.min_freq))
def test_special_token_init(self): """ Test init where unk, eos and pad token are erroneously also specified as special tokens. """ for token in [STD_UNK, STD_EOS, STD_PAD]: with self.assertRaises(AssertionError): T2I(special_tokens=[token]) with self.assertRaises(AssertionError): T2I.build(self.tokens, special_tokens=[token]) with self.assertRaises(AssertionError): T2I(unk_token="#UNK#", special_tokens=["#UNK#"]) with self.assertRaises(AssertionError): T2I.build(self.tokens, unk_token="#UNK#", special_tokens=["#UNK#"]) with self.assertRaises(AssertionError): T2I(unk_token="#PAD#", special_tokens=["#PAD#"]) with self.assertRaises(AssertionError): T2I.build(self.tokens, unk_token="#PAD#", special_tokens=["#PAD#"]) with self.assertRaises(AssertionError): T2I(unk_token="#EOS#", special_tokens=["#EOS#"]) with self.assertRaises(AssertionError): T2I.build(self.tokens, unk_token="#EOS#", special_tokens=["#EOS#"])
def test_custom_special_tokens_indexing(self): """ Test indexing with custom eos / unk token. """ t2i = T2I.build( self.test_corpus3, unk_token="#UNK#", eos_token="#EOS#", pad_token="#PAD#", special_tokens=("#MASK#", "#FLASK#"), ) self.assertEqual(t2i.index(self.test_corpus5), self.indexed_test_corpus45) self.assertEqual(t2i.index(self.test_corpus5b), self.indexed_test_corpus45) self._assert_indexing_consistency(self.test_corpus5, t2i) self.assertIn("#MASK#", t2i) self.assertIn("#FLASK#", t2i) string_repr = str(t2i) self.assertIn("#MASK#", string_repr) self.assertIn("#FLASK#", string_repr) self.assertEqual(t2i.index(self.test_corpus5c), self.indexed_test_corpus5c) # Make sure special tokens are still there after extend() extended_t2i = t2i.extend(self.test_corpus4) self.assertIn("#MASK#", extended_t2i) self.assertIn("#FLASK#", extended_t2i) extended_string_repr = str(extended_t2i) self.assertIn("#MASK#", extended_string_repr) self.assertIn("#FLASK#", extended_string_repr) self.assertEqual(extended_t2i.index(self.test_corpus5c), self.indexed_test_corpus5c2)
def test_immutability(self): """ Test whether the T2I stays immutable after object init. """ t2i = T2I.build(self.test_corpus1) with self.assertRaises(TypeError): t2i["banana"] = 66
def test_torchtext_compatibility(self): """ Test whether the vocab object is compatible with the torchtext Vocab class. """ t2i = T2I.build(self.test_corpus1) self.assertEqual(t2i.t2i, t2i.stoi) self.assertEqual(t2i.i2t, t2i.itos)
def test_eos_indexing(self): """ Test indexing with (default) end-of-sequence token. """ t2i = T2I.build(self.test_corpus3) self.assertEqual(t2i.index(self.test_corpus3b), self.indexed_test_corpus3) self._assert_indexing_consistency(self.test_corpus3b, t2i)
def test_default_indexing(self): """ Test normal indexing case. """ t2i = T2I.build(self.test_corpus1) self.assertEqual(t2i.index(self.test_corpus1), self.indexed_test_corpus1) self._assert_indexing_consistency(self.test_corpus1, t2i)
def test_iter(self): """ Test the __iter__ method. """ t2i = T2I.build(self.test_corpus1) contents = set([(k, v) for k, v in t2i]) expected_contents = {("A", 0), ("B", 1), ("C", 2), ("D", 3), ("E", 4), ("<unk>", 5), ("<eos>", 6), ("<pad>", 7)} self.assertEqual(expected_contents, contents)
def test_serialization(self): """ The above. """ t2i = T2I.build(" ".join([ random_str(random.randint(3, 10)) for _ in range(random.randint(20, 40)) ])) t2i.save(self.path) self.assertEqual(T2I.load(self.path), t2i)
def test_max_size(self): """ Test whether indexing stops once maximum specified size of T2I object was reached. """ # 1. Test during init index = {n: n for n in range(10)} t2i1 = T2I(index, max_size=3) self.assertEqual(len(t2i1), 3) self.assertTrue(all([i not in t2i1 for i in range(3, 10)])) # With special tokens t2i2 = T2I(index, max_size=10, special_tokens=("<mask>", "<flask>")) self.assertEqual(len(t2i2), 10) self.assertTrue(all([i not in t2i2 for i in range(6, 10)])) # 2. Test using build() corpus = "this is a long test sentence with exactly boring words" t2i3 = T2I.build(corpus, max_size=3) self.assertEqual(len(t2i3), 3) self.assertTrue( all([token not in t2i3 for token in corpus.split()[3:]])) self.assertTrue(all([i not in t2i3.indices() for i in range(3, 10)])) # With special tokens t2i4 = T2I.build(corpus, max_size=10, special_tokens=("<mask>", "<flask>")) self.assertEqual(len(t2i4), 10) self.assertTrue( all([token not in t2i4 for token in corpus.split()[6:]])) # 3. Test when building from file t2i5 = T2I.from_file(self.vocab_path, max_size=18) self.assertEqual(len(t2i5), 18) self.assertTrue(all([token not in t2i5 for token in self.tokens[16:]])) # With special tokens t2i6 = T2I.from_file(self.vocab_path, max_size=21, special_tokens=("<mask>", "<flask>")) self.assertEqual(len(t2i6), 21) self.assertTrue(all([token not in t2i6 for token in self.tokens[17:]]))
def test_unk_indexing(self): """ Test indexing with unknown words. """ t2i = T2I.build(self.test_corpus3) self.assertEqual(t2i.index(self.test_corpus4), self.indexed_test_corpus45) self.assertEqual(t2i.index(self.test_corpus4b), self.indexed_test_corpus45) self._assert_indexing_consistency(self.test_corpus4, t2i)
def test_delimiter_indexing(self): """ Test indexing with different delimiter. """ t2i = T2I.build(self.test_corpus2, delimiter="-") self.assertEqual(t2i.index(self.test_corpus2, delimiter="-"), self.indexed_test_corpus2) self._assert_indexing_consistency(self.test_corpus2, t2i, joiner="-", delimiter="-")
def test_iter(self): """ Test the __iter__, tokens() and indices() functions. """ t2i = T2I.build(self.test_corpus1) contents = set([(k, v) for k, v in t2i]) expected_contents = [("A", 0), ("B", 1), ("C", 2), ("D", 3), ("E", 4), ("<unk>", 5), ("<eos>", 6), ("<pad>", 7)] self.assertEqual(set(expected_contents), contents) self.assertEqual(list(zip(*expected_contents))[0], t2i.tokens()) self.assertEqual(list(zip(*expected_contents))[1], t2i.indices())
def setUp(self): corpus = [ "this is a test sentence .", "the mailman bites the dog .", "colorless green ideas sleep furiously .", "the horse raised past the barn fell .", ] self.sentences = [ "this green dog furiously bites the horse . <eos>", "sleep is a past test . <eos> <pad> <pad>", "the test is a barn . <eos> <pad> <pad>", ] self.t2i = T2I.build(corpus)
def test_representation(self): """ Test whether the string representation works correctly. """ t2i = T2I.build(self.test_corpus1, unk_token=">UNK<", eos_token=">EOS<", pad_token=">PAD<") str_representation = str(t2i) self.assertIn(str(len(t2i)), str_representation) self.assertIn(">UNK<", str_representation) self.assertIn(">EOS<", str_representation) self.assertIn(">PAD<", str_representation)
def test_constant_memory_usage(self): """ Make sure that a T2I object doesn't allocate more memory when unknown tokens are being looked up (like defaultdicts do). """ t2i = T2I.build(self.test_corpus1) old_len = len(t2i) old_mem_usage = sys.getsizeof(t2i) # Look up unknown tokens for token in [random_str(5) for _ in range(10)]: t2i[token] new_len = len(t2i) new_mem_usage = sys.getsizeof(t2i) self.assertEqual(old_len, new_len) self.assertEqual(old_mem_usage, new_mem_usage)
def test_automatic_padding(self): """ Test whether the automatic padding functionality works as expected. """ t2i = T2I.build(self.test_corpus1) # Now index corpus with sequences of uneven length corpus = ["A A A", "D", "D A", "C B A B A B"] # pad_to argument error cases with self.assertRaises(AssertionError): t2i(corpus, pad_to="min") with self.assertRaises(AssertionError): t2i(corpus, pad_to=0) with self.assertRaises(TypeError): t2i(corpus, pad_to=bool) # Max padding indexed_corpus = t2i(corpus, pad_to="max") seq_lengths = [len(seq) for seq in indexed_corpus] self.assertEqual(len(set(seq_lengths)), 1) self.assertTrue(all([seq_len == 6 for seq_len in seq_lengths])) self.assertTrue( all([t2i[t2i.pad_token] in seq for seq in indexed_corpus[:3]])) # Specified padding indexed_corpus2 = t2i(corpus, pad_to=10) seq_lengths2 = [len(seq) for seq in indexed_corpus2] self.assertEqual(len(set(seq_lengths2)), 1) self.assertTrue(all([seq_len == 10 for seq_len in seq_lengths2])) self.assertTrue( all([t2i[t2i.pad_token] in seq for seq in indexed_corpus2])) # Test warning with warnings.catch_warnings(record=True) as caught_warnings: t2i(corpus, pad_to=2) self.assertEqual(len(caught_warnings), 2)
def test_extend(self): """ Test extending an existing index with an additional corpus. """ t2i = T2I.build(self.test_corpus3) additional_corpus = "These are new words" t2i = t2i.extend(additional_corpus) for token in additional_corpus.split(" "): self.assertIn(token, t2i) self.assertEqual(token, t2i.i2t[t2i[token]]) test_sentence = "This is a new sentence" indexed_test_sentence = [0, 1, 2, 10, 4] self.assertEqual(t2i.index(test_sentence), indexed_test_sentence) self._assert_indexing_consistency(test_sentence, t2i) # Test whether i2t was updated self.assertTrue( all([t2i[token] in t2i.i2t for token in additional_corpus.split()]))
def setUp(self): test_corpus = "This is a long test sentence . It contains many words." self.t2i = T2I.build(test_corpus)