示例#1
0
    def test_register(self):
        """
        Test register.
        """
        dtype = "int"

        def add1(t, x):
            return x + 1

        def flip_sub_width(t, x):
            return x - t.width

        T.Transform.register_type(dtype, add1)
        T.HFlipTransform.register_type(dtype, flip_sub_width)

        transforms = T.TransformList([
            T.ScaleTransform(0, 0, 0, 0, 0),
            T.CropTransform(0, 0, 0, 0),
            T.HFlipTransform(3),
        ])
        self.assertEqual(transforms.apply_int(3), 2)

        # Testing __add__, __iadd__, __radd__, __len__.
        transforms = transforms + transforms
        transforms += transforms
        transforms = T.NoOpTransform() + transforms
        self.assertEqual(len(transforms), 13)

        with self.assertRaises(AssertionError):
            T.HFlipTransform.register_type(dtype, lambda x: 1)

        with self.assertRaises(AttributeError):
            transforms.no_existing
示例#2
0
    def test_register(self):
        """
        Test register.
        """
        dtype = "int"

        def add1(t, x):
            return x + 1

        def flip_sub_width(t, x):
            return x - t.width

        T.Transform.register_type(dtype, add1)
        T.HFlipTransform.register_type(dtype, flip_sub_width)

        transforms = T.TransformList([
            T.ScaleTransform(0, 0, 0, 0, 0),
            T.CropTransform(0, 0, 0, 0),
            T.HFlipTransform(3),
        ])
        self.assertEqual(transforms.apply_int(3), 2)

        with self.assertRaises(AssertionError):
            T.HFlipTransform.register_type(dtype, lambda x: 1)
示例#3
0
 def test_transformlist_flatten(self):
     t0 = T.HFlipTransform(width=100)
     t1 = T.ScaleTransform(3, 4, 5, 6)
     t2 = T.CropTransform(4, 5, 6, 7)
     t = T.TransformList([T.TransformList([t0, t1]), t2])
     self.assertEqual(len(t.transforms), 3)