示例#1
0
class TestTextModule(unittest.TestCase):
    def setUp(self):
        self.tokens = ['a', 'b', 'c', 'd', 'e', 'f']
        corpus = ['a b c', 'b c d d', 'c b e c f']
        ids = ['u1', 'u2', 'u3']
        # frequency ranking: c > b > d > a > e > f
        self.module = TextModule(corpus=corpus, ids=ids, max_vocab=6)
        self.module.build({'u1': 0, 'u2': 1, 'u3': 2})
        self.token_ids = (self.module.vocab.tok2idx[tok]
                          for tok in self.tokens)

    def test_init(self):
        self.assertCountEqual(self.module.vocab.idx2tok,
                              SPECIAL_TOKENS + self.tokens)

    def test_build(self):
        TextModule().build()
        TextModule(corpus=['abc']).build()
        TextModule(corpus=['abc']).build({'b': 0})
        TextModule(corpus=['abc'], ids=['a']).build({'b': 0})

    def test_sequences(self):
        (a, b, c, d, e, f) = self.token_ids

        self.assertListEqual(self.module.sequences,
                             [[a, b, c], [b, c, d, d], [c, b, e, c, f]])

    def test_batch_seq(self):
        (a, b, c, d, e, f) = self.token_ids

        batch_seqs = self.module.batch_seq([2, 1])
        self.assertEqual((2, 5), batch_seqs.shape)
        npt.assert_array_equal(batch_seqs,
                               np.asarray([[c, b, e, c, f], [b, c, d, d, 0]]))

        batch_seqs = self.module.batch_seq([0, 2], max_length=4)
        self.assertEqual((2, 4), batch_seqs.shape)
        npt.assert_array_equal(batch_seqs,
                               np.asarray([[a, b, c, 0], [c, b, e, c]]))

        self.module.sequences = None
        try:
            self.module.batch_seq([0])
        except ValueError:
            assert True

    def test_count_matrix(self):
        (a, b, c, d, e, f) = self.token_ids
        shift = len(SPECIAL_TOKENS)
        expected_counts = np.zeros_like(self.module.count_matrix.A)
        expected_counts[0, a - shift] = 1
        expected_counts[0, b - shift] = 1
        expected_counts[0, c - shift] = 1
        expected_counts[1, b - shift] = 1
        expected_counts[1, c - shift] = 1
        expected_counts[1, d - shift] = 2
        expected_counts[2, b - shift] = 1
        expected_counts[2, c - shift] = 2
        expected_counts[2, e - shift] = 1
        expected_counts[2, f - shift] = 1
        npt.assert_array_equal(self.module.count_matrix.A, expected_counts)

    def test_batch_bow(self):
        (a, b, c, d, e, f) = self.token_ids
        shift = len(SPECIAL_TOKENS)

        batch_bows = self.module.batch_bow([2, 1])
        self.assertEqual((2, self.module.max_vocab), batch_bows.shape)
        expected_bows = np.zeros_like(batch_bows)
        expected_bows[0, b - shift] = 1
        expected_bows[0, c - shift] = 2
        expected_bows[0, e - shift] = 1
        expected_bows[0, f - shift] = 1
        expected_bows[1, b - shift] = 1
        expected_bows[1, c - shift] = 1
        expected_bows[1, d - shift] = 2
        npt.assert_array_equal(batch_bows, expected_bows)

        batch_bows = self.module.batch_bow([0, 2],
                                           binary=True,
                                           keep_sparse=True)
        self.assertEqual((2, 6), batch_bows.shape)
        expected_bows = np.zeros_like(batch_bows.A)
        expected_bows[0, np.asarray([a, b, c]) - shift] = 1
        expected_bows[1, np.asarray([b, c, e, f]) - shift] = 1
        npt.assert_array_equal(batch_bows.A, expected_bows)

        self.module.count_matrix = None
        try:
            self.module.batch_bow([0])
        except ValueError:
            assert True

    def test_batch_bow_fallback(self):
        module = TextModule(features=np.asarray([[3, 2, 1], [4, 5, 6]]),
                            ids=['a', 'b'])
        module.build()
        npt.assert_array_equal(np.asarray([[3, 2, 1]]),
                               module.batch_bow(batch_ids=[0]))
示例#2
0
 def test_batch_bow_fallback(self):
     module = TextModule(features=np.asarray([[3, 2, 1], [4, 5, 6]]),
                         ids=['a', 'b'])
     module.build()
     npt.assert_array_equal(np.asarray([[3, 2, 1]]),
                            module.batch_bow(batch_ids=[0]))
示例#3
0
def test_init():
    md = TextModule()
    md.build(global_id_map=None)