def test_get_set_meta_fns(self):
     set_track_meta(False)
     self.assertEqual(get_track_meta(), False)
     set_track_meta(True)
     self.assertEqual(get_track_meta(), True)
     set_track_transforms(False)
     self.assertEqual(get_track_transforms(), False)
     set_track_transforms(True)
     self.assertEqual(get_track_transforms(), True)
Exemple #2
0
 def test_correct(self, input_param, expected_shape, track_meta):
     set_track_meta(track_meta)
     r = LoadImage(image_only=True, **input_param)(self.test_data)
     self.assertTupleEqual(r.shape, expected_shape)
     if track_meta:
         self.assertIsInstance(r, MetaTensor)
         self.assertTrue(hasattr(r, "affine"))
         self.assertIsInstance(r.affine, torch.Tensor)
     else:
         self.assertIsInstance(r, torch.Tensor)
         self.assertNotIsInstance(r, MetaTensor)
         self.assertFalse(hasattr(r, "affine"))
Exemple #3
0
 def test_torch(self, init_param, img: torch.Tensor, track_meta: bool, device):
     set_track_meta(track_meta)
     img = img.to(device)
     xform = Flipd("image", init_param)
     res = xform({"image": img})
     self.assertEqual(img.shape, res["image"].shape)
     if track_meta:
         self.assertIsInstance(res["image"], MetaTensor)
     else:
         self.assertNotIsInstance(res["image"], MetaTensor)
         self.assertIsInstance(res["image"], torch.Tensor)
         with self.assertRaisesRegex(ValueError, "MetaTensor"):
             xform.inverse(res)
Exemple #4
0
 def test_spacing_torch(self, pixdim, img, track_meta: bool):
     set_track_meta(track_meta)
     tr = Spacing(pixdim=pixdim)
     res = tr(img)
     if track_meta:
         self.assertIsInstance(res, MetaTensor)
         new_spacing = affine_to_spacing(res.affine, 3)
         assert_allclose(new_spacing, pixdim, type_test=False)
         self.assertNotEqual(img.shape, res.shape)
     else:
         self.assertIsInstance(res, torch.Tensor)
         self.assertNotIsInstance(res, MetaTensor)
         self.assertNotEqual(img.shape, res.shape)
     set_track_meta(True)
Exemple #5
0
    def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool,
                         device):
        set_track_meta(track_meta)
        tr = Spacingd(**init_param)
        data = {"seg": img.to(device)}
        res = tr(data)["seg"]

        if track_meta:
            self.assertIsInstance(res, MetaTensor)
            new_spacing = affine_to_spacing(res.affine, 3)
            assert_allclose(new_spacing, init_param["pixdim"], type_test=False)
            self.assertNotEqual(img.shape, res.shape)
        else:
            self.assertIsInstance(res, torch.Tensor)
            self.assertNotIsInstance(res, MetaTensor)
            self.assertNotEqual(img.shape, res.shape)
Exemple #6
0
    def test_correct(self, input_param, expected_shape, track_meta):
        set_track_meta(track_meta)
        result = LoadImaged(image_only=True, **input_param)(self.test_data)

        # shouldn't have any extra meta data keys
        self.assertEqual(len(result), len(KEYS))
        for key in KEYS:
            r = result[key]
            self.assertTupleEqual(r.shape, expected_shape)
            if track_meta:
                self.assertIsInstance(r, MetaTensor)
                self.assertTrue(hasattr(r, "affine"))
                self.assertIsInstance(r.affine, torch.Tensor)
            else:
                self.assertIsInstance(r, torch.Tensor)
                self.assertNotIsInstance(r, MetaTensor)
                self.assertFalse(hasattr(r, "affine"))
Exemple #7
0
    def test_ornt_torch(self, init_param, img: torch.Tensor, track_meta: bool, device):
        set_track_meta(track_meta)
        ornt = Orientation(**init_param)

        img = img.to(device)
        expected_data = img.clone()
        expected_code = ornt.axcodes

        res = ornt(img)
        assert_allclose(res, expected_data)
        if track_meta:
            self.assertIsInstance(res, MetaTensor)
            new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=ornt.labels)
            self.assertEqual("".join(new_code), expected_code)
        else:
            self.assertIsInstance(res, torch.Tensor)
            self.assertNotIsInstance(res, MetaTensor)
Exemple #8
0
 def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool,
                      device):
     set_track_meta(track_meta)
     ornt = Orientationd(**init_param)
     img = img.to(device)
     expected_shape = img.shape
     expected_code = ornt.ornt_transform.axcodes
     data = {k: img.clone() for k in ornt.keys}
     res = ornt(data)
     for k in ornt.keys:
         _im = res[k]
         np.testing.assert_allclose(_im.shape, expected_shape)
         if track_meta:
             self.assertIsInstance(_im, MetaTensor)
             code = nib.aff2axcodes(_im.affine.cpu(),
                                    ornt.ornt_transform.labels)
             self.assertEqual("".join(code), expected_code)
         else:
             self.assertIsInstance(_im, torch.Tensor)
             self.assertNotIsInstance(_im, MetaTensor)