Esempio n. 1
0
 def test_astype(self):
     t = MetaTensor([1.0], affine=torch.tensor(1), meta={"fname": "filename"})
     for np_types in ("float32", "np.float32", "numpy.float32", np.float32, float, "int", np.compat.long, np.uint16):
         self.assertIsInstance(t.astype(np_types), np.ndarray)
     for pt_types in ("torch.float", torch.float, "torch.float64"):
         self.assertIsInstance(t.astype(pt_types), torch.Tensor)
     self.assertIsInstance(t.astype("torch.float", device="cpu"), torch.Tensor)
Esempio n. 2
0
 def test_constructor(self, device, dtype):
     m, t = self.get_im(device=device, dtype=dtype)
     # construct from pre-existing
     m1 = MetaTensor(m.clone())
     self.check(m, m1, ids=False, meta=False)
     # meta already has affine
     m2 = MetaTensor(t.clone(), meta=m.meta)
     self.check(m, m2, ids=False, meta=False)
     # meta dosen't have affine
     affine = m.meta.pop("affine")
     m3 = MetaTensor(t.clone(), affine=affine, meta=m.meta)
     self.check(m, m3, ids=False, meta=False)
Esempio n. 3
0
 def test_check(self):
     im = torch.zeros(1, 2, 3)
     with self.assertRaises(ValueError):  # not MetaTensor
         EnsureChannelFirst()(im)
     with self.assertRaises(ValueError):  # no meta
         EnsureChannelFirst()(MetaTensor(im))
     with self.assertRaises(ValueError):  # no meta channel
         EnsureChannelFirst()(MetaTensor(
             im, meta={"original_channel_dim": None}))
     EnsureChannelFirst(strict_check=False)(im)
     EnsureChannelFirst(strict_check=False)(MetaTensor(
         im, meta={"original_channel_dim": None}))
Esempio n. 4
0
 def test_exceptions(self):
     im = torch.zeros((1, 2, 3))
     with self.assertRaises(ValueError):  # no meta
         EnsureChannelFirstd("img")({"img": im})
     with self.assertRaises(ValueError):  # no meta channel
         EnsureChannelFirstd("img")({
             "img":
             MetaTensor(im, meta={"original_channel_dim": None})
         })
     EnsureChannelFirstd("img", strict_check=False)({"img": im})
     EnsureChannelFirstd("img", strict_check=False)({
         "img":
         MetaTensor(im, meta={"original_channel_dim": None})
     })
Esempio n. 5
0
 def test_multiprocessing(self, device=None, dtype=None):
     """multiprocessing sharing with 'device' and 'dtype'"""
     buf = io.BytesIO()
     t = MetaTensor([0.0, 0.0], device=device, dtype=dtype)
     t.is_batch = True
     if t.is_cuda:
         with self.assertRaises(NotImplementedError):
             ForkingPickler(buf).dump(t)
         return
     ForkingPickler(buf).dump(t)
     obj = ForkingPickler.loads(buf.getvalue())
     self.assertIsInstance(obj, MetaTensor)
     assert_allclose(obj.as_tensor(), t)
     assert_allclose(obj.is_batch, True)
Esempio n. 6
0
 def test_array_function(self, device="cpu", dtype=float):
     a = np.random.RandomState().randn(100, 100)
     b = MetaTensor(a, device=device)
     assert_allclose(np.sum(a), np.sum(b))
     assert_allclose(np.sum(a, axis=1), np.sum(b, axis=1))
     assert_allclose(np.linalg.qr(a), np.linalg.qr(b))
     c = MetaTensor([1.0, 2.0, 3.0], device=device, dtype=dtype)
     assert_allclose(np.argwhere(c == 1.0).astype(int).tolist(), [[0]])
     assert_allclose(np.concatenate([c, c]), np.asarray([1.0, 2.0, 3.0, 1.0, 2.0, 3.0]))
     if pytorch_after(1, 8, 1):
         assert_allclose(c > np.asarray([1.0, 1.0, 1.0]), np.asarray([False, True, True]))
         assert_allclose(
             c > torch.as_tensor([1.0, 1.0, 1.0], device=device), torch.as_tensor([False, True, True], device=device)
         )
Esempio n. 7
0
 def test_as_dict(self):
     m, _ = self.get_im()
     m_dict = m.as_dict("im")
     im, meta = m_dict["im"], m_dict[PostFix.meta("im")]
     affine = meta.pop("affine")
     m2 = MetaTensor(im, affine, meta)
     self.check(m2, m, check_ids=False)
Esempio n. 8
0
 def test_construct_with_pre_applied_transforms(self):
     key = "im"
     _, im = self.get_im()
     tr = Compose([BorderPadd(key, 1), DivisiblePadd(key, 16)])
     data = tr({key: im})
     m = MetaTensor(im, applied_operations=data["im"].applied_operations)
     self.assertEqual(len(m.applied_operations), len(tr.transforms))
Esempio n. 9
0
 def test_str(self):
     t = MetaTensor([1.0], affine=torch.tensor(1), meta={"fname": "filename"})
     s1 = str(t)
     s2 = t.__repr__()
     expected_out = (
         "tensor([1.])\n"
         + "MetaData\n"
         + "\tfname: filename\n"
         + "\taffine: 1\n"
         + "\n"
         + "Applied operations\n"
         + "[]\n"
         + "Is batch?: False"
     )
     for s in (s1, s2):
         self.assertEqual(s, expected_out)
Esempio n. 10
0
 def png_rw(self, test_data, reader, writer, dtype, resample=True):
     test_data = test_data.astype(dtype)
     ndim = len(test_data.shape) - 1
     for p in TEST_NDARRAYS:
         output_ext = ".png"
         filepath = f"testfile_{ndim}d"
         saver = SaveImage(output_dir=self.test_dir,
                           output_ext=output_ext,
                           resample=resample,
                           separate_folder=False,
                           writer=writer)
         test_data = MetaTensor(p(test_data),
                                meta={
                                    "filename_or_obj": f"{filepath}.png",
                                    "spatial_shape": (6, 8)
                                })
         saver(test_data)
         saved_path = os.path.join(self.test_dir,
                                   filepath + "_trans" + output_ext)
         self.assertTrue(os.path.exists(saved_path))
         loader = LoadImage(image_only=True, reader=reader)
         data = loader(saved_path)
         meta = data.meta
         if meta["original_channel_dim"] == -1:
             _test_data = moveaxis(test_data, 0, -1)
         else:
             _test_data = test_data[0]
         assert_allclose(data, torch.as_tensor(_test_data))
Esempio n. 11
0
 def __call__(
     self, data: Mapping[Hashable, NdarrayOrTensor]
 ) -> Dict[Hashable, NdarrayOrTensor]:
     d = dict(data)
     for key in self.key_iterator(d):
         self.push_transform(d, key)
         im, meta = d[key], d.pop(PostFix.meta(key), None)
         im = MetaTensor(im, meta=meta)  # type: ignore
         d[key] = im
     return d
Esempio n. 12
0
 def get_im(shape=None, dtype=None, device=None):
     if shape is None:
         shape = (1, 10, 8)
     affine = torch.randint(0, 10, (4, 4))
     meta = {"fname": rand_string()}
     t = torch.rand(shape)
     if dtype is not None:
         t = t.to(dtype)
     if device is not None:
         t = t.to(device)
     m = MetaTensor(t.clone(), affine, meta)
     return m, t
Esempio n. 13
0
 def test_copy(self, device, dtype):
     m, _ = self.get_im(device=device, dtype=dtype)
     # shallow copy
     a = m
     self.check(a, m, ids=True)
     # deepcopy
     a = deepcopy(m)
     self.check(a, m, ids=False)
     # clone
     a = m.clone()
     self.check(a, m, ids=False)
     a = MetaTensor([[]], device=device, dtype=dtype)
     self.check(a, deepcopy(a), ids=False)
Esempio n. 14
0
 def inverse(
     self, data: Mapping[Hashable, NdarrayOrTensor]
 ) -> Dict[Hashable, NdarrayOrTensor]:
     d = deepcopy(dict(data))
     for key in self.key_iterator(d):
         # check transform
         _ = self.get_most_recent_transform(d, key)
         # do the inverse
         im, meta = d[key], d.pop(PostFix.meta(key), None)
         im = MetaTensor(im, meta=meta)  # type: ignore
         d[key] = im
         # Remove the applied transform
         self.pop_transform(d, key)
     return d
Esempio n. 15
0
 def test_ornt_meta(
     self,
     init_param,
     img: torch.Tensor,
     affine: torch.Tensor,
     expected_data: torch.Tensor,
     expected_code: str,
     device,
 ):
     img = MetaTensor(img, affine=affine).to(device)
     ornt = Orientation(**init_param)
     res: MetaTensor = ornt(img)
     assert_allclose(res, expected_data.to(device))
     new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=ornt.labels)
     self.assertEqual("".join(new_code), expected_code)
Esempio n. 16
0
 def test_dataloader(self, dtype):
     batch_size = 5
     ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)]
     ims = [MetaTensor(im, applied_operations=[f"t{i}"]) for i, im in enumerate(ims)]
     ds = Dataset(ims)
     im_shape = tuple(ims[0].shape)
     affine_shape = tuple(ims[0].affine.shape)
     expected_im_shape = (batch_size,) + im_shape
     expected_affine_shape = (batch_size,) + affine_shape
     dl = DataLoader(ds, num_workers=batch_size, batch_size=batch_size)
     for batch in dl:
         self.assertIsInstance(batch, MetaTensor)
         self.assertTupleEqual(tuple(batch.shape), expected_im_shape)
         self.assertTupleEqual(tuple(batch.affine.shape), expected_affine_shape)
         self.assertEqual(len(batch.applied_operations), batch_size)
Esempio n. 17
0
 def test_orntd(self, init_param, img: torch.Tensor,
                affine: Optional[torch.Tensor], expected_shape,
                expected_code, device):
     ornt = Orientationd(**init_param)
     if affine is not None:
         img = MetaTensor(img, affine=affine)
     img = img.to(device)
     data = {k: img.clone() for k in ornt.keys}
     res = ornt(data)
     for k in ornt.keys:
         _im = res[k]
         self.assertIsInstance(_im, MetaTensor)
         np.testing.assert_allclose(_im.shape, expected_shape)
         code = nib.aff2axcodes(_im.affine.cpu(),
                                ornt.ornt_transform.labels)
         self.assertEqual("".join(code), expected_code)
Esempio n. 18
0
    def test_spacing(self, init_param, img, affine, data_param,
                     expected_output, device):
        img = MetaTensor(img, affine=affine).to(device)
        res: MetaTensor = Spacing(**init_param)(img, **data_param)
        self.assertEqual(img.device, res.device)

        assert_allclose(res, expected_output, atol=1e-1, rtol=1e-1)
        sr = min(len(res.shape) - 1, 3)
        if isinstance(init_param["pixdim"], float):
            init_param["pixdim"] = [init_param["pixdim"]] * sr
        init_pixdim = ensure_tuple(init_param["pixdim"])
        init_pixdim = init_param["pixdim"][:sr]
        norm = affine_to_spacing(res.affine, sr).cpu().numpy()
        assert_allclose(fall_back_tuple(init_pixdim, norm),
                        norm,
                        type_test=False)
Esempio n. 19
0
 def test_collate(self, device, dtype):
     numel = 3
     ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(numel)]
     ims = [MetaTensor(im, applied_operations=[f"t{i}"]) for i, im in enumerate(ims)]
     collated = list_data_collate(ims)
     # tensor
     self.assertIsInstance(collated, MetaTensor)
     expected_shape = (numel,) + tuple(ims[0].shape)
     self.assertTupleEqual(tuple(collated.shape), expected_shape)
     for i, im in enumerate(ims):
         self.check(im, ims[i], ids=True)
     # affine
     self.assertIsInstance(collated.affine, torch.Tensor)
     expected_shape = (numel,) + tuple(ims[0].affine.shape)
     self.assertTupleEqual(tuple(collated.affine.shape), expected_shape)
     self.assertEqual(len(collated.applied_operations), numel)
Esempio n. 20
0
 def test_inverse(self, device):
     img_t = torch.rand((1, 10, 9, 8), dtype=torch.float32, device=device)
     affine = torch.tensor(
         [[0, 0, -1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=torch.float32, device="cpu"
     )
     meta = {"fname": "somewhere"}
     img = MetaTensor(img_t, affine=affine, meta=meta)
     tr = Orientation("LPS")
     # check that image and affine have changed
     img = tr(img)
     self.assertNotEqual(img.shape, img_t.shape)
     self.assertGreater((affine - img.affine).max(), 0.5)
     # check that with inverse, image affine are back to how they were
     img = tr.inverse(img)
     self.assertEqual(img.shape, img_t.shape)
     self.assertLess((affine - img.affine).max(), 1e-2)
Esempio n. 21
0
    def test_saved_content(self, test_data, meta_data, output_ext, resample):
        if meta_data is not None:
            test_data = MetaTensor(test_data, meta=meta_data)

        with tempfile.TemporaryDirectory() as tempdir:
            trans = SaveImage(
                output_dir=tempdir,
                output_ext=output_ext,
                resample=resample,
                separate_folder=False,  # test saving into the same folder
            )
            trans(test_data)

            filepath = "testfile0" if meta_data is not None else "0"
            self.assertTrue(
                os.path.exists(
                    os.path.join(tempdir, filepath + "_trans" + output_ext)))
Esempio n. 22
0
    def test_flips_inverse(self, img, device, dst_affine, kwargs,
                           expected_output):
        img = MetaTensor(img, affine=torch.eye(4)).to(device)
        data = {"img": img, "dst_affine": dst_affine}

        xform = SpatialResampled(keys="img", **kwargs)
        output_data = xform(data)
        out = output_data["img"]

        assert_allclose(out, expected_output, rtol=1e-2, atol=1e-2)
        assert_allclose(out.affine, dst_affine, rtol=1e-2, atol=1e-2)

        inverted = xform.inverse(output_data)["img"]
        self.assertEqual(inverted.applied_operations,
                         [])  # no further invert after inverting
        expected_affine = to_affine_nd(len(out.affine) - 1, torch.eye(4))
        assert_allclose(inverted.affine, expected_affine, rtol=1e-2, atol=1e-2)
        assert_allclose(inverted, img, rtol=1e-2, atol=1e-2)
Esempio n. 23
0
    def test_kwargs(self):
        spatial_size = (32, 64, 128)
        test_image = np.random.rand(*spatial_size)
        with tempfile.TemporaryDirectory() as tempdir:
            filename = os.path.join(tempdir, "test_image.nii.gz")
            itk_np_view = itk.image_view_from_array(test_image)
            itk.imwrite(itk_np_view, filename)

            loader = LoadImage(image_only=True)
            reader = ITKReader(fallback_only=False)
            loader.register(reader)
            result = loader(filename)

            reader = ITKReader()
            img = reader.read(filename, fallback_only=False)
            result_raw = reader.get_data(img)
            result_raw = MetaTensor.ensure_torch_and_prune_meta(*result_raw)
            self.assertTupleEqual(result.shape, result_raw.shape)
Esempio n. 24
0
 def test_inverse(self, device):
     img_t = torch.rand((1, 10, 9, 8), dtype=torch.float32, device=device)
     affine = torch.tensor(
         [[0, 0, -1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]],
         dtype=torch.float32,
         device="cpu")
     meta = {"fname": "somewhere"}
     img = MetaTensor(img_t, affine=affine, meta=meta)
     tr = Spacing(pixdim=[1.1, 1.2, 0.9])
     # check that image and affine have changed
     img = tr(img)
     self.assertNotEqual(img.shape, img_t.shape)
     l2_norm_affine = ((affine - img.affine)**2).sum()**0.5
     self.assertGreater(l2_norm_affine, 5e-2)
     # check that with inverse, image affine are back to how they were
     img = tr.inverse(img)
     self.assertEqual(img.applied_operations, [])
     self.assertEqual(img.shape, img_t.shape)
     l2_norm_affine = ((affine - img.affine)**2).sum()**0.5
     self.assertLess(l2_norm_affine, 5e-2)
Esempio n. 25
0
    def get_most_recent_transform(self,
                                  data,
                                  key: Hashable = None,
                                  check: bool = True,
                                  pop: bool = False):
        """
        Get most recent transform for the stack.

        Args:
            data: dictionary of data or `MetaTensor`.
            key: if data is a dictionary, data[key] will be modified.
            check: if true, check that `self` is the same type as the most recently-applied transform.
            pop: if true, remove the transform as it is returned.

        Returns:
            Dictionary of most recently applied transform

        Raises:
            - RuntimeError: data is neither `MetaTensor` nor dictionary
        """
        if not self.tracing:
            raise RuntimeError(
                "Transform Tracing must be enabled to get the most recent transform."
            )
        if isinstance(data, MetaTensor):
            all_transforms = data.applied_operations
        elif isinstance(data, Mapping):
            if key in data and isinstance(data[key], MetaTensor):
                all_transforms = data[key].applied_operations
            else:
                all_transforms = data.get(
                    self.trace_key(key),
                    MetaTensor.get_default_applied_operations())
        else:
            raise ValueError(
                f"`data` should be either `MetaTensor` or dictionary, got {type(data)}."
            )
        if check:
            self.check_transforms_match(all_transforms[-1])
        return all_transforms.pop() if pop else all_transforms[-1]
Esempio n. 26
0
 def nifti_rw(self, test_data, reader, writer, dtype, resample=True):
     test_data = test_data.astype(dtype)
     ndim = len(test_data.shape) - 1
     for p in TEST_NDARRAYS:
         output_ext = ".nii.gz"
         filepath = f"testfile_{ndim}d"
         saver = SaveImage(output_dir=self.test_dir,
                           output_ext=output_ext,
                           resample=resample,
                           separate_folder=False,
                           writer=writer)
         meta_dict = {
             "filename_or_obj":
             f"{filepath}.png",
             "affine":
             np.eye(4),
             "original_affine":
             np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0],
                       [0, 0, 0, 1]]),
         }
         test_data = MetaTensor(p(test_data), meta=meta_dict)
         self.assertEqual(test_data.meta[MetaKeys.SPACE], "RAS")
         saver(test_data)
         saved_path = os.path.join(self.test_dir,
                                   filepath + "_trans" + output_ext)
         self.assertTrue(os.path.exists(saved_path))
         loader = LoadImage(image_only=True,
                            reader=reader,
                            squeeze_non_spatial_dims=True)
         data = loader(saved_path)
         meta = data.meta
         if meta["original_channel_dim"] == -1:
             _test_data = moveaxis(test_data, 0, -1)
         else:
             _test_data = test_data[0]
         if resample:
             _test_data = moveaxis(_test_data, 0, 1)
         assert_allclose(data, torch.as_tensor(_test_data))
Esempio n. 27
0
 def nrrd_rw(self, test_data, reader, writer, dtype, resample=True):
     test_data = test_data.astype(dtype)
     ndim = len(test_data.shape)
     for p in TEST_NDARRAYS:
         output_ext = ".nrrd"
         filepath = f"testfile_{ndim}d"
         saver = SaveImage(output_dir=self.test_dir,
                           output_ext=output_ext,
                           resample=resample,
                           separate_folder=False,
                           writer=writer)
         test_data = MetaTensor(p(test_data),
                                meta={
                                    "filename_or_obj":
                                    f"{filepath}{output_ext}",
                                    "spatial_shape": test_data.shape
                                })
         saver(test_data)
         saved_path = os.path.join(self.test_dir,
                                   filepath + "_trans" + output_ext)
         loader = LoadImage(image_only=True, reader=reader)
         data = loader(saved_path)
         assert_allclose(data, torch.as_tensor(test_data))
Esempio n. 28
0
 def test_numpy(self, device=None, dtype=None):
     """device, dtype"""
     t = MetaTensor([0.0], device=device, dtype=dtype)
     self.assertIsInstance(t, MetaTensor)
     assert_allclose(t.array, np.asarray([0.0]))
     t.array = np.asarray([1.0])
     self.check_meta(t, MetaTensor([1.0]))
     assert_allclose(t.as_tensor(), torch.as_tensor([1.0]))
     t.array = [2.0]
     self.check_meta(t, MetaTensor([2.0]))
     assert_allclose(t.as_tensor(), torch.as_tensor([2.0]))
     if not t.is_cuda:
         t.array[0] = torch.as_tensor(3.0, device=device, dtype=dtype)
         self.check_meta(t, MetaTensor([3.0]))
         assert_allclose(t.as_tensor(), torch.as_tensor([3.0]))
Esempio n. 29
0
    def test_value_3d(
        self,
        keys,
        data,
        expected_convert_result,
        expected_zoom_result,
        expected_zoom_keepsize_result,
        expected_flip_result,
        expected_clip_result,
        expected_rotate_result,
    ):
        test_dtype = [torch.float32]
        for dtype in test_dtype:
            data = CastToTyped(keys=["image", "boxes"], dtype=dtype)(data)
            # test ConvertBoxToStandardModed
            transform_convert_mode = ConvertBoxModed(**keys)
            convert_result = transform_convert_mode(data)
            assert_allclose(convert_result["boxes"],
                            expected_convert_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)

            invert_transform_convert_mode = Invertd(
                keys=["boxes"],
                transform=transform_convert_mode,
                orig_keys=["boxes"])
            data_back = invert_transform_convert_mode(convert_result)
            if "boxes_transforms" in data_back:  # if the transform is tracked in dict:
                self.assertEqual(data_back["boxes_transforms"],
                                 [])  # it should be updated
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)

            # test ZoomBoxd
            transform_zoom = ZoomBoxd(image_keys="image",
                                      box_keys="boxes",
                                      box_ref_image_keys="image",
                                      zoom=[0.5, 3, 1.5],
                                      keep_size=False)
            zoom_result = transform_zoom(data)
            self.assertEqual(len(zoom_result["image"].applied_operations), 1)
            assert_allclose(zoom_result["boxes"],
                            expected_zoom_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)
            invert_transform_zoom = Invertd(keys=["image", "boxes"],
                                            transform=transform_zoom,
                                            orig_keys=["image", "boxes"])
            data_back = invert_transform_zoom(zoom_result)
            self.assertEqual(data_back["image"].applied_operations, [])
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)
            assert_allclose(data_back["image"],
                            data["image"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)

            transform_zoom = ZoomBoxd(image_keys="image",
                                      box_keys="boxes",
                                      box_ref_image_keys="image",
                                      zoom=[0.5, 3, 1.5],
                                      keep_size=True)
            zoom_result = transform_zoom(data)
            self.assertEqual(len(zoom_result["image"].applied_operations), 1)
            assert_allclose(zoom_result["boxes"],
                            expected_zoom_keepsize_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)

            # test RandZoomBoxd
            transform_zoom = RandZoomBoxd(
                image_keys="image",
                box_keys="boxes",
                box_ref_image_keys="image",
                prob=1.0,
                min_zoom=(0.3, ) * 3,
                max_zoom=(3.0, ) * 3,
                keep_size=False,
            )
            zoom_result = transform_zoom(data)
            self.assertEqual(len(zoom_result["image"].applied_operations), 1)
            invert_transform_zoom = Invertd(keys=["image", "boxes"],
                                            transform=transform_zoom,
                                            orig_keys=["image", "boxes"])
            data_back = invert_transform_zoom(zoom_result)
            self.assertEqual(data_back["image"].applied_operations, [])
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=0.01)
            assert_allclose(data_back["image"],
                            data["image"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)

            # test AffineBoxToImageCoordinated, AffineBoxToWorldCoordinated
            transform_affine = AffineBoxToImageCoordinated(
                box_keys="boxes", box_ref_image_keys="image")
            if not isinstance(
                    data["image"], MetaTensor
            ):  # metadict should be undefined and it's an exception
                with self.assertRaises(Exception) as context:
                    transform_affine(deepcopy(data))
                self.assertTrue(
                    "Please check whether it is the correct the image meta key."
                    in str(context.exception))

            data["image"] = MetaTensor(
                data["image"],
                meta={
                    "affine": torch.diag(1.0 / torch.Tensor([0.5, 3, 1.5, 1]))
                })
            affine_result = transform_affine(data)
            if "boxes_transforms" in affine_result:
                self.assertEqual(len(affine_result["boxes_transforms"]), 1)
            assert_allclose(affine_result["boxes"],
                            expected_zoom_result,
                            type_test=True,
                            device_test=True,
                            atol=0.01)
            invert_transform_affine = Invertd(keys=["boxes"],
                                              transform=transform_affine,
                                              orig_keys=["boxes"])
            data_back = invert_transform_affine(affine_result)
            if "boxes_transforms" in data_back:
                self.assertEqual(data_back["boxes_transforms"], [])
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=0.01)
            invert_transform_affine = AffineBoxToWorldCoordinated(
                box_keys="boxes", box_ref_image_keys="image")
            data_back = invert_transform_affine(affine_result)
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=0.01)

            # test FlipBoxd
            transform_flip = FlipBoxd(image_keys="image",
                                      box_keys="boxes",
                                      box_ref_image_keys="image",
                                      spatial_axis=[0, 1, 2])
            flip_result = transform_flip(data)
            if "boxes_transforms" in flip_result:
                self.assertEqual(len(flip_result["boxes_transforms"]), 1)
            assert_allclose(flip_result["boxes"],
                            expected_flip_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)
            invert_transform_flip = Invertd(keys=["image", "boxes"],
                                            transform=transform_flip,
                                            orig_keys=["image", "boxes"])
            data_back = invert_transform_flip(flip_result)
            if "boxes_transforms" in data_back:
                self.assertEqual(data_back["boxes_transforms"], [])
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)
            assert_allclose(data_back["image"],
                            data["image"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)

            # test RandFlipBoxd
            for spatial_axis in [(0, ), (1, ), (2, ), (0, 1), (1, 2)]:
                transform_flip = RandFlipBoxd(
                    image_keys="image",
                    box_keys="boxes",
                    box_ref_image_keys="image",
                    prob=1.0,
                    spatial_axis=spatial_axis,
                )
                flip_result = transform_flip(data)
                if "boxes_transforms" in flip_result:
                    self.assertEqual(len(flip_result["boxes_transforms"]), 1)
                invert_transform_flip = Invertd(keys=["image", "boxes"],
                                                transform=transform_flip,
                                                orig_keys=["image", "boxes"])
                data_back = invert_transform_flip(flip_result)
                if "boxes_transforms" in data_back:
                    self.assertEqual(data_back["boxes_transforms"], [])
                assert_allclose(data_back["boxes"],
                                data["boxes"],
                                type_test=False,
                                device_test=False,
                                atol=1e-3)
                assert_allclose(data_back["image"],
                                data["image"],
                                type_test=False,
                                device_test=False,
                                atol=1e-3)

            # test ClipBoxToImaged
            transform_clip = ClipBoxToImaged(box_keys="boxes",
                                             box_ref_image_keys="image",
                                             label_keys=["labels", "scores"],
                                             remove_empty=True)
            clip_result = transform_clip(data)
            assert_allclose(clip_result["boxes"],
                            expected_clip_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)
            assert_allclose(clip_result["labels"],
                            data["labels"][1:],
                            type_test=True,
                            device_test=True,
                            atol=1e-3)
            assert_allclose(clip_result["scores"],
                            data["scores"][1:],
                            type_test=True,
                            device_test=True,
                            atol=1e-3)

            transform_clip = ClipBoxToImaged(
                box_keys="boxes",
                box_ref_image_keys="image",
                label_keys=[],
                remove_empty=True)  # corner case when label_keys is empty
            clip_result = transform_clip(data)
            assert_allclose(clip_result["boxes"],
                            expected_clip_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)

            # test RandCropBoxByPosNegLabeld
            transform_crop = RandCropBoxByPosNegLabeld(
                image_keys="image",
                box_keys="boxes",
                label_keys=["labels", "scores"],
                spatial_size=2,
                num_samples=3)
            crop_result = transform_crop(data)
            assert len(crop_result) == 3
            for ll in range(3):
                assert_allclose(
                    crop_result[ll]["boxes"].shape[0],
                    crop_result[ll]["labels"].shape[0],
                    type_test=True,
                    device_test=True,
                    atol=1e-3,
                )
                assert_allclose(
                    crop_result[ll]["boxes"].shape[0],
                    crop_result[ll]["scores"].shape[0],
                    type_test=True,
                    device_test=True,
                    atol=1e-3,
                )

            # test RotateBox90d
            transform_rotate = RotateBox90d(image_keys="image",
                                            box_keys="boxes",
                                            box_ref_image_keys="image",
                                            k=1,
                                            spatial_axes=[0, 1])
            rotate_result = transform_rotate(data)
            self.assertEqual(len(rotate_result["image"].applied_operations), 1)
            assert_allclose(rotate_result["boxes"],
                            expected_rotate_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)
            invert_transform_rotate = Invertd(keys=["image", "boxes"],
                                              transform=transform_rotate,
                                              orig_keys=["image", "boxes"])
            data_back = invert_transform_rotate(rotate_result)
            self.assertEqual(data_back["image"].applied_operations, [])
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)
            assert_allclose(data_back["image"],
                            data["image"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)

            transform_rotate = RandRotateBox90d(image_keys="image",
                                                box_keys="boxes",
                                                box_ref_image_keys="image",
                                                prob=1.0,
                                                max_k=3,
                                                spatial_axes=[0, 1])
            rotate_result = transform_rotate(data)
            self.assertEqual(len(rotate_result["image"].applied_operations), 1)
            invert_transform_rotate = Invertd(keys=["image", "boxes"],
                                              transform=transform_rotate,
                                              orig_keys=["image", "boxes"])
            data_back = invert_transform_rotate(rotate_result)
            self.assertEqual(data_back["image"].applied_operations, [])
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)
            assert_allclose(data_back["image"],
                            data["image"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)
Esempio n. 30
0
# limitations under the License.

import os
import tempfile
import unittest

import torch
from parameterized import parameterized

from monai.data.meta_tensor import MetaTensor
from monai.transforms import SaveImaged

TEST_CASE_1 = [
    {
        "img":
        MetaTensor(torch.randint(0, 255, (1, 2, 3, 4)),
                   meta={"filename_or_obj": "testfile0.nii.gz"})
    },
    ".nii.gz",
    False,
]

TEST_CASE_2 = [
    {
        "img":
        MetaTensor(torch.randint(0, 255, (1, 2, 3, 4)),
                   meta={"filename_or_obj": "testfile0.nii.gz"}),
        "patch_index":
        6,
    },
    ".nii.gz",
    False,