예제 #1
0
    def test_register(self):
        """
        Test register.
        """
        dtype = "int"

        def add1(t, x):
            return x + 1

        def flip_sub_width(t, x):
            return x - t.width

        T.Transform.register_type(dtype, add1)
        T.HFlipTransform.register_type(dtype, flip_sub_width)

        transforms = T.TransformList([
            T.ScaleTransform(0, 0, 0, 0, 0),
            T.CropTransform(0, 0, 0, 0),
            T.HFlipTransform(3),
        ])
        self.assertEqual(transforms.apply_int(3), 2)

        # Testing __add__, __iadd__, __radd__, __len__.
        transforms = transforms + transforms
        transforms += transforms
        transforms = T.NoOpTransform() + transforms
        self.assertEqual(len(transforms), 13)

        with self.assertRaises(AssertionError):
            T.HFlipTransform.register_type(dtype, lambda x: 1)

        with self.assertRaises(AttributeError):
            transforms.no_existing
예제 #2
0
    def test_print_transform(self):
        t0 = T.HFlipTransform(width=100)
        self.assertEqual(str(t0), "HFlipTransform(width=100)")

        t = T.TransformList([T.NoOpTransform(), t0])
        self.assertEqual(str(t), f"TransformList[NoOpTransform(), {t0}]")

        t = T.BlendTransform(np.zeros((100, 100, 100)), 1.0, 1.0)
        self.assertEqual(
            str(t),
            "BlendTransform(src_image=..., src_weight=1.0, dst_weight=1.0)")
예제 #3
0
    def test_register_with_decorator(self):
        """
        Test register using decorator.
        """
        dtype = "float"

        @T.HFlipTransform.register_type(dtype)
        def add1(t, x):
            return x + 1

        transforms = T.TransformList([T.HFlipTransform(3)])
        self.assertEqual(transforms.apply_float(3), 4)
예제 #4
0
    def test_register(self):
        """
        Test register.
        """
        dtype = "int"

        def add1(t, x):
            return x + 1

        def flip_sub_width(t, x):
            return x - t.width

        T.Transform.register_type(dtype, add1)
        T.HFlipTransform.register_type(dtype, flip_sub_width)

        transforms = T.TransformList([
            T.ScaleTransform(0, 0, 0, 0, 0),
            T.CropTransform(0, 0, 0, 0),
            T.HFlipTransform(3),
        ])
        self.assertEqual(transforms.apply_int(3), 2)

        with self.assertRaises(AssertionError):
            T.HFlipTransform.register_type(dtype, lambda x: 1)
예제 #5
0
 def test_transformlist_flatten(self):
     t0 = T.HFlipTransform(width=100)
     t1 = T.ScaleTransform(3, 4, 5, 6)
     t2 = T.CropTransform(4, 5, 6, 7)
     t = T.TransformList([T.TransformList([t0, t1]), t2])
     self.assertEqual(len(t.transforms), 3)