Exemplo n.º 1
0
    def test_custom_policy(self):
        """Test autoaugment with a custom policy."""
        image = tf.zeros((224, 224, 3), dtype=tf.uint8)
        augmenter = augment.AutoAugment(policies=self._generate_test_policy())
        aug_image = augmenter.distort(image)

        self.assertEqual((224, 224, 3), aug_image.shape)
Exemplo n.º 2
0
    def test_invalid_custom_policy_shape(self):
        """Test autoaugment with wrong shape in the custom policy."""
        policy = [[('Equalize', 0.8, 1, 1), ('Shear', 0.8, 4, 1)],
                  [('TranslateY', 0.6, 3, 1), ('Rotate', 0.9, 3, 1)]]

        with self.assertRaisesRegex(
                ValueError, r'Expected \(:, :, 3\) but got \(2, 2, 4\)'):
            augment.AutoAugment(policies=policy)
Exemplo n.º 3
0
    def test_invalid_custom_policy_key(self):
        """Test autoaugment with invalid key in the custom policy."""
        image = tf.zeros((224, 224, 3), dtype=tf.uint8)
        policy = [[('AAAAA', 0.8, 1), ('Shear', 0.8, 4)],
                  [('TranslateY', 0.6, 3), ('Rotate', 0.9, 3)]]
        augmenter = augment.AutoAugment(policies=policy)

        with self.assertRaisesRegex(KeyError, '\'AAAAA\''):
            augmenter.distort(image)
Exemplo n.º 4
0
    def test_autoaugment(self):
        """Smoke test to be sure there are no syntax errors."""
        image = tf.zeros((224, 224, 3), dtype=tf.uint8)

        for policy in self.AVAILABLE_POLICIES:
            augmenter = augment.AutoAugment(augmentation_name=policy)
            aug_image = augmenter.distort(image)

            self.assertEqual((224, 224, 3), aug_image.shape)
Exemplo n.º 5
0
    def test_invalid_custom_policy_ndim(self):
        """Test autoaugment with wrong dimension in the custom policy."""
        policy = [[('Equalize', 0.8, 1), ('Shear', 0.8, 4)],
                  [('TranslateY', 0.6, 3), ('Rotate', 0.9, 3)]]
        policy = [[policy]]

        with self.assertRaisesRegex(
                ValueError,
                r'Expected \(:, :, 3\) but got \(1, 1, 2, 2, 3\).'):
            augment.AutoAugment(policies=policy)
Exemplo n.º 6
0
    def test_invalid_custom_sub_policy(self, sub_policy, value):
        """Test autoaugment with out-of-range values in the custom policy."""
        image = tf.zeros((224, 224, 3), dtype=tf.uint8)
        policy = self._generate_test_policy()
        policy[0][0] = sub_policy
        augmenter = augment.AutoAugment(policies=policy)

        with self.assertRaisesRegex(
                tf.errors.InvalidArgumentError,
                r'Expected \'tf.Tensor\(False, shape=\(\), dtype=bool\)\' to be true. '
                r'Summarized data: ({})'.format(value)):
            augmenter.distort(image)