Ejemplo n.º 1
0
    def test_device_dtype_change(self):
        class DummyTrafo(AbstractTransform):
            def __init__(self, a):
                super().__init__(False)
                self.register_buffer("tmp", a)

            def __call__(self, *args, **kwargs):
                return self.tmp

        trafo_a = DummyTrafo(torch.tensor([1.0], dtype=torch.float32))
        trafo_a = trafo_a.to(torch.float32)
        trafo_b = DummyTrafo(torch.tensor([2.0], dtype=torch.float32))
        trafo_b = trafo_b.to(torch.float32)
        self.assertEqual(trafo_a.tmp.dtype, torch.float32)
        self.assertEqual(trafo_b.tmp.dtype, torch.float32)
        compose = Compose(trafo_a, trafo_b)
        compose = compose.to(torch.float64)

        self.assertEqual(compose.transforms[0].tmp.dtype, torch.float64)
Ejemplo n.º 2
0
    def test_wrapping_non_module_trafos(self):
        class DummyTrafo:
            def __init__(self):
                self.a = 5

            def __call__(self, *args, **kwargs):
                return 5

        dummy_trafo = DummyTrafo()

        compose = Compose([dummy_trafo])
        self.assertIsInstance(compose.transforms[0], _TransformWrapper)
        self.assertIsInstance(compose.transforms[0].trafo, DummyTrafo)
Ejemplo n.º 3
0
    def test_compose_shuffle(self):
        compose = Compose([Mirror(dims=(0, ))] * 10, shuffle=True)

        random.seed(0)
        outp = compose(**self.batch)

        order = list(range(len(compose.transforms)))
        expected_order = copy.deepcopy(order)
        random.seed(0)
        random.shuffle(expected_order)

        self.assertEqual(compose.transform_order, expected_order)
        self.assertNotEqual(expected_order, order)
Ejemplo n.º 4
0
 def test_compose_multiple_tuple(self):
     compose = Compose(tuple(self.transforms))
     outp = compose(**self.batch)
     self.assertTrue((self.batch["data"] == outp["data"]).all())
Ejemplo n.º 5
0
 def test_compose_multiple(self):
     compose = Compose(self.transforms)
     outp = compose(**self.batch)
     self.assertTrue((self.batch["data"] == outp["data"]).all())
     self.assertEqual(len(compose.transform_order), 2)
Ejemplo n.º 6
0
 def test_compose_single(self):
     single_compose = Compose(self.transforms[0])
     outp = single_compose(**self.batch)
     expected = Mirror(dims=(0, ))(**self.batch)
     self.assertTrue((expected["data"] == outp["data"]).all())