Esempio n. 1
0
    def test_rand_affined(self, input_param, input_data, expected_val, track_meta):
        set_track_meta(track_meta)
        g = RandAffined(**input_param).set_random_state(123)
        res = g(input_data)
        if input_param.get("cache_grid", False):
            self.assertTrue(g.rand_affine._cached_grid is not None)
        for key in res:
            if isinstance(key, str) and key.endswith("_transforms"):
                continue
            result = res[key]
            if track_meta:
                self.assertIsInstance(result, MetaTensor)
                self.assertEqual(len(result.applied_operations), 1)
            expected = expected_val[key] if isinstance(expected_val, dict) else expected_val
            assert_allclose(result, expected, rtol=_rtol, atol=1e-3, type_test=False)

        g.set_random_state(4)
        res = g(input_data)
        if not track_meta:
            return

        # affine should be tensor because the resampler only supports pytorch backend
        if isinstance(res["img"], MetaTensor) and "extra_info" in res["img"].applied_operations[0]:
            if not res["img"].applied_operations[-1]["extra_info"]["do_resampling"]:
                return
            affine_img = res["img"].applied_operations[0]["extra_info"]["rand_affine_info"]["extra_info"]["affine"]
            affine_seg = res["seg"].applied_operations[0]["extra_info"]["rand_affine_info"]["extra_info"]["affine"]
            assert_allclose(affine_img, affine_seg, rtol=_rtol, atol=1e-3)

        res_inv = g.inverse(res)
        for k, v in res_inv.items():
            self.assertIsInstance(v, MetaTensor)
            self.assertEqual(len(v.applied_operations), 0)
            self.assertTupleEqual(v.shape, input_data[k].shape)
Esempio n. 2
0
 def test_affine_grid(self, input_param, input_data, expected_val):
     g = AffineGrid(**input_param)
     set_track_meta(False)
     result, _ = g(**input_data)
     self.assertNotIsInstance(result, MetaTensor)
     self.assertIsInstance(result, torch.Tensor)
     set_track_meta(True)
     if "device" in input_data:
         self.assertEqual(result.device, input_data[device])
     assert_allclose(result, expected_val, type_test=False, rtol=_rtol)
Esempio n. 3
0
 def test_rand_2d_elastic(self, input_param, input_data, expected_val):
     g = Rand2DElastic(**input_param)
     set_track_meta(False)
     result = g(**input_data)
     self.assertNotIsInstance(result, MetaTensor)
     self.assertIsInstance(result, torch.Tensor)
     set_track_meta(True)
     g.set_random_state(123)
     result = g(**input_data)
     assert_allclose(result, expected_val, type_test=False, rtol=_rtol, atol=1e-4)
Esempio n. 4
0
    def test_longest_shape(self, input_param, expected_shape):
        input_data = np.random.randint(0, 2, size=[3, 4, 7, 10])
        input_param["size_mode"] = "longest"
        result = Resize(**input_param)(input_data)
        np.testing.assert_allclose(result.shape[1:], expected_shape)

        set_track_meta(False)
        result = Resize(**input_param)(input_data)
        self.assertNotIsInstance(result, MetaTensor)
        np.testing.assert_allclose(result.shape[1:], expected_shape)
        set_track_meta(True)
Esempio n. 5
0
 def test_correct_shape(self, im_type, angle, mode, padding_mode, align_corners):
     rotate_fn = Rotate(angle, True, align_corners=align_corners, dtype=np.float64)
     im = im_type(self.imt[0])
     set_track_meta(False)
     rotated = rotate_fn(im, mode=mode, padding_mode=padding_mode)
     self.assertNotIsInstance(rotated, MetaTensor)
     np.testing.assert_allclose(self.imt[0].shape, rotated.shape)
     set_track_meta(True)
     rotated = rotate_fn(im, mode=mode, padding_mode=padding_mode)
     np.testing.assert_allclose(self.imt[0].shape, rotated.shape)
     test_local_inversion(rotate_fn, rotated, im)
Esempio n. 6
0
 def test_longest_shape(self, input_param, expected_shape):
     input_data = {
         "img": np.random.randint(0, 2, size=[3, 4, 7, 10]),
         "label": np.random.randint(0, 2, size=[3, 4, 7, 10]),
     }
     input_param["size_mode"] = "longest"
     rescaler = Resized(**input_param)
     result = rescaler(input_data)
     for k in rescaler.keys:
         np.testing.assert_allclose(result[k].shape[1:], expected_shape)
     set_track_meta(False)
     result = Resized(**input_param)(input_data)
     self.assertNotIsInstance(result["img"], MetaTensor)
     np.testing.assert_allclose(result["img"].shape[1:], expected_shape)
     set_track_meta(True)
Esempio n. 7
0
    def test_correct_results(self):
        for p in TEST_NDARRAYS_ALL:
            flip = RandAxisFlip(prob=1.0)
            im = p(self.imt[0])
            result = flip(im)
            expected = [
                np.flip(channel, flip._axis) for channel in self.imt[0]
            ]
            assert_allclose(result, p(np.stack(expected)), type_test="tensor")
            test_local_inversion(flip, result, im)

            set_track_meta(False)
            result = flip(im)
            self.assertNotIsInstance(result, MetaTensor)
            self.assertIsInstance(result, torch.Tensor)
            set_track_meta(True)
Esempio n. 8
0
    def test_keep_size(self):
        for p in TEST_NDARRAYS_ALL:
            zoom_fn = Zoom(zoom=[0.6, 0.6], keep_size=True, align_corners=True)
            im = p(self.imt[0])
            zoomed = zoom_fn(im, mode="bilinear")
            assert_allclose(zoomed.shape, self.imt.shape[1:], type_test=False)
            test_local_inversion(zoom_fn, zoomed, im)

            zoom_fn = Zoom(zoom=[1.3, 1.3], keep_size=True)
            im = p(self.imt[0])
            zoomed = zoom_fn(im)
            assert_allclose(zoomed.shape, self.imt.shape[1:], type_test=False)
            test_local_inversion(zoom_fn, zoomed, p(self.imt[0]))

            set_track_meta(False)
            rotated = zoom_fn(im)
            self.assertNotIsInstance(rotated, MetaTensor)
            np.testing.assert_allclose(zoomed.shape, self.imt.shape[1:])
            set_track_meta(True)
Esempio n. 9
0
    def test_default(self):
        key = None
        rotate = RandRotate90d(keys=key)
        for p in TEST_NDARRAYS_ALL:
            rotate.set_random_state(1323)
            im = {key: p(self.imt[0])}
            rotated = rotate(im)
            test_local_inversion(rotate, rotated, im, key)
            expected = [
                np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]
            ]
            expected = np.stack(expected)
            assert_allclose(rotated[key], p(expected), type_test="tensor")

            set_track_meta(False)
            rotated = rotate(im)[key]
            self.assertNotIsInstance(rotated, MetaTensor)
            self.assertIsInstance(rotated, torch.Tensor)
            set_track_meta(True)
Esempio n. 10
0
    def test_affine(self, input_param, input_data, expected_val):
        input_copy = deepcopy(input_data["img"])
        g = Affine(**input_param)
        result = g(**input_data)
        if isinstance(result, tuple):
            result = result[0]
        test_local_inversion(g, result, input_copy)
        assert_allclose(result,
                        expected_val,
                        rtol=1e-4,
                        atol=1e-4,
                        type_test=False)

        set_track_meta(False)
        result = g(**input_data)
        if isinstance(result, tuple):
            result = result[0]
        self.assertNotIsInstance(result, MetaTensor)
        self.assertIsInstance(result, torch.Tensor)
        set_track_meta(True)
Esempio n. 11
0
    def test_k(self):
        rotate = RandRotate90(max_k=2)
        for p in TEST_NDARRAYS_ALL:
            im = p(self.imt[0])
            set_track_meta(False)
            rotated = rotate(im)
            self.assertNotIsInstance(rotated, MetaTensor)
            self.assertIsInstance(rotated, torch.Tensor)

            set_track_meta(True)
            rotate.set_random_state(123)
            rotated = rotate(im)
            test_local_inversion(rotate, rotated, im)
            expected = [
                np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]
            ]
            expected = np.stack(expected)
            assert_allclose(rotated,
                            p(expected),
                            rtol=1.0e-5,
                            atol=1.0e-8,
                            type_test="tensor")
Esempio n. 12
0
    def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected):
        rotate_fn = RandRotate(
            range_x=x,
            range_y=y,
            range_z=z,
            prob=1.0,
            keep_size=keep_size,
            mode=mode,
            padding_mode=padding_mode,
            align_corners=align_corners,
            dtype=np.float64,
        )
        rotate_fn.set_random_state(243)
        im = im_type(self.imt[0])
        rotated = rotate_fn(im)
        torch.testing.assert_allclose(rotated.shape, expected, rtol=1e-7, atol=0)
        test_local_inversion(rotate_fn, rotated, im)

        set_track_meta(False)
        rotated = rotate_fn(im)
        self.assertNotIsInstance(rotated, MetaTensor)
        self.assertIsInstance(rotated, torch.Tensor)
        set_track_meta(True)
Esempio n. 13
0
 def test_rotate90_default(self):
     key = "test"
     rotate = Rotate90d(keys=key)
     for p in TEST_NDARRAYS_ALL:
         im = p(self.imt[0])
         set_track_meta(True)
         rotated = rotate({key: im})
         test_local_inversion(rotate, rotated, {key: im}, key)
         expected = [
             np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]
         ]
         expected = np.stack(expected)
         assert_allclose(rotated[key], p(expected), type_test="tensor")
         set_track_meta(False)
         rotated = rotate({key: im})
         self.assertNotIsInstance(rotated[key], MetaTensor)
         set_track_meta(True)