Beispiel #1
0
 def test_correct(self):
     with tempfile.TemporaryDirectory() as temp_dir:
         transforms = Compose([
             LoadImaged(("im1", "im2")),
             EnsureChannelFirstd(("im1", "im2")),
             CopyItemsd(("im2", "im2_meta_dict"),
                        names=("im3", "im3_meta_dict")),
             ResampleToMatchd("im3", "im1_meta_dict"),
             Lambda(update_fname),
             SaveImaged("im3",
                        output_dir=temp_dir,
                        output_postfix="",
                        separate_folder=False),
         ])
         data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]})
         # check that output sizes match
         assert_allclose(data["im1"].shape, data["im3"].shape)
         # and that the meta data has been updated accordingly
         assert_allclose(data["im3"].shape[1:],
                         data["im3_meta_dict"]["spatial_shape"],
                         type_test=False)
         assert_allclose(data["im3_meta_dict"]["affine"],
                         data["im1_meta_dict"]["affine"])
         # check we're different from the original
         self.assertTrue(
             any(i != j
                 for i, j in zip(data["im3"].shape, data["im2"].shape)))
         self.assertTrue(
             any(i != j
                 for i, j in zip(data["im3_meta_dict"]["affine"].flatten(
                 ), data["im2_meta_dict"]["affine"].flatten())))
         # test the inverse
         data = Invertd("im3", transforms, "im3")(data)
         assert_allclose(data["im2"].shape, data["im3"].shape)
Beispiel #2
0
 def test_correct(self):
     transforms = Compose([
         LoadImaged(("im1", "im2")),
         EnsureChannelFirstd(("im1", "im2")),
         CopyItemsd(("im2"), names=("im3")),
         ResampleToMatchd("im3", "im1"),
         Lambda(update_fname),
         SaveImaged("im3",
                    output_dir=self.tmpdir,
                    output_postfix="",
                    separate_folder=False,
                    resample=False),
     ])
     data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]})
     # check that output sizes match
     assert_allclose(data["im1"].shape, data["im3"].shape)
     # and that the meta data has been updated accordingly
     assert_allclose(data["im3"].affine, data["im1"].affine)
     # check we're different from the original
     self.assertTrue(
         any(i != j for i, j in zip(data["im3"].shape, data["im2"].shape)))
     self.assertTrue(
         any(i != j for i, j in zip(data["im3"].affine.flatten(),
                                    data["im2"].affine.flatten())))
     # test the inverse
     data = Invertd("im3", transforms)(data)
     assert_allclose(data["im2"].shape, data["im3"].shape)
Beispiel #3
0
    def test_invert(self):
        set_determinism(seed=0)
        im_fname, seg_fname = (
            make_nifti_image(i)
            for i in create_test_image_3d(101, 100, 107, noise_max=100))
        transform = Compose([
            LoadImaged(KEYS),
            AddChanneld(KEYS),
            Orientationd(KEYS, "RPS"),
            Spacingd(KEYS,
                     pixdim=(1.2, 1.01, 0.9),
                     mode=["bilinear", "nearest"],
                     dtype=np.float32),
            ScaleIntensityd("image", minv=1, maxv=10),
            RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(KEYS, prob=0.5),
            RandRotate90d(KEYS, spatial_axes=(1, 2)),
            RandZoomd(KEYS,
                      prob=0.5,
                      min_zoom=0.5,
                      max_zoom=1.1,
                      keep_size=True),
            RandRotated(KEYS,
                        prob=0.5,
                        range_x=np.pi,
                        mode="bilinear",
                        align_corners=True,
                        dtype=np.float64),
            RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"),
            ResizeWithPadOrCropd(KEYS, 100),
            # test EnsureTensor for complicated dict data and invert it
            CopyItemsd(PostFix.meta("image"), times=1, names="test_dict"),
            # test to support Tensor, Numpy array and dictionary when inverting
            EnsureTyped(keys=["image", "test_dict"]),
            ToTensord("image"),
            CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]),
            CopyItemsd("label",
                       times=2,
                       names=["label_inverted", "label_inverted1"]),
            CopyItemsd("image",
                       times=2,
                       names=["image_inverted", "image_inverted1"]),
        ])
        data = [{"image": im_fname, "label": seg_fname} for _ in range(12)]

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available(
        ) else 2

        dataset = CacheDataset(data, transform=transform, progress=False)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=5)
        inverter = Invertd(
            # `image` was not copied, invert the original value directly
            keys=["image_inverted", "label_inverted", "test_dict"],
            transform=transform,
            orig_keys=["label", "label", "test_dict"],
            meta_keys=[
                PostFix.meta("image_inverted"),
                PostFix.meta("label_inverted"), None
            ],
            orig_meta_keys=[
                PostFix.meta("label"),
                PostFix.meta("label"), None
            ],
            nearest_interp=True,
            to_tensor=[True, False, False],
            device="cpu",
        )

        inverter_1 = Invertd(
            # `image` was not copied, invert the original value directly
            keys=["image_inverted1", "label_inverted1"],
            transform=transform,
            orig_keys=["image", "image"],
            meta_keys=[
                PostFix.meta("image_inverted1"),
                PostFix.meta("label_inverted1")
            ],
            orig_meta_keys=[PostFix.meta("image"),
                            PostFix.meta("image")],
            nearest_interp=[True, False],
            to_tensor=[True, True],
            device="cpu",
        )

        expected_keys = [
            "image",
            "image_inverted",
            "image_inverted1",
            PostFix.meta("image_inverted1"),
            PostFix.meta("image_inverted"),
            PostFix.meta("image"),
            "image_transforms",
            "label",
            "label_inverted",
            "label_inverted1",
            PostFix.meta("label_inverted1"),
            PostFix.meta("label_inverted"),
            PostFix.meta("label"),
            "label_transforms",
            "test_dict",
            "test_dict_transforms",
        ]
        # execute 1 epoch
        for d in loader:
            d = decollate_batch(d)
            for item in d:
                item = inverter(item)
                item = inverter_1(item)

                self.assertListEqual(sorted(item), expected_keys)
                self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100))
                self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100))
                # check the nearest interpolation mode
                i = item["image_inverted"]
                torch.testing.assert_allclose(
                    i.to(torch.uint8).to(torch.float), i.to(torch.float))
                self.assertTupleEqual(i.shape[1:], (100, 101, 107))
                i = item["label_inverted"]
                torch.testing.assert_allclose(
                    i.to(torch.uint8).to(torch.float), i.to(torch.float))
                self.assertTupleEqual(i.shape[1:], (100, 101, 107))
                # test inverted test_dict
                self.assertTrue(
                    isinstance(item["test_dict"]["affine"], np.ndarray))
                self.assertTrue(
                    isinstance(item["test_dict"]["filename_or_obj"], str))

                # check the case that different items use different interpolation mode to invert transforms
                d = item["image_inverted1"]
                # if the interpolation mode is nearest, accumulated diff should be smaller than 1
                self.assertLess(
                    torch.sum(
                        d.to(torch.float) -
                        d.to(torch.uint8).to(torch.float)).item(), 1.0)
                self.assertTupleEqual(d.shape, (1, 100, 101, 107))

                d = item["label_inverted1"]
                # if the interpolation mode is not nearest, accumulated diff should be greater than 10000
                self.assertGreater(
                    torch.sum(
                        d.to(torch.float) -
                        d.to(torch.uint8).to(torch.float)).item(), 10000.0)
                self.assertTupleEqual(d.shape, (1, 100, 101, 107))

        # check labels match
        reverted = item["label_inverted"].detach().cpu().numpy().astype(
            np.int32)
        original = LoadImaged(KEYS)(data[-1])["label"]
        n_good = np.sum(np.isclose(reverted, original, atol=1e-3))
        reverted_name = item[PostFix.meta("label_inverted")]["filename_or_obj"]
        original_name = data[-1]["label"]
        self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        # 25300: 2 workers (cpu, non-macos)
        # 1812: 0 workers (gpu or macos)
        # 1821: windows torch 1.10.0
        self.assertTrue((reverted.size - n_good) in (34007, 1812, 1821),
                        f"diff.  {reverted.size - n_good}")

        set_determinism(seed=None)
Beispiel #4
0
    def __init__(
        self,
        transform: InvertibleTransform,
        loader: TorchDataLoader,
        output_keys: Union[str, Sequence[str]] = CommonKeys.PRED,
        batch_keys: Union[str, Sequence[str]] = CommonKeys.IMAGE,
        meta_key_postfix: str = "meta_dict",
        collate_fn: Optional[Callable] = no_collation,
        postfix: str = "inverted",
        nearest_interp: Union[bool, Sequence[bool]] = True,
        to_tensor: Union[bool, Sequence[bool]] = True,
        device: Union[Union[str, torch.device],
                      Sequence[Union[str, torch.device]]] = "cpu",
        post_func: Union[Callable, Sequence[Callable]] = lambda x: x,
        num_workers: Optional[int] = 0,
    ) -> None:
        """
        Args:
            transform: a callable data transform on input data.
            loader: data loader used to run transforms and generate the batch of data.
            output_keys: the key of expected data in `ignite.engine.output`, invert transforms on it.
                it also can be a list of keys, will invert transform for each of them. Default to "pred".
            batch_keys: the key of input data in `ignite.engine.batch`. will get the applied transforms
                for this input data, then invert them for the expected data with `output_keys`.
                It can also be a list of keys, each matches to the `output_keys` data. default to "image".
            meta_key_postfix: use `{batch_key}_{postfix}` to to fetch the meta data according to the key data,
                default is `meta_dict`, the meta data is a dictionary object.
                For example, to handle key `image`,  read/write affine matrices from the
                metadata `image_meta_dict` dictionary's `affine` field.
            collate_fn: how to collate data after inverse transformations.
                default won't do any collation, so the output will be a list of size batch size.
            postfix: will save the inverted result into `ignite.engine.output` with key `{output_key}_{postfix}`.
            nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms,
                default to `True`. If `False`, use the same interpolation mode as the original transform.
                it also can be a list of bool, each matches to the `output_keys` data.
            to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.
                it also can be a list of bool, each matches to the `output_keys` data.
            device: if converted to Tensor, move the inverted results to target device before `post_func`,
                default to "cpu", it also can be a list of string or `torch.device`,
                each matches to the `output_keys` data.
            post_func: post processing for the inverted data, should be a callable function.
                it also can be a list of callable, each matches to the `output_keys` data.
            num_workers: number of workers when run data loader for inverse transforms,
                default to 0 as only run one iteration and multi-processing may be even slower.
                Set to `None`, to use the `num_workers` of the input transform data loader.

        """
        self.inverter = Invertd(
            keys=output_keys,
            transform=transform,
            loader=loader,
            orig_keys=batch_keys,
            meta_key_postfix=meta_key_postfix,
            collate_fn=collate_fn,
            postfix=postfix,
            nearest_interp=nearest_interp,
            to_tensor=to_tensor,
            device=device,
            post_func=post_func,
            num_workers=num_workers,
        )
        self.output_keys = ensure_tuple(output_keys)
        self.meta_key_postfix = meta_key_postfix
        self.postfix = postfix
Beispiel #5
0
    def test_invert(self):
        set_determinism(seed=0)
        im_fname, seg_fname = [
            make_nifti_image(i)
            for i in create_test_image_3d(101, 100, 107, noise_max=100)
        ]
        transform = Compose([
            LoadImaged(KEYS),
            AddChanneld(KEYS),
            Orientationd(KEYS, "RPS"),
            Spacingd(KEYS,
                     pixdim=(1.2, 1.01, 0.9),
                     mode=["bilinear", "nearest"],
                     dtype=np.float32),
            ScaleIntensityd("image", minv=1, maxv=10),
            RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(KEYS, prob=0.5),
            RandRotate90d(KEYS, spatial_axes=(1, 2)),
            RandZoomd(KEYS,
                      prob=0.5,
                      min_zoom=0.5,
                      max_zoom=1.1,
                      keep_size=True),
            RandRotated(KEYS,
                        prob=0.5,
                        range_x=np.pi,
                        mode="bilinear",
                        align_corners=True),
            RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"),
            ResizeWithPadOrCropd(KEYS, 100),
            ToTensord(
                "image"
            ),  # test to support both Tensor and Numpy array when inverting
            CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]),
        ])
        data = [{"image": im_fname, "label": seg_fname} for _ in range(12)]

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available(
        ) else 2

        dataset = CacheDataset(data, transform=transform, progress=False)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=5)
        inverter = Invertd(
            keys=["image", "label"],
            transform=transform,
            loader=loader,
            orig_keys="label",
            nearest_interp=True,
            postfix="inverted",
            to_tensor=[True, False],
            device="cpu",
            num_workers=0
            if sys.platform == "darwin" or torch.cuda.is_available() else 2,
        )

        # execute 1 epoch
        for d in loader:
            d = inverter(d)
            # this unit test only covers basic function, test_handler_transform_inverter covers more
            self.assertTupleEqual(d["image"].shape[1:], (1, 100, 100, 100))
            self.assertTupleEqual(d["label"].shape[1:], (1, 100, 100, 100))
            # check the nearest inerpolation mode
            for i in d["image_inverted"]:
                torch.testing.assert_allclose(
                    i.to(torch.uint8).to(torch.float), i.to(torch.float))
                self.assertTupleEqual(i.shape, (1, 100, 101, 107))
            for i in d["label_inverted"]:
                np.testing.assert_allclose(
                    i.astype(np.uint8).astype(np.float32),
                    i.astype(np.float32))
                self.assertTupleEqual(i.shape, (1, 100, 101, 107))

        set_determinism(seed=None)
Beispiel #6
0
    def test_value_3d(
        self,
        keys,
        data,
        expected_convert_result,
        expected_zoom_result,
        expected_zoom_keepsize_result,
        expected_flip_result,
        expected_clip_result,
        expected_rotate_result,
    ):
        test_dtype = [torch.float32]
        for dtype in test_dtype:
            data = CastToTyped(keys=["image", "boxes"], dtype=dtype)(data)
            # test ConvertBoxToStandardModed
            transform_convert_mode = ConvertBoxModed(**keys)
            convert_result = transform_convert_mode(data)
            assert_allclose(convert_result["boxes"],
                            expected_convert_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)

            invert_transform_convert_mode = Invertd(
                keys=["boxes"],
                transform=transform_convert_mode,
                orig_keys=["boxes"])
            data_back = invert_transform_convert_mode(convert_result)
            if "boxes_transforms" in data_back:  # if the transform is tracked in dict:
                self.assertEqual(data_back["boxes_transforms"],
                                 [])  # it should be updated
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)

            # test ZoomBoxd
            transform_zoom = ZoomBoxd(image_keys="image",
                                      box_keys="boxes",
                                      box_ref_image_keys="image",
                                      zoom=[0.5, 3, 1.5],
                                      keep_size=False)
            zoom_result = transform_zoom(data)
            self.assertEqual(len(zoom_result["image"].applied_operations), 1)
            assert_allclose(zoom_result["boxes"],
                            expected_zoom_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)
            invert_transform_zoom = Invertd(keys=["image", "boxes"],
                                            transform=transform_zoom,
                                            orig_keys=["image", "boxes"])
            data_back = invert_transform_zoom(zoom_result)
            self.assertEqual(data_back["image"].applied_operations, [])
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)
            assert_allclose(data_back["image"],
                            data["image"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)

            transform_zoom = ZoomBoxd(image_keys="image",
                                      box_keys="boxes",
                                      box_ref_image_keys="image",
                                      zoom=[0.5, 3, 1.5],
                                      keep_size=True)
            zoom_result = transform_zoom(data)
            self.assertEqual(len(zoom_result["image"].applied_operations), 1)
            assert_allclose(zoom_result["boxes"],
                            expected_zoom_keepsize_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)

            # test RandZoomBoxd
            transform_zoom = RandZoomBoxd(
                image_keys="image",
                box_keys="boxes",
                box_ref_image_keys="image",
                prob=1.0,
                min_zoom=(0.3, ) * 3,
                max_zoom=(3.0, ) * 3,
                keep_size=False,
            )
            zoom_result = transform_zoom(data)
            self.assertEqual(len(zoom_result["image"].applied_operations), 1)
            invert_transform_zoom = Invertd(keys=["image", "boxes"],
                                            transform=transform_zoom,
                                            orig_keys=["image", "boxes"])
            data_back = invert_transform_zoom(zoom_result)
            self.assertEqual(data_back["image"].applied_operations, [])
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=0.01)
            assert_allclose(data_back["image"],
                            data["image"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)

            # test AffineBoxToImageCoordinated, AffineBoxToWorldCoordinated
            transform_affine = AffineBoxToImageCoordinated(
                box_keys="boxes", box_ref_image_keys="image")
            if not isinstance(
                    data["image"], MetaTensor
            ):  # metadict should be undefined and it's an exception
                with self.assertRaises(Exception) as context:
                    transform_affine(deepcopy(data))
                self.assertTrue(
                    "Please check whether it is the correct the image meta key."
                    in str(context.exception))

            data["image"] = MetaTensor(
                data["image"],
                meta={
                    "affine": torch.diag(1.0 / torch.Tensor([0.5, 3, 1.5, 1]))
                })
            affine_result = transform_affine(data)
            if "boxes_transforms" in affine_result:
                self.assertEqual(len(affine_result["boxes_transforms"]), 1)
            assert_allclose(affine_result["boxes"],
                            expected_zoom_result,
                            type_test=True,
                            device_test=True,
                            atol=0.01)
            invert_transform_affine = Invertd(keys=["boxes"],
                                              transform=transform_affine,
                                              orig_keys=["boxes"])
            data_back = invert_transform_affine(affine_result)
            if "boxes_transforms" in data_back:
                self.assertEqual(data_back["boxes_transforms"], [])
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=0.01)
            invert_transform_affine = AffineBoxToWorldCoordinated(
                box_keys="boxes", box_ref_image_keys="image")
            data_back = invert_transform_affine(affine_result)
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=0.01)

            # test FlipBoxd
            transform_flip = FlipBoxd(image_keys="image",
                                      box_keys="boxes",
                                      box_ref_image_keys="image",
                                      spatial_axis=[0, 1, 2])
            flip_result = transform_flip(data)
            if "boxes_transforms" in flip_result:
                self.assertEqual(len(flip_result["boxes_transforms"]), 1)
            assert_allclose(flip_result["boxes"],
                            expected_flip_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)
            invert_transform_flip = Invertd(keys=["image", "boxes"],
                                            transform=transform_flip,
                                            orig_keys=["image", "boxes"])
            data_back = invert_transform_flip(flip_result)
            if "boxes_transforms" in data_back:
                self.assertEqual(data_back["boxes_transforms"], [])
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)
            assert_allclose(data_back["image"],
                            data["image"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)

            # test RandFlipBoxd
            for spatial_axis in [(0, ), (1, ), (2, ), (0, 1), (1, 2)]:
                transform_flip = RandFlipBoxd(
                    image_keys="image",
                    box_keys="boxes",
                    box_ref_image_keys="image",
                    prob=1.0,
                    spatial_axis=spatial_axis,
                )
                flip_result = transform_flip(data)
                if "boxes_transforms" in flip_result:
                    self.assertEqual(len(flip_result["boxes_transforms"]), 1)
                invert_transform_flip = Invertd(keys=["image", "boxes"],
                                                transform=transform_flip,
                                                orig_keys=["image", "boxes"])
                data_back = invert_transform_flip(flip_result)
                if "boxes_transforms" in data_back:
                    self.assertEqual(data_back["boxes_transforms"], [])
                assert_allclose(data_back["boxes"],
                                data["boxes"],
                                type_test=False,
                                device_test=False,
                                atol=1e-3)
                assert_allclose(data_back["image"],
                                data["image"],
                                type_test=False,
                                device_test=False,
                                atol=1e-3)

            # test ClipBoxToImaged
            transform_clip = ClipBoxToImaged(box_keys="boxes",
                                             box_ref_image_keys="image",
                                             label_keys=["labels", "scores"],
                                             remove_empty=True)
            clip_result = transform_clip(data)
            assert_allclose(clip_result["boxes"],
                            expected_clip_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)
            assert_allclose(clip_result["labels"],
                            data["labels"][1:],
                            type_test=True,
                            device_test=True,
                            atol=1e-3)
            assert_allclose(clip_result["scores"],
                            data["scores"][1:],
                            type_test=True,
                            device_test=True,
                            atol=1e-3)

            transform_clip = ClipBoxToImaged(
                box_keys="boxes",
                box_ref_image_keys="image",
                label_keys=[],
                remove_empty=True)  # corner case when label_keys is empty
            clip_result = transform_clip(data)
            assert_allclose(clip_result["boxes"],
                            expected_clip_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)

            # test RandCropBoxByPosNegLabeld
            transform_crop = RandCropBoxByPosNegLabeld(
                image_keys="image",
                box_keys="boxes",
                label_keys=["labels", "scores"],
                spatial_size=2,
                num_samples=3)
            crop_result = transform_crop(data)
            assert len(crop_result) == 3
            for ll in range(3):
                assert_allclose(
                    crop_result[ll]["boxes"].shape[0],
                    crop_result[ll]["labels"].shape[0],
                    type_test=True,
                    device_test=True,
                    atol=1e-3,
                )
                assert_allclose(
                    crop_result[ll]["boxes"].shape[0],
                    crop_result[ll]["scores"].shape[0],
                    type_test=True,
                    device_test=True,
                    atol=1e-3,
                )

            # test RotateBox90d
            transform_rotate = RotateBox90d(image_keys="image",
                                            box_keys="boxes",
                                            box_ref_image_keys="image",
                                            k=1,
                                            spatial_axes=[0, 1])
            rotate_result = transform_rotate(data)
            self.assertEqual(len(rotate_result["image"].applied_operations), 1)
            assert_allclose(rotate_result["boxes"],
                            expected_rotate_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)
            invert_transform_rotate = Invertd(keys=["image", "boxes"],
                                              transform=transform_rotate,
                                              orig_keys=["image", "boxes"])
            data_back = invert_transform_rotate(rotate_result)
            self.assertEqual(data_back["image"].applied_operations, [])
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)
            assert_allclose(data_back["image"],
                            data["image"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)

            transform_rotate = RandRotateBox90d(image_keys="image",
                                                box_keys="boxes",
                                                box_ref_image_keys="image",
                                                prob=1.0,
                                                max_k=3,
                                                spatial_axes=[0, 1])
            rotate_result = transform_rotate(data)
            self.assertEqual(len(rotate_result["image"].applied_operations), 1)
            invert_transform_rotate = Invertd(keys=["image", "boxes"],
                                              transform=transform_rotate,
                                              orig_keys=["image", "boxes"])
            data_back = invert_transform_rotate(rotate_result)
            self.assertEqual(data_back["image"].applied_operations, [])
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)
            assert_allclose(data_back["image"],
                            data["image"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)
Beispiel #7
0
def main(tempdir):
    print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, _ = create_test_image_3d(128,
                                     128,
                                     128,
                                     num_seg_classes=1,
                                     channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    files = [{"img": img} for img in images]

    # define pre transforms
    pre_transforms = Compose([
        LoadImaged(keys="img"),
        EnsureChannelFirstd(keys="img"),
        Orientationd(keys="img", axcodes="RAS"),
        Resized(keys="img",
                spatial_size=(96, 96, 96),
                mode="trilinear",
                align_corners=True),
        ScaleIntensityd(keys="img"),
        EnsureTyped(keys="img"),
    ])
    # define dataset and dataloader
    dataset = Dataset(data=files, transform=pre_transforms)
    dataloader = DataLoader(dataset, batch_size=2, num_workers=4)
    # define post transforms
    post_transforms = Compose([
        EnsureTyped(keys="pred"),
        Activationsd(keys="pred", sigmoid=True),
        Invertd(
            keys=
            "pred",  # invert the `pred` data field, also support multiple fields
            transform=pre_transforms,
            orig_keys=
            "img",  # get the previously applied pre_transforms information on the `img` data field,
            # then invert `pred` based on this information. we can use same info
            # for multiple fields, also support different orig_keys for different fields
            meta_keys=
            "pred_meta_dict",  # key field to save inverted meta data, every item maps to `keys`
            orig_meta_keys=
            "img_meta_dict",  # get the meta data from `img_meta_dict` field when inverting,
            # for example, may need the `affine` to invert `Spacingd` transform,
            # multiple fields can use the same meta data to invert
            meta_key_postfix=
            "meta_dict",  # if `meta_keys=None`, use "{keys}_{meta_key_postfix}" as the meta key,
            # if `orig_meta_keys=None`, use "{orig_keys}_{meta_key_postfix}",
            # otherwise, no need this arg during inverting
            nearest_interp=
            False,  # don't change the interpolation mode to "nearest" when inverting transforms
            # to ensure a smooth output, then execute `AsDiscreted` transform
            to_tensor=True,  # convert to PyTorch Tensor after inverting
        ),
        AsDiscreted(keys="pred", threshold=0.5),
        SaveImaged(keys="pred",
                   meta_keys="pred_meta_dict",
                   output_dir="./out",
                   output_postfix="seg",
                   resample=False),
    ])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    net.load_state_dict(
        torch.load("best_metric_model_segmentation3d_dict.pth"))

    net.eval()
    with torch.no_grad():
        for d in dataloader:
            images = d["img"].to(device)
            # define sliding window size and batch size for windows inference
            d["pred"] = sliding_window_inference(inputs=images,
                                                 roi_size=(96, 96, 96),
                                                 sw_batch_size=4,
                                                 predictor=net)
            # decollate the batch data into a list of dictionaries, then execute postprocessing transforms
            d = [post_transforms(i) for i in decollate_batch(d)]
Beispiel #8
0
    def __init__(
        self,
        transform: InvertibleTransform,
        loader: TorchDataLoader,
        output_keys: KeysCollection = CommonKeys.PRED,
        batch_keys: KeysCollection = CommonKeys.IMAGE,
        meta_keys: Optional[KeysCollection] = None,
        batch_meta_keys: Optional[KeysCollection] = None,
        meta_key_postfix: str = "meta_dict",
        collate_fn: Optional[Callable] = no_collation,
        nearest_interp: Union[bool, Sequence[bool]] = True,
        to_tensor: Union[bool, Sequence[bool]] = True,
        device: Union[Union[str, torch.device],
                      Sequence[Union[str, torch.device]]] = "cpu",
        post_func: Union[Callable, Sequence[Callable]] = lambda x: x,
        num_workers: Optional[int] = 0,
    ) -> None:
        """
        Args:
            transform: a callable data transform on input data.
            loader: data loader used to run transforms and generate the batch of data.
            output_keys: the key of expected data in `ignite.engine.output`, invert transforms on it.
                it also can be a list of keys, will invert transform for each of them.
                Default to "pred". it's in-place operation.
            batch_keys: the key of input data in `ignite.engine.batch`. will get the applied transforms
                for this input data, then invert them for the expected data with `output_keys`.
                It can also be a list of keys, each matches to the `output_keys` data. default to "image".
            meta_keys: explicitly indicate the key for the inverted meta data dictionary.
                the meta data is a dictionary object which contains: filename, original_shape, etc.
                it can be a sequence of string, map to the `keys`.
                if None, will try to construct meta_keys by `{key}_{meta_key_postfix}`.
            batch_meta_keys: the key of the meta data of input data in `ignite.engine.batch`,
                will get the `affine`, `data_shape`, etc.
                the meta data is a dictionary object which contains: filename, original_shape, etc.
                it can be a sequence of string, map to the `keys`.
                if None, will try to construct meta_keys by `{orig_key}_{meta_key_postfix}`.
                meta data will also be inverted and stored in `meta_keys`.
            meta_key_postfix: if `orig_meta_keys` is None, use `{orig_key}_{meta_key_postfix}` to to fetch the
                meta data from dict, if `meta_keys` is None, use `{key}_{meta_key_postfix}`.
                default is `meta_dict`, the meta data is a dictionary object.
                For example, to handle orig_key `image`,  read/write `affine` matrices from the
                metadata `image_meta_dict` dictionary's `affine` field.
                the inverted meta dict will be stored with key: "{key}_{meta_key_postfix}".
            collate_fn: how to collate data after inverse transformations. default won't do any collation,
                so the output will be a list of PyTorch Tensor or numpy array without batch dim.
            nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms,
                default to `True`. If `False`, use the same interpolation mode as the original transform.
                it also can be a list of bool, each matches to the `output_keys` data.
            to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.
                it also can be a list of bool, each matches to the `output_keys` data.
            device: if converted to Tensor, move the inverted results to target device before `post_func`,
                default to "cpu", it also can be a list of string or `torch.device`,
                each matches to the `output_keys` data.
            post_func: post processing for the inverted data, should be a callable function.
                it also can be a list of callable, each matches to the `output_keys` data.
            num_workers: number of workers when run data loader for inverse transforms,
                default to 0 as only run one iteration and multi-processing may be even slower.
                Set to `None`, to use the `num_workers` of the input transform data loader.

        """
        self.inverter = Invertd(
            keys=output_keys,
            transform=transform,
            loader=loader,
            orig_keys=batch_keys,
            meta_keys=meta_keys,
            orig_meta_keys=batch_meta_keys,
            meta_key_postfix=meta_key_postfix,
            collate_fn=collate_fn,
            nearest_interp=nearest_interp,
            to_tensor=to_tensor,
            device=device,
            post_func=post_func,
            num_workers=num_workers,
        )
        self.output_keys = ensure_tuple(output_keys)
        self.meta_keys = ensure_tuple_rep(None, len(
            self.output_keys)) if meta_keys is None else ensure_tuple(
                meta_keys)
        if len(self.output_keys) != len(self.meta_keys):
            raise ValueError(
                "meta_keys should have the same length as output_keys.")
        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix,
                                                 len(self.output_keys))
Beispiel #9
0
def main(tempdir):
    print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, _ = create_test_image_3d(128,
                                     128,
                                     128,
                                     num_seg_classes=1,
                                     channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    files = [{"img": img} for img in images]

    # define pre transforms
    pre_transforms = Compose([
        LoadImaged(keys="img"),
        EnsureChannelFirstd(keys="img"),
        Orientationd(keys="img", axcodes="RAS"),
        Resized(keys="img",
                spatial_size=(96, 96, 96),
                mode="trilinear",
                align_corners=True),
        ScaleIntensityd(keys="img"),
        ToTensord(keys="img"),
    ])
    # define dataset and dataloader
    dataset = Dataset(data=files, transform=pre_transforms)
    dataloader = DataLoader(dataset, batch_size=2, num_workers=4)
    # define post transforms
    post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        Invertd(keys="pred",
                transform=pre_transforms,
                loader=dataloader,
                orig_keys="img",
                nearest_interp=True),
        SaveImaged(keys="pred_inverted",
                   output_dir="./output",
                   output_postfix="seg",
                   resample=False),
    ])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    net.load_state_dict(
        torch.load("best_metric_model_segmentation3d_dict.pth"))

    net.eval()
    with torch.no_grad():
        for d in dataloader:
            images = d["img"].to(device)
            # define sliding window size and batch size for windows inference
            d["pred"] = sliding_window_inference(inputs=images,
                                                 roi_size=(96, 96, 96),
                                                 sw_batch_size=4,
                                                 predictor=net)
            # execute post transforms to invert spatial transforms and save to NIfTI files
            post_transforms(d)