예제 #1
0
    def test_all_policy_ops_video_with_bboxes(self):
        """Smoke test to be sure all video augmentation functions can execute."""

        prob = 1
        magnitude = 10
        replace_value = [128] * 3
        cutout_const = 100
        translate_const = 250

        image = tf.ones((2, 224, 224, 3), dtype=tf.uint8)
        bboxes = tf.ones((2, 2, 4), dtype=tf.float32)

        for op_name in augment.NAME_TO_FUNC:
            func, _, args = augment._parse_policy_info(op_name, prob,
                                                       magnitude,
                                                       replace_value,
                                                       cutout_const,
                                                       translate_const)
            if op_name in {
                    'Rotate_BBox',
                    'ShearX_BBox',
                    'ShearY_BBox',
                    'TranslateX_BBox',
                    'TranslateY_BBox',
                    'TranslateY_Only_BBoxes',
            }:
                with self.assertRaises(ValueError):
                    func(image, bboxes, *args)
            else:
                image, bboxes = func(image, bboxes, *args)

        self.assertEqual((2, 224, 224, 3), image.shape)
        self.assertEqual((2, 2, 4), bboxes.shape)
예제 #2
0
    def test_all_policy_ops_with_bboxes(self):
        """Smoke test to be sure all augmentation functions can execute."""

        prob = 1
        magnitude = 10
        replace_value = [128] * 3
        cutout_const = 100
        translate_const = 250

        image = tf.ones((224, 224, 3), dtype=tf.uint8)
        bboxes = tf.ones((2, 4), dtype=tf.float32)

        for op_name in augment.NAME_TO_FUNC:
            func, _, args = augment._parse_policy_info(op_name, prob,
                                                       magnitude,
                                                       replace_value,
                                                       cutout_const,
                                                       translate_const)
            image, bboxes = func(image, bboxes, *args)

        self.assertEqual((224, 224, 3), image.shape)
        self.assertEqual((2, 4), bboxes.shape)
예제 #3
0
    def test_all_policy_ops_video(self):
        """Smoke test to be sure all video augmentation functions can execute."""

        prob = 1
        magnitude = 10
        replace_value = [128] * 3
        cutout_const = 100
        translate_const = 250

        image = tf.ones((2, 224, 224, 3), dtype=tf.uint8)
        bboxes = None

        for op_name in augment.NAME_TO_FUNC.keys(
        ) - augment.REQUIRE_BOXES_FUNCS:
            func, _, args = augment._parse_policy_info(op_name, prob,
                                                       magnitude,
                                                       replace_value,
                                                       cutout_const,
                                                       translate_const)
            image, bboxes = func(image, bboxes, *args)

        self.assertEqual((2, 224, 224, 3), image.shape)
        self.assertIsNone(bboxes)