def test_get_set_meta_fns(self): set_track_meta(False) self.assertEqual(get_track_meta(), False) set_track_meta(True) self.assertEqual(get_track_meta(), True) set_track_transforms(False) self.assertEqual(get_track_transforms(), False) set_track_transforms(True) self.assertEqual(get_track_transforms(), True)
def test_correct(self, input_param, expected_shape, track_meta): set_track_meta(track_meta) r = LoadImage(image_only=True, **input_param)(self.test_data) self.assertTupleEqual(r.shape, expected_shape) if track_meta: self.assertIsInstance(r, MetaTensor) self.assertTrue(hasattr(r, "affine")) self.assertIsInstance(r.affine, torch.Tensor) else: self.assertIsInstance(r, torch.Tensor) self.assertNotIsInstance(r, MetaTensor) self.assertFalse(hasattr(r, "affine"))
def test_torch(self, init_param, img: torch.Tensor, track_meta: bool, device): set_track_meta(track_meta) img = img.to(device) xform = Flipd("image", init_param) res = xform({"image": img}) self.assertEqual(img.shape, res["image"].shape) if track_meta: self.assertIsInstance(res["image"], MetaTensor) else: self.assertNotIsInstance(res["image"], MetaTensor) self.assertIsInstance(res["image"], torch.Tensor) with self.assertRaisesRegex(ValueError, "MetaTensor"): xform.inverse(res)
def test_spacing_torch(self, pixdim, img, track_meta: bool): set_track_meta(track_meta) tr = Spacing(pixdim=pixdim) res = tr(img) if track_meta: self.assertIsInstance(res, MetaTensor) new_spacing = affine_to_spacing(res.affine, 3) assert_allclose(new_spacing, pixdim, type_test=False) self.assertNotEqual(img.shape, res.shape) else: self.assertIsInstance(res, torch.Tensor) self.assertNotIsInstance(res, MetaTensor) self.assertNotEqual(img.shape, res.shape) set_track_meta(True)
def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool, device): set_track_meta(track_meta) tr = Spacingd(**init_param) data = {"seg": img.to(device)} res = tr(data)["seg"] if track_meta: self.assertIsInstance(res, MetaTensor) new_spacing = affine_to_spacing(res.affine, 3) assert_allclose(new_spacing, init_param["pixdim"], type_test=False) self.assertNotEqual(img.shape, res.shape) else: self.assertIsInstance(res, torch.Tensor) self.assertNotIsInstance(res, MetaTensor) self.assertNotEqual(img.shape, res.shape)
def test_correct(self, input_param, expected_shape, track_meta): set_track_meta(track_meta) result = LoadImaged(image_only=True, **input_param)(self.test_data) # shouldn't have any extra meta data keys self.assertEqual(len(result), len(KEYS)) for key in KEYS: r = result[key] self.assertTupleEqual(r.shape, expected_shape) if track_meta: self.assertIsInstance(r, MetaTensor) self.assertTrue(hasattr(r, "affine")) self.assertIsInstance(r.affine, torch.Tensor) else: self.assertIsInstance(r, torch.Tensor) self.assertNotIsInstance(r, MetaTensor) self.assertFalse(hasattr(r, "affine"))
def test_ornt_torch(self, init_param, img: torch.Tensor, track_meta: bool, device): set_track_meta(track_meta) ornt = Orientation(**init_param) img = img.to(device) expected_data = img.clone() expected_code = ornt.axcodes res = ornt(img) assert_allclose(res, expected_data) if track_meta: self.assertIsInstance(res, MetaTensor) new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=ornt.labels) self.assertEqual("".join(new_code), expected_code) else: self.assertIsInstance(res, torch.Tensor) self.assertNotIsInstance(res, MetaTensor)
def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool, device): set_track_meta(track_meta) ornt = Orientationd(**init_param) img = img.to(device) expected_shape = img.shape expected_code = ornt.ornt_transform.axcodes data = {k: img.clone() for k in ornt.keys} res = ornt(data) for k in ornt.keys: _im = res[k] np.testing.assert_allclose(_im.shape, expected_shape) if track_meta: self.assertIsInstance(_im, MetaTensor) code = nib.aff2axcodes(_im.affine.cpu(), ornt.ornt_transform.labels) self.assertEqual("".join(code), expected_code) else: self.assertIsInstance(_im, torch.Tensor) self.assertNotIsInstance(_im, MetaTensor)