コード例 #1
0
 def test_value_cifar_c(self, split):
   if not flags.FLAGS.fake_data:
     config = image_data_utils.DataConfig(
         split, corruption_value=.25, corruption_type='brightness')
     dataset = data_lib.build_dataset(config)
     image_shape = next(iter(dataset))[0].numpy().shape
     self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
コード例 #2
0
 def test_array_cifar_c(self, split):
   if not flags.FLAGS.fake_data:
     config = image_data_utils.DataConfig(
         split, corruption_level=4, corruption_type='glass_blur')
     dataset = data_lib.build_dataset(config)
     image_shape = next(iter(dataset))[0].numpy().shape
     self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
コード例 #3
0
 def test_roll_pixels(self, split):
     config = image_data_utils.DataConfig(split, roll_pixels=5)
     if not flags.FLAGS.fake_data:
         dataset = data_lib.build_dataset(config, BATCH_SIZE)
         image = next(iter(dataset))[0].numpy()
         self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE)
         self.assertAllInRange(image, 0., 1.)
         self.assertTrue((image > 1. / 255).any())
コード例 #4
0
 def test_fake_data(self, split):
     # config is ignored for fake data
     config = image_data_utils.DataConfig(split)
     dataset = data_lib.build_dataset(config, BATCH_SIZE, fake_data=True)
     image = next(iter(dataset))[0].numpy()
     self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE)
     self.assertAllInRange(image, 0., 1.)
     self.assertTrue((image > 1. / 255).any())
コード例 #5
0
 def test_alt_dataset(self):
   if not flags.FLAGS.fake_data:
     config = image_data_utils.DataConfig('test', alt_dataset_name='celeb_a')
     dataset = data_lib.build_dataset(config, BATCH_SIZE)
     image = next(iter(dataset))[0].numpy()
     self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE)
     self.assertAllInRange(image, 0., 1.)
     self.assertTrue((image > 1./255).any())
コード例 #6
0
 def test_value_imagenet_c(self, split):
   if not flags.FLAGS.fake_data:
     config = image_data_utils.DataConfig(
         split, corruption_value=.25, corruption_type='brightness')
     dataset = data_lib.build_dataset(config, BATCH_SIZE)
     image = next(iter(dataset))[0].numpy()
     self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE)
     self.assertAllInRange(image, 0., 1.)
     self.assertTrue((image > 1./255).any())
コード例 #7
0
 def test_static_cifar_c(self, split):
   if not flags.FLAGS.fake_data:
     config = image_data_utils.DataConfig(
         split, corruption_static=True, corruption_level=3,
         corruption_type='pixelate')
     if split in ['train', 'valid']:
       with self.assertRaises(ValueError):
         data_lib.build_dataset(config)
     else:
       dataset = data_lib.build_dataset(config)
       image_shape = next(iter(dataset))[0].numpy().shape
       self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
コード例 #8
0
 def test_static_imagenet_c(self, split):
   if not flags.FLAGS.fake_data:
     config = image_data_utils.DataConfig(
         split, corruption_static=True, corruption_level=3,
         corruption_type='pixelate')
     if split in ['train', 'valid']:
       with self.assertRaises(ValueError):
         data_lib.build_dataset(config, BATCH_SIZE)
     else:
       dataset = data_lib.build_dataset(config, BATCH_SIZE)
       image = next(iter(dataset))[0].numpy()
       self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE)
       self.assertAllInRange(image, 0., 1.)
       self.assertTrue((image > 1./255).any())
コード例 #9
0
 def test_roll_pixels(self, split):
     config = image_data_utils.DataConfig(split, roll_pixels=5)
     if not flags.FLAGS.fake_data:
         dataset = data_lib.build_dataset(config)
         image_shape = next(iter(dataset))[0].numpy().shape
         self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
コード例 #10
0
 def test_fake_data(self, split):
     # config is ignored for fake data
     config = image_data_utils.DataConfig(split)
     dataset = data_lib.build_dataset(config, fake_data=True)
     image_shape = next(iter(dataset))[0].numpy().shape
     self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)