Exemplo n.º 1
0
    def test_build_and_extend_consistency(self):
        """
        Make sure that index is built correctly no matter whether the input to build() is a single sentence or a list of
        sentences.
        """
        # Test build()
        test_sentence = "This is a test sentence"
        test_corpus = ["This is a", "test sentence"]
        test_tokenized_corpus = [["This", "is", "a"], ["test", "sentence"]]

        t2i1 = T2I.build(test_sentence)
        t2i2 = T2I.build(test_corpus)
        t2i3 = T2I.build(test_tokenized_corpus)
        self.assertEqual(t2i1, t2i2, t2i3)

        # Test extend()
        test_sentence2 = "These are new words"
        test_corpus2 = ["These are", "new words"]
        test_tokenized_corpus2 = [["These", "are"], ["new", "words"]]

        self.assertEqual(t2i1.extend(test_sentence2),
                         t2i2.extend(test_corpus2),
                         t2i3.extend(test_tokenized_corpus2))

        # Test extend with a mix of types
        self.assertEqual(t2i1.extend(test_tokenized_corpus2),
                         t2i2.extend(test_sentence2),
                         t2i3.extend(test_corpus2))
Exemplo n.º 2
0
    def test_eq(self):
        """
        Test the __eq__ function.
        """
        t2i1 = T2I.build(self.test_corpus1)
        t2i2 = T2I.build(self.test_corpus1)
        t2i3 = T2I.build(self.test_corpus2)

        self.assertTrue((t2i1 == t2i2))
        self.assertFalse((t2i1 == t2i3))
        self.assertFalse((t2i1 == t2i1._index))
Exemplo n.º 3
0
    def test_count_build(self):
        """
        Test whether tokens are ignored when using build() when their frequency is too low.
        """
        with warnings.catch_warnings(record=True) as caught_warnings:
            T2I.build(self.corpus, counter=self.counter)
            self.assertEqual(len(caught_warnings), 1)

        t2i = T2I.build(self.corpus,
                        counter=self.counter,
                        min_freq=self.min_freq)
        self.assertTrue(
            self._check_freq_filtering(t2i, self.counter, self.min_freq))
Exemplo n.º 4
0
    def test_special_token_init(self):
        """
        Test init where unk, eos and pad token are erroneously also specified as special tokens.
        """
        for token in [STD_UNK, STD_EOS, STD_PAD]:
            with self.assertRaises(AssertionError):
                T2I(special_tokens=[token])

            with self.assertRaises(AssertionError):
                T2I.build(self.tokens, special_tokens=[token])

        with self.assertRaises(AssertionError):
            T2I(unk_token="#UNK#", special_tokens=["#UNK#"])

        with self.assertRaises(AssertionError):
            T2I.build(self.tokens, unk_token="#UNK#", special_tokens=["#UNK#"])

        with self.assertRaises(AssertionError):
            T2I(unk_token="#PAD#", special_tokens=["#PAD#"])

        with self.assertRaises(AssertionError):
            T2I.build(self.tokens, unk_token="#PAD#", special_tokens=["#PAD#"])

        with self.assertRaises(AssertionError):
            T2I(unk_token="#EOS#", special_tokens=["#EOS#"])

        with self.assertRaises(AssertionError):
            T2I.build(self.tokens, unk_token="#EOS#", special_tokens=["#EOS#"])
Exemplo n.º 5
0
    def test_custom_special_tokens_indexing(self):
        """
        Test indexing with custom eos / unk token.
        """
        t2i = T2I.build(
            self.test_corpus3,
            unk_token="#UNK#",
            eos_token="#EOS#",
            pad_token="#PAD#",
            special_tokens=("#MASK#", "#FLASK#"),
        )

        self.assertEqual(t2i.index(self.test_corpus5),
                         self.indexed_test_corpus45)
        self.assertEqual(t2i.index(self.test_corpus5b),
                         self.indexed_test_corpus45)
        self._assert_indexing_consistency(self.test_corpus5, t2i)
        self.assertIn("#MASK#", t2i)
        self.assertIn("#FLASK#", t2i)
        string_repr = str(t2i)
        self.assertIn("#MASK#", string_repr)
        self.assertIn("#FLASK#", string_repr)
        self.assertEqual(t2i.index(self.test_corpus5c),
                         self.indexed_test_corpus5c)

        # Make sure special tokens are still there after extend()
        extended_t2i = t2i.extend(self.test_corpus4)
        self.assertIn("#MASK#", extended_t2i)
        self.assertIn("#FLASK#", extended_t2i)
        extended_string_repr = str(extended_t2i)
        self.assertIn("#MASK#", extended_string_repr)
        self.assertIn("#FLASK#", extended_string_repr)
        self.assertEqual(extended_t2i.index(self.test_corpus5c),
                         self.indexed_test_corpus5c2)
Exemplo n.º 6
0
 def test_immutability(self):
     """
     Test whether the T2I stays immutable after object init.
     """
     t2i = T2I.build(self.test_corpus1)
     with self.assertRaises(TypeError):
         t2i["banana"] = 66
Exemplo n.º 7
0
    def test_torchtext_compatibility(self):
        """
        Test whether the vocab object is compatible with the torchtext Vocab class.
        """
        t2i = T2I.build(self.test_corpus1)

        self.assertEqual(t2i.t2i, t2i.stoi)
        self.assertEqual(t2i.i2t, t2i.itos)
Exemplo n.º 8
0
    def test_eos_indexing(self):
        """
        Test indexing with (default) end-of-sequence token.
        """
        t2i = T2I.build(self.test_corpus3)

        self.assertEqual(t2i.index(self.test_corpus3b),
                         self.indexed_test_corpus3)
        self._assert_indexing_consistency(self.test_corpus3b, t2i)
Exemplo n.º 9
0
    def test_default_indexing(self):
        """
        Test normal indexing case.
        """
        t2i = T2I.build(self.test_corpus1)

        self.assertEqual(t2i.index(self.test_corpus1),
                         self.indexed_test_corpus1)
        self._assert_indexing_consistency(self.test_corpus1, t2i)
Exemplo n.º 10
0
    def test_iter(self):
        """
        Test the __iter__ method.
        """
        t2i = T2I.build(self.test_corpus1)

        contents = set([(k, v) for k, v in t2i])
        expected_contents = {("A", 0), ("B", 1), ("C", 2), ("D", 3), ("E", 4),
                             ("<unk>", 5), ("<eos>", 6), ("<pad>", 7)}
        self.assertEqual(expected_contents, contents)
Exemplo n.º 11
0
    def test_serialization(self):
        """ The above. """
        t2i = T2I.build(" ".join([
            random_str(random.randint(3, 10))
            for _ in range(random.randint(20, 40))
        ]))

        t2i.save(self.path)

        self.assertEqual(T2I.load(self.path), t2i)
Exemplo n.º 12
0
    def test_max_size(self):
        """
        Test whether indexing stops once maximum specified size of T2I object was reached.
        """
        # 1. Test during init
        index = {n: n for n in range(10)}
        t2i1 = T2I(index, max_size=3)
        self.assertEqual(len(t2i1), 3)
        self.assertTrue(all([i not in t2i1 for i in range(3, 10)]))

        # With special tokens
        t2i2 = T2I(index, max_size=10, special_tokens=("<mask>", "<flask>"))
        self.assertEqual(len(t2i2), 10)
        self.assertTrue(all([i not in t2i2 for i in range(6, 10)]))

        # 2. Test using build()
        corpus = "this is a long test sentence with exactly boring words"
        t2i3 = T2I.build(corpus, max_size=3)
        self.assertEqual(len(t2i3), 3)
        self.assertTrue(
            all([token not in t2i3 for token in corpus.split()[3:]]))
        self.assertTrue(all([i not in t2i3.indices() for i in range(3, 10)]))

        # With special tokens
        t2i4 = T2I.build(corpus,
                         max_size=10,
                         special_tokens=("<mask>", "<flask>"))
        self.assertEqual(len(t2i4), 10)
        self.assertTrue(
            all([token not in t2i4 for token in corpus.split()[6:]]))

        # 3. Test when building from file
        t2i5 = T2I.from_file(self.vocab_path, max_size=18)
        self.assertEqual(len(t2i5), 18)
        self.assertTrue(all([token not in t2i5 for token in self.tokens[16:]]))

        # With special tokens
        t2i6 = T2I.from_file(self.vocab_path,
                             max_size=21,
                             special_tokens=("<mask>", "<flask>"))
        self.assertEqual(len(t2i6), 21)
        self.assertTrue(all([token not in t2i6 for token in self.tokens[17:]]))
Exemplo n.º 13
0
    def test_unk_indexing(self):
        """
        Test indexing with unknown words.
        """
        t2i = T2I.build(self.test_corpus3)

        self.assertEqual(t2i.index(self.test_corpus4),
                         self.indexed_test_corpus45)
        self.assertEqual(t2i.index(self.test_corpus4b),
                         self.indexed_test_corpus45)
        self._assert_indexing_consistency(self.test_corpus4, t2i)
Exemplo n.º 14
0
    def test_delimiter_indexing(self):
        """
        Test indexing with different delimiter.
        """
        t2i = T2I.build(self.test_corpus2, delimiter="-")

        self.assertEqual(t2i.index(self.test_corpus2, delimiter="-"),
                         self.indexed_test_corpus2)
        self._assert_indexing_consistency(self.test_corpus2,
                                          t2i,
                                          joiner="-",
                                          delimiter="-")
Exemplo n.º 15
0
    def test_iter(self):
        """
        Test the __iter__, tokens() and indices() functions.
        """
        t2i = T2I.build(self.test_corpus1)

        contents = set([(k, v) for k, v in t2i])
        expected_contents = [("A", 0), ("B", 1), ("C", 2), ("D", 3), ("E", 4),
                             ("<unk>", 5), ("<eos>", 6), ("<pad>", 7)]
        self.assertEqual(set(expected_contents), contents)
        self.assertEqual(list(zip(*expected_contents))[0], t2i.tokens())
        self.assertEqual(list(zip(*expected_contents))[1], t2i.indices())
Exemplo n.º 16
0
    def setUp(self):
        corpus = [
            "this is a test sentence .",
            "the mailman bites the dog .",
            "colorless green ideas sleep furiously .",
            "the horse raised past the barn fell .",
        ]
        self.sentences = [
            "this green dog furiously bites the horse . <eos>",
            "sleep is a past test . <eos> <pad> <pad>",
            "the test is a barn . <eos> <pad> <pad>",
        ]

        self.t2i = T2I.build(corpus)
Exemplo n.º 17
0
    def test_representation(self):
        """
        Test whether the string representation works correctly.
        """
        t2i = T2I.build(self.test_corpus1,
                        unk_token=">UNK<",
                        eos_token=">EOS<",
                        pad_token=">PAD<")
        str_representation = str(t2i)

        self.assertIn(str(len(t2i)), str_representation)
        self.assertIn(">UNK<", str_representation)
        self.assertIn(">EOS<", str_representation)
        self.assertIn(">PAD<", str_representation)
Exemplo n.º 18
0
    def test_constant_memory_usage(self):
        """
        Make sure that a T2I object doesn't allocate more memory when unknown tokens are being looked up (like
        defaultdicts do).
        """
        t2i = T2I.build(self.test_corpus1)
        old_len = len(t2i)
        old_mem_usage = sys.getsizeof(t2i)

        # Look up unknown tokens
        for token in [random_str(5) for _ in range(10)]:
            t2i[token]

        new_len = len(t2i)
        new_mem_usage = sys.getsizeof(t2i)

        self.assertEqual(old_len, new_len)
        self.assertEqual(old_mem_usage, new_mem_usage)
Exemplo n.º 19
0
    def test_automatic_padding(self):
        """
        Test whether the automatic padding functionality works as expected.
        """
        t2i = T2I.build(self.test_corpus1)

        # Now index corpus with sequences of uneven length
        corpus = ["A A A", "D", "D A", "C B A B A B"]

        # pad_to argument error cases
        with self.assertRaises(AssertionError):
            t2i(corpus, pad_to="min")

        with self.assertRaises(AssertionError):
            t2i(corpus, pad_to=0)

        with self.assertRaises(TypeError):
            t2i(corpus, pad_to=bool)

        # Max padding
        indexed_corpus = t2i(corpus, pad_to="max")
        seq_lengths = [len(seq) for seq in indexed_corpus]

        self.assertEqual(len(set(seq_lengths)), 1)
        self.assertTrue(all([seq_len == 6 for seq_len in seq_lengths]))
        self.assertTrue(
            all([t2i[t2i.pad_token] in seq for seq in indexed_corpus[:3]]))

        # Specified padding
        indexed_corpus2 = t2i(corpus, pad_to=10)
        seq_lengths2 = [len(seq) for seq in indexed_corpus2]

        self.assertEqual(len(set(seq_lengths2)), 1)
        self.assertTrue(all([seq_len == 10 for seq_len in seq_lengths2]))
        self.assertTrue(
            all([t2i[t2i.pad_token] in seq for seq in indexed_corpus2]))

        # Test warning
        with warnings.catch_warnings(record=True) as caught_warnings:
            t2i(corpus, pad_to=2)

            self.assertEqual(len(caught_warnings), 2)
Exemplo n.º 20
0
    def test_extend(self):
        """
        Test extending an existing index with an additional corpus.
        """
        t2i = T2I.build(self.test_corpus3)
        additional_corpus = "These are new words"

        t2i = t2i.extend(additional_corpus)

        for token in additional_corpus.split(" "):
            self.assertIn(token, t2i)
            self.assertEqual(token, t2i.i2t[t2i[token]])

        test_sentence = "This is a new sentence"
        indexed_test_sentence = [0, 1, 2, 10, 4]
        self.assertEqual(t2i.index(test_sentence), indexed_test_sentence)
        self._assert_indexing_consistency(test_sentence, t2i)

        # Test whether i2t was updated
        self.assertTrue(
            all([t2i[token] in t2i.i2t
                 for token in additional_corpus.split()]))
Exemplo n.º 21
0
 def setUp(self):
     test_corpus = "This is a long test sentence . It contains many words."
     self.t2i = T2I.build(test_corpus)