Пример #1
0
    def test_inverse_inferred_seg(self):

        test_data = []
        for _ in range(20):
            image, label = create_test_image_2d(100, 101)
            test_data.append({"image": image, "label": label.astype(np.float32)})

        batch_size = 10
        # num workers = 0 for mac
        num_workers = 2 if sys.platform != "darwin" else 0
        transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), CenterSpatialCropd(KEYS, (110, 99))])
        num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform))

        dataset = CacheDataset(test_data, transform=transforms, progress=False)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = UNet(
            dimensions=2,
            in_channels=1,
            out_channels=1,
            channels=(2, 4),
            strides=(2,),
        ).to(device)

        data = first(loader)
        labels = data["label"].to(device)
        segs = model(labels).detach().cpu()
        label_transform_key = "label" + InverseKeys.KEY_SUFFIX
        segs_dict = {"label": segs, label_transform_key: data[label_transform_key]}

        segs_dict_decollated = decollate_batch(segs_dict)
        # inverse of individual segmentation
        seg_dict = first(segs_dict_decollated)
        # test to convert interpolation mode for 1 data of model output batch
        convert_inverse_interp_mode(seg_dict, mode="nearest", align_corners=None)

        with allow_missing_keys_mode(transforms):
            inv_seg = transforms.inverse(seg_dict)["label"]
        self.assertEqual(len(data["label_transforms"]), num_invertible_transforms)
        self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms)
        self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)

        # Inverse of batch
        batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation)
        with allow_missing_keys_mode(transforms):
            inv_batch = batch_inverter(segs_dict)
        self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape)
Пример #2
0
    def __call__(self, engine: Engine) -> None:
        """
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        """
        for output_key, batch_key, nearest_interp in zip(self.output_keys, self.batch_keys, self.nearest_interp):
            transform_key = batch_key + InverseKeys.KEY_SUFFIX
            if transform_key not in engine.state.batch:
                warnings.warn(f"all the transforms on `{batch_key}` are not InvertibleTransform.")
                continue

            transform_info = engine.state.batch[transform_key]
            if nearest_interp:
                transform_info = convert_inverse_interp_mode(
                    trans_info=deepcopy(transform_info),
                    mode="nearest",
                    align_corners=None,
                )

            segs_dict = {
                batch_key: engine.state.output[output_key].detach().cpu(),
                transform_key: transform_info,
            }
            meta_dict_key = f"{batch_key}_{self.meta_key_postfix}"
            if meta_dict_key in engine.state.batch:
                segs_dict[meta_dict_key] = engine.state.batch[meta_dict_key]

            with allow_missing_keys_mode(self.transform):  # type: ignore
                inverted_key = f"{output_key}_{self.postfix}"
                engine.state.output[inverted_key] = [self._totensor(i[batch_key]) for i in self.inverter(segs_dict)]
Пример #3
0
    def test_inverse_inferred_seg(self, extra_transform):

        test_data = []
        for _ in range(20):
            image, label = create_test_image_2d(100, 101)
            test_data.append({
                "image": image,
                "label": label.astype(np.float32)
            })

        batch_size = 10
        # num workers = 0 for mac
        num_workers = 2 if sys.platform == "linux" else 0
        transforms = Compose([
            AddChanneld(KEYS),
            SpatialPadd(KEYS, (150, 153)), extra_transform
        ])

        dataset = CacheDataset(test_data, transform=transforms, progress=False)
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers)

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = UNet(spatial_dims=2,
                     in_channels=1,
                     out_channels=1,
                     channels=(2, 4),
                     strides=(1, )).to(device)

        data = first(loader)
        self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES)

        labels = data["label"].to(device)
        self.assertIsInstance(labels, MetaTensor)
        segs = model(labels).detach().cpu()
        segs_decollated = decollate_batch(segs)
        self.assertIsInstance(segs_decollated[0], MetaTensor)
        # inverse of individual segmentation
        seg_metatensor = first(segs_decollated)
        # test to convert interpolation mode for 1 data of model output batch
        convert_applied_interp_mode(seg_metatensor.applied_operations,
                                    mode="nearest",
                                    align_corners=None)

        # manually invert the last crop samples
        xform = seg_metatensor.applied_operations.pop(-1)
        shape_before_extra_xform = xform["orig_size"]
        resizer = ResizeWithPadOrCrop(spatial_size=shape_before_extra_xform)
        with resizer.trace_transform(False):
            seg_metatensor = resizer(seg_metatensor)

        with allow_missing_keys_mode(transforms):
            inv_seg = transforms.inverse({"label": seg_metatensor})["label"]
        self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)
Пример #4
0
    def __call__(self, engine: Engine) -> None:
        """
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        """
        for output_key, batch_key, nearest_interp, to_tensor, device, post_func in zip(
                self.output_keys, self.batch_keys, self.nearest_interp,
                self.to_tensor, self.device, self.post_func):
            transform_key = batch_key + InverseKeys.KEY_SUFFIX
            if transform_key not in engine.state.batch:
                warnings.warn(
                    f"all the transforms on `{batch_key}` are not InvertibleTransform."
                )
                continue

            transform_info = engine.state.batch[transform_key]
            if nearest_interp:
                transform_info = convert_inverse_interp_mode(
                    trans_info=deepcopy(transform_info),
                    mode="nearest",
                    align_corners=None,
                )

            output = engine.state.output[output_key]
            if isinstance(output, torch.Tensor):
                output = output.detach()
            segs_dict = {
                batch_key: output,
                transform_key: transform_info,
            }
            meta_dict_key = f"{batch_key}_{self.meta_key_postfix}"
            if meta_dict_key in engine.state.batch:
                segs_dict[meta_dict_key] = engine.state.batch[meta_dict_key]

            with allow_missing_keys_mode(self.transform):  # type: ignore
                inverted = self.inverter(segs_dict)

            # save the inverted data into state.output
            inverted_key = f"{output_key}_{self.postfix}"
            engine.state.output[inverted_key] = [
                post_func(
                    self._totensor(i[batch_key]).
                    to(device) if to_tensor else i[batch_key])
                for i in inverted
            ]

            # save the inverted meta dict into state.batch
            if meta_dict_key in engine.state.batch:
                engine.state.batch[
                    f"{inverted_key}_{self.meta_key_postfix}"] = [
                        i.get(meta_dict_key) for i in inverted
                    ]
Пример #5
0
 def test_multiple(self):
     orig_states = [True, False]
     ts = [
         SpatialPadd(["image", "label"], 10, allow_missing_keys=i)
         for i in orig_states
     ]
     with allow_missing_keys_mode(ts):
         for t in ts:
             self.assertTrue(t.allow_missing_keys)
             # and that transform works even though key is missing
             _ = t(self.data)
     for t, o_s in zip(ts, orig_states):
         self.assertEqual(t.allow_missing_keys, o_s)
Пример #6
0
 def test_map_transform(self):
     for amk in [True, False]:
         t = SpatialPadd(["image", "label"], 10, allow_missing_keys=amk)
         with allow_missing_keys_mode(t):
             # check state is True
             self.assertTrue(t.allow_missing_keys)
             # and that transform works even though key is missing
             _ = t(self.data)
         # check it has returned to original state
         self.assertEqual(t.allow_missing_keys, amk)
         if not amk:
             # should fail because amks==False and key is missing
             with self.assertRaises(KeyError):
                 _ = t(self.data)
Пример #7
0
 def test_compose(self):
     amks = [True, False, True]
     t = Compose([
         SpatialPadd(["image", "label"], 10, allow_missing_keys=amk)
         for amk in amks
     ])
     with allow_missing_keys_mode(t):
         # check states are all True
         for _t in t.transforms:
             self.assertTrue(_t.allow_missing_keys)
         # and that transform works even though key is missing
         _ = t(self.data)
     # check they've returned to original state
     for _t, amk in zip(t.transforms, amks):
         self.assertEqual(_t.allow_missing_keys, amk)
     # should fail because not all amks==True and key is missing
     with self.assertRaises((KeyError, RuntimeError)):
         _ = t(self.data)
Пример #8
0
 def test_array_transform(self):
     for t in [SpatialPad(10), Compose([SpatialPad(10)])]:
         with self.assertRaises(TypeError):
             with allow_missing_keys_mode(t):
                 pass