def _gen_cmy_data(self): ds = tf.data.Dataset.from_generator( self._gen, (tf.uint8, tf.int64), (tf.TensorShape( [self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([]))) data = image_dataloader.ImageClassifierDataLoader( ds, self.IMAGES_PER_CLASS * 3, ['cyan', 'magenta', 'yellow']) return data
def test_split(self): ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) data = image_dataloader.ImageClassifierDataLoader(ds, 4, 2, ['pos', 'neg']) train_data, test_data = data.split(0.5) self.assertEqual(train_data.size, 2) for i, elem in enumerate(train_data.dataset): self.assertTrue((elem.numpy() == np.array([i, 1])).all()) self.assertEqual(train_data.num_classes, 2) self.assertEqual(train_data.index_to_label, ['pos', 'neg']) self.assertEqual(test_data.size, 2) for i, elem in enumerate(test_data.dataset): self.assertTrue((elem.numpy() == np.array([i, 0])).all()) self.assertEqual(test_data.num_classes, 2) self.assertEqual(test_data.index_to_label, ['pos', 'neg'])