def _gen_data(self): ds = tf.data.Dataset.from_generator( self._gen, (tf.string, tf.int64), (tf.TensorShape([]), tf.TensorShape([]))) data = text_dataloader.TextClassifierDataLoader( ds, self.TEXT_PER_CLASS * 2, 2, ['pos', 'neg']) return data
def test_split(self): ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) data = text_dataloader.TextClassifierDataLoader( ds, 4, 2, ['pos', 'neg']) train_data, test_data = data.split(0.5, shuffle=False) self.assertEqual(train_data.size, 2) for i, elem in enumerate(train_data.dataset): self.assertTrue((elem.numpy() == np.array([i, 1])).all()) self.assertEqual(train_data.num_classes, 2) self.assertEqual(train_data.index_to_label, ['pos', 'neg']) self.assertEqual(test_data.size, 2) for i, elem in enumerate(test_data.dataset): self.assertTrue((elem.numpy() == np.array([i, 0])).all()) self.assertEqual(test_data.num_classes, 2) self.assertEqual(test_data.index_to_label, ['pos', 'neg'])