def _gen_data(self):
     ds = tf.data.Dataset.from_generator(
         self._gen, (tf.string, tf.int64),
         (tf.TensorShape([]), tf.TensorShape([])))
     data = text_dataloader.TextClassifierDataLoader(
         ds, self.TEXT_PER_CLASS * 2, 2, ['pos', 'neg'])
     return data
Example #2
0
    def test_split(self):
        ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
                                                 [1, 0]])
        data = text_dataloader.TextClassifierDataLoader(
            ds, 4, 2, ['pos', 'neg'])
        train_data, test_data = data.split(0.5, shuffle=False)

        self.assertEqual(train_data.size, 2)
        for i, elem in enumerate(train_data.dataset):
            self.assertTrue((elem.numpy() == np.array([i, 1])).all())
        self.assertEqual(train_data.num_classes, 2)
        self.assertEqual(train_data.index_to_label, ['pos', 'neg'])

        self.assertEqual(test_data.size, 2)
        for i, elem in enumerate(test_data.dataset):
            self.assertTrue((elem.numpy() == np.array([i, 0])).all())
        self.assertEqual(test_data.num_classes, 2)
        self.assertEqual(test_data.index_to_label, ['pos', 'neg'])