def test_value_cifar_c(self, split): if not flags.FLAGS.fake_data: config = image_data_utils.DataConfig( split, corruption_value=.25, corruption_type='brightness') dataset = data_lib.build_dataset(config) image_shape = next(iter(dataset))[0].numpy().shape self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
def test_array_cifar_c(self, split): if not flags.FLAGS.fake_data: config = image_data_utils.DataConfig( split, corruption_level=4, corruption_type='glass_blur') dataset = data_lib.build_dataset(config) image_shape = next(iter(dataset))[0].numpy().shape self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
def test_roll_pixels(self, split): config = image_data_utils.DataConfig(split, roll_pixels=5) if not flags.FLAGS.fake_data: dataset = data_lib.build_dataset(config, BATCH_SIZE) image = next(iter(dataset))[0].numpy() self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE) self.assertAllInRange(image, 0., 1.) self.assertTrue((image > 1. / 255).any())
def test_fake_data(self, split): # config is ignored for fake data config = image_data_utils.DataConfig(split) dataset = data_lib.build_dataset(config, BATCH_SIZE, fake_data=True) image = next(iter(dataset))[0].numpy() self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE) self.assertAllInRange(image, 0., 1.) self.assertTrue((image > 1. / 255).any())
def test_alt_dataset(self): if not flags.FLAGS.fake_data: config = image_data_utils.DataConfig('test', alt_dataset_name='celeb_a') dataset = data_lib.build_dataset(config, BATCH_SIZE) image = next(iter(dataset))[0].numpy() self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE) self.assertAllInRange(image, 0., 1.) self.assertTrue((image > 1./255).any())
def test_value_imagenet_c(self, split): if not flags.FLAGS.fake_data: config = image_data_utils.DataConfig( split, corruption_value=.25, corruption_type='brightness') dataset = data_lib.build_dataset(config, BATCH_SIZE) image = next(iter(dataset))[0].numpy() self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE) self.assertAllInRange(image, 0., 1.) self.assertTrue((image > 1./255).any())
def test_static_cifar_c(self, split): if not flags.FLAGS.fake_data: config = image_data_utils.DataConfig( split, corruption_static=True, corruption_level=3, corruption_type='pixelate') if split in ['train', 'valid']: with self.assertRaises(ValueError): data_lib.build_dataset(config) else: dataset = data_lib.build_dataset(config) image_shape = next(iter(dataset))[0].numpy().shape self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
def test_static_imagenet_c(self, split): if not flags.FLAGS.fake_data: config = image_data_utils.DataConfig( split, corruption_static=True, corruption_level=3, corruption_type='pixelate') if split in ['train', 'valid']: with self.assertRaises(ValueError): data_lib.build_dataset(config, BATCH_SIZE) else: dataset = data_lib.build_dataset(config, BATCH_SIZE) image = next(iter(dataset))[0].numpy() self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE) self.assertAllInRange(image, 0., 1.) self.assertTrue((image > 1./255).any())
def test_roll_pixels(self, split): config = image_data_utils.DataConfig(split, roll_pixels=5) if not flags.FLAGS.fake_data: dataset = data_lib.build_dataset(config) image_shape = next(iter(dataset))[0].numpy().shape self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
def test_fake_data(self, split): # config is ignored for fake data config = image_data_utils.DataConfig(split) dataset = data_lib.build_dataset(config, fake_data=True) image_shape = next(iter(dataset))[0].numpy().shape self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)