def make_augmentation_config(data_config, num_labels): """Returns dataset augmentation configuration.""" random_config = datasets.RandomizedAugmentationConfig( rotation_probability=data_config.rotation_probability, smooth_probability=data_config.smooth_probability, contrast_probability=data_config.contrast_probability, resize_probability=data_config.resize_probability, negate_probability=data_config.negate_probability, roll_probability=data_config.roll_probability, angle_range=data_config.angle_range, rotate_by_90=data_config.rotate_by_90) if data_config.per_label_augmentation: with tf.variable_scope('augmentations'): return datasets.AugmentationConfig(children=[ datasets.AugmentationConfig(random_config=random_config) for _ in range(num_labels) ]) else: return datasets.AugmentationConfig(random_config=random_config)
def test_full_augmentations(self): """Testing image processing with all augmentations.""" aug_config = datasets.AugmentationConfig( random_config=datasets.RandomizedAugmentationConfig( rotation_probability=1.0, smooth_probability=1.0, contrast_probability=1.0, resize_probability=1.0, negate_probability=1.0)) self._test_augmentations(aug_config)
def test_augmentation_config_randomization(self): """Testing randomization in AugmentationConfig.""" aug_config = datasets.AugmentationConfig( random_config=datasets.RandomizedAugmentationConfig()) rand_op = aug_config.randomize_op() angle = aug_config.angle.value with self.session() as sess: sess.run(tf.global_variables_initializer()) sess.run(rand_op) v1 = sess.run(angle) v2 = sess.run(angle) sess.run(rand_op) v3 = sess.run(angle) self.assertAlmostEqual(v1, v2) self.assertNotAlmostEqual(v1, v3)
def test_get_batch(self): """Tests image and label generation in the `TaskGenerator`.""" batch_size, image_size = 8, 4 data = self._make_data(batch_size=batch_size, image_size=image_size) gen = datasets.TaskGenerator(data, num_labels=4, image_size=image_size) aug_config = datasets.AugmentationConfig( random_config=datasets.RandomizedAugmentationConfig( rotation_probability=0.0, smooth_probability=0.0, contrast_probability=0.0)) images, labels, classes = gen.get_batch(batch_size=batch_size, config=aug_config) with self.session() as sess: sess.run(tf.global_variables_initializer()) v_images, v_labels, v_classes = sess.run((images, labels, classes)) self.assertEqual(v_images.shape, (batch_size, image_size, image_size, 1)) self.assertEqual(v_labels.shape, (batch_size, )) self.assertEqual(v_classes.shape, (batch_size, ))