Exemplo n.º 1
0
 def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d,
                   acceptable_diff):
     for key in keys:
         orig = orig_d[key]
         fwd_bck = fwd_bck_d[key]
         if isinstance(fwd_bck, torch.Tensor):
             fwd_bck = fwd_bck.cpu().numpy()
         unmodified = unmodified_d[key]
         if isinstance(orig, np.ndarray):
             mean_diff = np.mean(np.abs(orig - fwd_bck))
             resized = ResizeWithPadOrCrop(orig.shape[1:])(unmodified)
             if isinstance(resized, torch.Tensor):
                 resized = resized.detach().cpu().numpy()
             unmodded_diff = np.mean(np.abs(orig - resized))
             try:
                 self.assertLessEqual(mean_diff, acceptable_diff)
             except AssertionError:
                 print(
                     f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}"
                 )
                 if orig[0].ndim == 1:
                     print("orig", orig[0])
                     print("fwd_bck", fwd_bck[0])
                     print("unmod", unmodified[0])
                 raise
Exemplo n.º 2
0
    def test_inverse_inferred_seg(self, extra_transform):

        test_data = []
        for _ in range(20):
            image, label = create_test_image_2d(100, 101)
            test_data.append({
                "image": image,
                "label": label.astype(np.float32)
            })

        batch_size = 10
        # num workers = 0 for mac
        num_workers = 2 if sys.platform == "linux" else 0
        transforms = Compose([
            AddChanneld(KEYS),
            SpatialPadd(KEYS, (150, 153)), extra_transform
        ])

        dataset = CacheDataset(test_data, transform=transforms, progress=False)
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers)

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = UNet(spatial_dims=2,
                     in_channels=1,
                     out_channels=1,
                     channels=(2, 4),
                     strides=(1, )).to(device)

        data = first(loader)
        self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES)

        labels = data["label"].to(device)
        self.assertIsInstance(labels, MetaTensor)
        segs = model(labels).detach().cpu()
        segs_decollated = decollate_batch(segs)
        self.assertIsInstance(segs_decollated[0], MetaTensor)
        # inverse of individual segmentation
        seg_metatensor = first(segs_decollated)
        # test to convert interpolation mode for 1 data of model output batch
        convert_applied_interp_mode(seg_metatensor.applied_operations,
                                    mode="nearest",
                                    align_corners=None)

        # manually invert the last crop samples
        xform = seg_metatensor.applied_operations.pop(-1)
        shape_before_extra_xform = xform["orig_size"]
        resizer = ResizeWithPadOrCrop(spatial_size=shape_before_extra_xform)
        with resizer.trace_transform(False):
            seg_metatensor = resizer(seg_metatensor)

        with allow_missing_keys_mode(transforms):
            inv_seg = transforms.inverse({"label": seg_metatensor})["label"]
        self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)
Exemplo n.º 3
0
 def __call__(self, image, label, params=None):
     if params is None:
         params = self._random()
     self.params = params
     if self.params['width'] > 0 and self.params['height'] > 0:
         image = ResizeWithPadOrCrop(
             [self.params['width'], self.params['height']])(image)
         label = ResizeWithPadOrCrop(
             [self.params['width'], self.params['height']])(label)
     return image, label
Exemplo n.º 4
0
    def __call__(
        self,
        img: np.ndarray,
        msk: Optional[np.ndarray] = None,
        center: Optional[tuple] = None,
        z_axis: Optional[int] = None,
    ):
        """
        Apply the transform to `img`, assuming `img` is channel-first and
        slicing doesn't apply to the channel dim.
        """
        if self.mask_data is None and msk is None:
            raise ValueError("Unknown mask_data.")
        mask_data_ = np.array([[1]])
        if self.mask_data is not None and msk is None:
            mask_data_ = self.mask_data > 0
        if msk is not None:
            mask_data_ = msk > 0
        mask_data_ = np.asarray(mask_data_)

        if mask_data_.shape[0] != 1 and mask_data_.shape[0] != img.shape[0]:
            raise ValueError(
                "When mask_data is not single channel, mask_data channels must match img, "
                f"got img={img.shape[0]} mask_data={mask_data_.shape[0]}.")

        z_axis_ = z_axis if z_axis is not None else self.z_axis

        if center is None:
            center = self.get_center_pos(mask_data_, z_axis_)

        if self.crop_mode in ["single", "parallel"]:
            size_ = self.get_new_spatial_size(z_axis_)
            size_ = list(map(int, size_))
            slice_ = SpatialCrop(roi_center=center, roi_size=size_)(img)
            if np.any(slice_.shape[1:] != size_):
                slice_ = ResizeWithPadOrCrop(spatial_size=size_)(slice_)

            return np.moveaxis(slice_.squeeze(0), z_axis_, 0)
        else:
            cross_slices = np.zeros(shape=(3, ) + self.roi_size)
            for k in range(3):
                size_ = np.insert(self.roi_size, k, 1)
                slice_ = SpatialCrop(roi_center=center, roi_size=size_)(img)
                if np.any(slice_.shape[1:] != size_):
                    slice_ = ResizeWithPadOrCrop(spatial_size=size_)(slice_)

                cross_slices[k] = slice_.squeeze()
            return cross_slices
Exemplo n.º 5
0
    def __call__(self, data):
        d = dict(data)

        current_shape = d[self.ref_image].shape[1:]
        dims = len(current_shape)
        croppad = ResizeWithPadOrCrop(spatial_size=self.spatial_size)

        # get guidance following pad or crop to spatial_size
        new_guidance = []
        for guidance in d[self.guidance]:
            if guidance:
                signal = np.zeros(current_shape)
                for point in guidance:
                    if dims == 2:
                        signal[point[0], point[1]] = 1.0
                    else:
                        signal[point[0], point[1], point[2]] = 1.0
                signal = croppad(signal[np.newaxis, :]).squeeze(
                    0)  # croppad requires channel dim
                new_guidance.append(
                    np.argwhere(signal == 1.0).astype(int).tolist())
            else:
                new_guidance.append([])

        d[self.guidance] = new_guidance
        return d
Exemplo n.º 6
0
 def test_pad_shape(self, input_param, input_shape, expected_shape):
     for p in TEST_NDARRAYS_ALL:
         if isinstance(
                 p(0),
                 torch.Tensor) and ("constant_values" in input_param
                                    or input_param["mode"] == "reflect"):
             continue
         padcropper = ResizeWithPadOrCrop(**input_param)
         result = padcropper(p(np.zeros(input_shape)))
         np.testing.assert_allclose(result.shape, expected_shape)
         result = padcropper(p(np.zeros(input_shape)), mode="constant")
         np.testing.assert_allclose(result.shape, expected_shape)
         self.assertIsInstance(result, MetaTensor)
         self.assertEqual(len(result.applied_operations), 1)
         inv = padcropper.inverse(result)
         self.assertTupleEqual(inv.shape, input_shape)
         self.assertIsInstance(inv, MetaTensor)
         self.assertEqual(inv.applied_operations, [])
Exemplo n.º 7
0
 def test_pad_shape(self, input_param, input_shape, expected_shape):
     for p in TEST_NDARRAYS:
         if isinstance(
                 p(0),
                 torch.Tensor) and ("constant_values" in input_param
                                    or input_param["mode"] == "reflect"):
             continue
         paddcroper = ResizeWithPadOrCrop(**input_param)
         result = paddcroper(p(np.zeros(input_shape)))
         np.testing.assert_allclose(result.shape, expected_shape)
         result = paddcroper(p(np.zeros(input_shape)), mode="constant")
         np.testing.assert_allclose(result.shape, expected_shape)
Exemplo n.º 8
0
    def __call__(
        self,
        img: np.ndarray,
        label: Optional[np.ndarray] = None,
        image: Optional[np.ndarray] = None,
        fg_indices: Optional[np.ndarray] = None,
        bg_indices: Optional[np.ndarray] = None,
    ) -> List[np.ndarray]:
        """
        Args:
            img: input data to crop samples from based on the pos/neg ratio of `label` and `image`.
                Assumes `img` is a channel-first array.
            label: the label image that is used for finding foreground/background, if None, use `self.label`.
            image: optional image data to help select valid area, can be same as `img` or another image array.
                use ``label == 0 & image > image_threshold`` to select the negative sample(background) center.
                so the crop center will only exist on valid image area. if None, use `self.image`.
            fg_indices: foreground indices to randomly select crop centers,
                need to provide `fg_indices` and `bg_indices` together.
            bg_indices: background indices to randomly select crop centers,
                need to provide `fg_indices` and `bg_indices` together.

        """
        if label is None:
            label = self.label
        if label is None:
            raise ValueError("label should be provided.")
        if image is None:
            image = self.image
        if fg_indices is None or bg_indices is None:
            if self.fg_indices is not None and self.bg_indices is not None:
                fg_indices = self.fg_indices
                bg_indices = self.bg_indices
            else:
                fg_indices, bg_indices = map_binary_to_indices(
                    label, image, self.image_threshold)

        if self.target_label is not None:
            label = (label == self.target_label).astype(np.uint8)

        self.randomize(label, fg_indices, bg_indices, image)
        results: List[np.ndarray] = []
        if self.centers is not None:
            for center in self.centers:
                if np.any(np.greater(self.spatial_size, img.shape[1:])):
                    cropper = ResizeWithPadOrCrop(
                        spatial_size=self.spatial_size)
                else:
                    cropper = SpatialCrop(
                        roi_center=tuple(center),
                        spatial_size=self.spatial_size)  # type: ignore
                results.append(cropper(img))

        return results
Exemplo n.º 9
0
    def test_invert(self):
        set_determinism(seed=0)
        im_fname = make_nifti_image(create_test_image_3d(101, 100, 107, noise_max=100)[1])  # label image, discrete
        data = [im_fname for _ in range(12)]
        transform = Compose(
            [
                LoadImage(image_only=True),
                EnsureChannelFirst(),
                Orientation("RPS"),
                Spacing(pixdim=(1.2, 1.01, 0.9), mode="bilinear", dtype=np.float32),
                RandFlip(prob=0.5, spatial_axis=[1, 2]),
                RandAxisFlip(prob=0.5),
                RandRotate90(prob=0, spatial_axes=(1, 2)),
                RandZoom(prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True),
                RandRotate(prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64),
                RandAffine(prob=0.5, rotate_range=np.pi, mode="nearest"),
                ResizeWithPadOrCrop(100),
                CastToType(dtype=torch.uint8),
            ]
        )

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2
        dataset = Dataset(data, transform=transform)
        self.assertIsInstance(transform.inverse(dataset[0]), MetaTensor)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=1)
        inverter = Invert(transform=transform, nearest_interp=True, device="cpu")

        for d in loader:
            d = decollate_batch(d)
            for item in d:
                orig = deepcopy(item)
                i = inverter(item)
                self.assertTupleEqual(orig.shape[1:], (100, 100, 100))
                # check the nearest interpolation mode
                torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float))
                self.assertTupleEqual(i.shape[1:], (100, 101, 107))
        # check labels match
        reverted = i.detach().cpu().numpy().astype(np.int32)
        original = LoadImage(image_only=True)(data[-1])
        n_good = np.sum(np.isclose(reverted, original.numpy(), atol=1e-3))
        reverted_name = i.meta["filename_or_obj"]
        original_name = original.meta["filename_or_obj"]
        self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        self.assertTrue((reverted.size - n_good) < 300000, f"diff. {reverted.size - n_good}")
        set_determinism(seed=None)
Exemplo n.º 10
0
    def __call__(
        self, data: Mapping[Hashable,
                            np.ndarray]) -> List[Dict[Hashable, np.ndarray]]:
        d = dict(data)
        label = d[self.label_key]
        image = d[self.image_key] if self.image_key else None
        fg_indices = d.get(
            self.fg_indices_key) if self.fg_indices_key is not None else None
        bg_indices = d.get(
            self.bg_indices_key) if self.bg_indices_key is not None else None

        if self.target_label is not None:
            label = (label == self.target_label).astype(np.uint8)

        self.randomize(label, fg_indices, bg_indices, image)
        if not isinstance(self.spatial_size, tuple):
            raise TypeError(
                f"Expect spatial_size to be tuple, but got {type(self.spatial_size)}"
            )
        if self.centers is None:
            raise AssertionError
        results: List[Dict[Hashable,
                           np.ndarray]] = [{} for _ in range(self.num_samples)]

        for i, center in enumerate(self.centers):
            for key in self.key_iterator(d):
                img = d[key]
                if np.greater(self.spatial_size, img.shape[1:]).any():
                    cropper = ResizeWithPadOrCrop(
                        spatial_size=self.spatial_size)
                else:
                    cropper = SpatialCrop(
                        roi_center=tuple(center),
                        roi_size=self.spatial_size)  # type: ignore
                results[i][key] = cropper(img)
            # fill in the extra keys with unmodified data
            for key in set(data.keys()).difference(set(self.keys)):
                results[i][key] = data[key]
            # add `patch_index` to the meta data
            for key in self.key_iterator(d):
                meta_data_key = f"{key}_{self.meta_key_postfix}"
                if meta_data_key not in results[i]:
                    results[i][meta_data_key] = {}  # type: ignore
                results[i][meta_data_key][Key.PATCH_INDEX] = i

        return results
 def test_pad_shape(self, input_param, input_shape, expected_shape):
     paddcroper = ResizeWithPadOrCrop(**input_param)
     result = paddcroper(np.zeros(input_shape))
     np.testing.assert_allclose(result.shape, expected_shape)
     result = paddcroper(np.zeros(input_shape), mode="constant")
     np.testing.assert_allclose(result.shape, expected_shape)