def test_register(self): """ Test register. """ dtype = "int" def add1(t, x): return x + 1 def flip_sub_width(t, x): return x - t.width T.Transform.register_type(dtype, add1) T.HFlipTransform.register_type(dtype, flip_sub_width) transforms = T.TransformList([ T.ScaleTransform(0, 0, 0, 0, 0), T.CropTransform(0, 0, 0, 0), T.HFlipTransform(3), ]) self.assertEqual(transforms.apply_int(3), 2) # Testing __add__, __iadd__, __radd__, __len__. transforms = transforms + transforms transforms += transforms transforms = T.NoOpTransform() + transforms self.assertEqual(len(transforms), 13) with self.assertRaises(AssertionError): T.HFlipTransform.register_type(dtype, lambda x: 1) with self.assertRaises(AttributeError): transforms.no_existing
def test_register(self): """ Test register. """ dtype = "int" def add1(t, x): return x + 1 def flip_sub_width(t, x): return x - t.width T.Transform.register_type(dtype, add1) T.HFlipTransform.register_type(dtype, flip_sub_width) transforms = T.TransformList([ T.ScaleTransform(0, 0, 0, 0, 0), T.CropTransform(0, 0, 0, 0), T.HFlipTransform(3), ]) self.assertEqual(transforms.apply_int(3), 2) with self.assertRaises(AssertionError): T.HFlipTransform.register_type(dtype, lambda x: 1)
def test_transformlist_flatten(self): t0 = T.HFlipTransform(width=100) t1 = T.ScaleTransform(3, 4, 5, 6) t2 = T.CropTransform(4, 5, 6, 7) t = T.TransformList([T.TransformList([t0, t1]), t2]) self.assertEqual(len(t.transforms), 3)