def test_register(self): """ Test register. """ dtype = "int" def add1(t, x): return x + 1 def flip_sub_width(t, x): return x - t.width T.Transform.register_type(dtype, add1) T.HFlipTransform.register_type(dtype, flip_sub_width) transforms = T.TransformList([ T.ScaleTransform(0, 0, 0, 0, 0), T.CropTransform(0, 0, 0, 0), T.HFlipTransform(3), ]) self.assertEqual(transforms.apply_int(3), 2) # Testing __add__, __iadd__, __radd__, __len__. transforms = transforms + transforms transforms += transforms transforms = T.NoOpTransform() + transforms self.assertEqual(len(transforms), 13) with self.assertRaises(AssertionError): T.HFlipTransform.register_type(dtype, lambda x: 1) with self.assertRaises(AttributeError): transforms.no_existing
def test_print_transform(self): t0 = T.HFlipTransform(width=100) self.assertEqual(str(t0), "HFlipTransform(width=100)") t = T.TransformList([T.NoOpTransform(), t0]) self.assertEqual(str(t), f"TransformList[NoOpTransform(), {t0}]") t = T.BlendTransform(np.zeros((100, 100, 100)), 1.0, 1.0) self.assertEqual( str(t), "BlendTransform(src_image=..., src_weight=1.0, dst_weight=1.0)")
def test_register_with_decorator(self): """ Test register using decorator. """ dtype = "float" @T.HFlipTransform.register_type(dtype) def add1(t, x): return x + 1 transforms = T.TransformList([T.HFlipTransform(3)]) self.assertEqual(transforms.apply_float(3), 4)
def test_register(self): """ Test register. """ dtype = "int" def add1(t, x): return x + 1 def flip_sub_width(t, x): return x - t.width T.Transform.register_type(dtype, add1) T.HFlipTransform.register_type(dtype, flip_sub_width) transforms = T.TransformList([ T.ScaleTransform(0, 0, 0, 0, 0), T.CropTransform(0, 0, 0, 0), T.HFlipTransform(3), ]) self.assertEqual(transforms.apply_int(3), 2) with self.assertRaises(AssertionError): T.HFlipTransform.register_type(dtype, lambda x: 1)
def test_transformlist_flatten(self): t0 = T.HFlipTransform(width=100) t1 = T.ScaleTransform(3, 4, 5, 6) t2 = T.CropTransform(4, 5, 6, 7) t = T.TransformList([T.TransformList([t0, t1]), t2]) self.assertEqual(len(t.transforms), 3)