예제 #1
0
    def __call__(self, data):
        d = dict(data)
        box_start, box_end = generate_spatial_bounding_box(
            d[self.source_key], self.select_fn, self.channel_indices,
            self.margin, self.allow_smaller)

        center = list(
            np.mean([box_start, box_end], axis=0).astype(int, copy=False))
        current_size = list(
            np.subtract(box_end, box_start).astype(int, copy=False))

        if np.all(np.less(current_size, self.spatial_size)):
            cropper = SpatialCrop(roi_center=center,
                                  roi_size=self.spatial_size)
            box_start = np.array([s.start for s in cropper.slices])
            box_end = np.array([s.stop for s in cropper.slices])
        else:
            cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)

        for key, meta_key, meta_key_postfix in self.key_iterator(
                d, self.meta_keys, self.meta_key_postfix):
            meta_key = meta_key or f"{key}_{meta_key_postfix}"
            d[meta_key][self.start_coord_key] = box_start
            d[meta_key][self.end_coord_key] = box_end
            d[meta_key][self.original_shape_key] = d[key].shape

            image = cropper(d[key])
            d[meta_key][self.cropped_shape_key] = image.shape
            d[key] = image
        return d
예제 #2
0
    def __call__(self, data):
        d = dict(data)
        box_start, box_end = generate_spatial_bounding_box(
            d[self.source_key], self.select_fn, self.channel_indices, self.margin
        )

        center = list(np.mean([box_start, box_end], axis=0).astype(int))
        current_size = list(np.subtract(box_end, box_start).astype(int))

        if np.all(np.less(current_size, self.spatial_size)):
            cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size)
            box_start = cropper.roi_start
            box_end = cropper.roi_end
        else:
            cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)

        for key in self.key_iterator(d):
            meta_key = f"{key}_{self.meta_key_postfix}"
            d[meta_key][self.start_coord_key] = box_start
            d[meta_key][self.end_coord_key] = box_end
            d[meta_key][self.original_shape_key] = d[key].shape

            image = cropper(d[key])
            d[meta_key][self.cropped_shape_key] = image.shape
            d[key] = image
        return d
예제 #3
0
    def __call__(self, data):
        d: Dict = dict(data)
        first_key: Union[Hashable, List] = self.first_key(d)
        if first_key == []:
            return d

        guidance = d[self.guidance]
        original_spatial_shape = d[first_key].shape[1:]
        box_start, box_end = self.bounding_box(
            np.array(guidance[0] + guidance[1]), original_spatial_shape)
        center = list(
            np.mean([box_start, box_end], axis=0).astype(int, copy=False))
        spatial_size = self.spatial_size

        box_size = list(
            np.subtract(box_end, box_start).astype(int, copy=False))
        spatial_size = spatial_size[-len(box_size):]

        if len(spatial_size) < len(box_size):
            # If the data is in 3D and spatial_size is specified as 2D [256,256]
            # Then we will get all slices in such case
            diff = len(box_size) - len(spatial_size)
            spatial_size = list(
                original_spatial_shape[1:(1 + diff)]) + spatial_size

        if np.all(np.less(box_size, spatial_size)):
            if len(center) == 3:
                # 3D Deepgrow: set center to be middle of the depth dimension (D)
                center[0] = spatial_size[0] // 2
            cropper = SpatialCrop(roi_center=center, roi_size=spatial_size)
        else:
            cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)

        # update bounding box in case it was corrected by the SpatialCrop constructor
        box_start = np.array([s.start for s in cropper.slices])
        box_end = np.array([s.stop for s in cropper.slices])
        for key, meta_key, meta_key_postfix in self.key_iterator(
                d, self.meta_keys, self.meta_key_postfix):
            if not np.array_equal(d[key].shape[1:], original_spatial_shape):
                raise RuntimeError(
                    "All the image specified in keys should have same spatial shape"
                )
            meta_key = meta_key or f"{key}_{meta_key_postfix}"
            d[meta_key][self.start_coord_key] = box_start
            d[meta_key][self.end_coord_key] = box_end
            d[meta_key][self.original_shape_key] = d[key].shape

            image = cropper(d[key])
            d[meta_key][self.cropped_shape_key] = image.shape
            d[key] = image

        pos_clicks, neg_clicks = guidance[0], guidance[1]
        pos = np.subtract(pos_clicks,
                          box_start).tolist() if len(pos_clicks) else []
        neg = np.subtract(neg_clicks,
                          box_start).tolist() if len(neg_clicks) else []

        d[self.guidance] = [pos, neg]
        return d
예제 #4
0
    def __call__(self, data):
        # load data
        d = dict(data)
        image = d["image"]
        image_spacings = d["image_meta_dict"]["pixdim"][1:4].tolist()

        if "label" in self.keys:
            label = d["label"]
            label[label < 0] = 0

        if self.training:
            # only task 04 does not be impacted
            cropped_data = self.crop_foreg({"image": image, "label": label})
            image, label = cropped_data["image"], cropped_data["label"]
        else:
            d["original_shape"] = np.array(image.shape[1:])
            box_start, box_end = generate_spatial_bounding_box(image)
            image = SpatialCrop(roi_start=box_start, roi_end=box_end)(image)
            d["bbox"] = np.vstack([box_start, box_end])
            d["crop_shape"] = np.array(image.shape[1:])

        original_shape = image.shape[1:]
        # calculate shape
        resample_flag = False
        anisotrophy_flag = False
        if self.target_spacing != image_spacings:
            # resample
            resample_flag = True
            resample_shape = self.calculate_new_shape(image_spacings,
                                                      original_shape)
            anisotrophy_flag = self.check_anisotrophy(image_spacings)
            image = resample_image(image, resample_shape, anisotrophy_flag)
            if self.training:
                label = resample_label(label, resample_shape, anisotrophy_flag)

        d["resample_flag"] = resample_flag
        d["anisotrophy_flag"] = anisotrophy_flag
        # clip image for CT dataset
        if self.low != 0 or self.high != 0:
            image = np.clip(image, self.low, self.high)
            image = (image - self.mean) / self.std
        else:
            image = self.normalize_intensity(image.copy())

        d["image"] = image

        if "label" in self.keys:
            d["label"] = label

        return d
예제 #5
0
    def __call__(self, data):
        d: Dict = dict(data)
        guidance = d[self.guidance]
        original_spatial_shape = d[self.keys[0]].shape[1:]
        box_start, box_end = self.bounding_box(
            np.array(guidance[0] + guidance[1]), original_spatial_shape)
        center = list(np.mean([box_start, box_end], axis=0).astype(int))
        spatial_size = self.spatial_size

        box_size = list(np.subtract(box_end, box_start).astype(int))
        spatial_size = spatial_size[-len(box_size):]

        if len(spatial_size) < len(box_size):
            # If the data is in 3D and spatial_size is specified as 2D [256,256]
            # Then we will get all slices in such case
            diff = len(box_size) - len(spatial_size)
            spatial_size = list(
                original_spatial_shape[1:(1 + diff)]) + spatial_size

        if np.all(np.less(box_size, spatial_size)):
            if len(center) == 3:
                # 3D Deepgrow: set center to be middle of the depth dimension (D)
                center[0] = spatial_size[0] // 2
            cropper = SpatialCrop(roi_center=center, roi_size=spatial_size)
        else:
            cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
        box_start, box_end = cropper.roi_start, cropper.roi_end

        for key in self.keys:
            if not np.array_equal(d[key].shape[1:], original_spatial_shape):
                raise RuntimeError(
                    "All the image specified in keys should have same spatial shape"
                )
            meta_key = f"{key}_{self.meta_key_postfix}"
            d[meta_key][self.start_coord_key] = box_start
            d[meta_key][self.end_coord_key] = box_end
            d[meta_key][self.original_shape_key] = d[key].shape

            image = cropper(d[key])
            d[meta_key][self.cropped_shape_key] = image.shape
            d[key] = image

        pos_clicks, neg_clicks = guidance[0], guidance[1]
        pos = np.subtract(pos_clicks,
                          box_start).tolist() if len(pos_clicks) else []
        neg = np.subtract(neg_clicks,
                          box_start).tolist() if len(neg_clicks) else []

        d[self.guidance] = [pos, neg]
        return d
예제 #6
0
 def test_tensor_shape(self, input_param, input_shape, expected_shape):
     input_data = torch.randint(
         0,
         2,
         size=input_shape,
         device="cuda" if torch.cuda.is_available() else "cpu")
     result = SpatialCrop(**input_param)(input_data)
     self.assertTupleEqual(result.shape, expected_shape)
예제 #7
0
    def __call__(self, data):
        guidance = data[self.guidance]
        box_start = None
        for key in self.keys:
            box_start, box_end = self.bounding_box(
                np.array(guidance[0] + guidance[1]), data[key].shape[1:])
            center = np.mean([box_start, box_end], axis=0).astype(int).tolist()
            spatial_size = data.get(self.spatial_size_key, self.spatial_size)

            current_size = np.absolute(np.subtract(
                box_start, box_end)).astype(int).tolist()
            spatial_size = spatial_size[-len(current_size):]
            if len(spatial_size) < len(
                    current_size
            ):  # 3D spatial_size = [256,256] (include all slices in such case)
                diff = len(current_size) - len(spatial_size)
                spatial_size = list(
                    data[key].shape[1:(1 + diff)]) + spatial_size

            if np.all(np.less(current_size, spatial_size)):
                if len(center) == 3:
                    center[0] = center[0] + (spatial_size[0] // 2 - center[0])
                cropper = SpatialCrop(roi_center=center, roi_size=spatial_size)
            else:
                cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
            box_start, box_end = cropper.roi_start, cropper.roi_end

            meta_key = f"{key}_{self.meta_key_postfix}"
            data[meta_key][self.start_coord_key] = box_start
            data[meta_key][self.end_coord_key] = box_end
            data[meta_key][self.original_shape_key] = data[key].shape

            image = cropper(data[key])
            data[meta_key][self.cropped_shape_key] = image.shape
            data[key] = image

        pos_clicks, neg_clicks = guidance[0], guidance[1]
        pos = np.subtract(pos_clicks,
                          box_start).tolist() if len(pos_clicks) else []
        neg = np.subtract(neg_clicks,
                          box_start).tolist() if len(neg_clicks) else []

        data[self.guidance] = [pos, neg]
        return data
예제 #8
0
    def __call__(
        self,
        img: np.ndarray,
        label: Optional[np.ndarray] = None,
        image: Optional[np.ndarray] = None,
        fg_indices: Optional[np.ndarray] = None,
        bg_indices: Optional[np.ndarray] = None,
    ) -> List[np.ndarray]:
        """
        Args:
            img: input data to crop samples from based on the pos/neg ratio of `label` and `image`.
                Assumes `img` is a channel-first array.
            label: the label image that is used for finding foreground/background, if None, use `self.label`.
            image: optional image data to help select valid area, can be same as `img` or another image array.
                use ``label == 0 & image > image_threshold`` to select the negative sample(background) center.
                so the crop center will only exist on valid image area. if None, use `self.image`.
            fg_indices: foreground indices to randomly select crop centers,
                need to provide `fg_indices` and `bg_indices` together.
            bg_indices: background indices to randomly select crop centers,
                need to provide `fg_indices` and `bg_indices` together.

        """
        if label is None:
            label = self.label
        if label is None:
            raise ValueError("label should be provided.")
        if image is None:
            image = self.image
        if fg_indices is None or bg_indices is None:
            if self.fg_indices is not None and self.bg_indices is not None:
                fg_indices = self.fg_indices
                bg_indices = self.bg_indices
            else:
                fg_indices, bg_indices = map_binary_to_indices(
                    label, image, self.image_threshold)

        if self.target_label is not None:
            label = (label == self.target_label).astype(np.uint8)

        self.randomize(label, fg_indices, bg_indices, image)
        results: List[np.ndarray] = []
        if self.centers is not None:
            for center in self.centers:
                if np.any(np.greater(self.spatial_size, img.shape[1:])):
                    cropper = ResizeWithPadOrCrop(
                        spatial_size=self.spatial_size)
                else:
                    cropper = SpatialCrop(
                        roi_center=tuple(center),
                        spatial_size=self.spatial_size)  # type: ignore
                results.append(cropper(img))

        return results
예제 #9
0
    def __call__(
        self,
        img: np.ndarray,
        msk: Optional[np.ndarray] = None,
        center: Optional[tuple] = None,
        z_axis: Optional[int] = None,
    ):
        """
        Apply the transform to `img`, assuming `img` is channel-first and
        slicing doesn't apply to the channel dim.
        """
        if self.mask_data is None and msk is None:
            raise ValueError("Unknown mask_data.")
        mask_data_ = np.array([[1]])
        if self.mask_data is not None and msk is None:
            mask_data_ = self.mask_data > 0
        if msk is not None:
            mask_data_ = msk > 0
        mask_data_ = np.asarray(mask_data_)

        if mask_data_.shape[0] != 1 and mask_data_.shape[0] != img.shape[0]:
            raise ValueError(
                "When mask_data is not single channel, mask_data channels must match img, "
                f"got img={img.shape[0]} mask_data={mask_data_.shape[0]}.")

        z_axis_ = z_axis if z_axis is not None else self.z_axis

        if center is None:
            center = self.get_center_pos(mask_data_, z_axis_)

        if self.crop_mode in ["single", "parallel"]:
            size_ = self.get_new_spatial_size(z_axis_)
            size_ = list(map(int, size_))
            slice_ = SpatialCrop(roi_center=center, roi_size=size_)(img)
            if np.any(slice_.shape[1:] != size_):
                slice_ = ResizeWithPadOrCrop(spatial_size=size_)(slice_)

            return np.moveaxis(slice_.squeeze(0), z_axis_, 0)
        else:
            cross_slices = np.zeros(shape=(3, ) + self.roi_size)
            for k in range(3):
                size_ = np.insert(self.roi_size, k, 1)
                slice_ = SpatialCrop(roi_center=center, roi_size=size_)(img)
                if np.any(slice_.shape[1:] != size_):
                    slice_ = ResizeWithPadOrCrop(spatial_size=size_)(slice_)

                cross_slices[k] = slice_.squeeze()
            return cross_slices
예제 #10
0
    def __call__(
            self, data: Mapping[Hashable,
                                np.ndarray]) -> Dict[Hashable, np.ndarray]:
        d = dict(data)

        for key in self.key_iterator(d):
            orig_size = d[key].shape[1:]
            z_size = orig_size[2]
            z_bottom = int(z_size * self.relative_z_roi[1])
            z_top = z_size - int(z_size * self.relative_z_roi[0])
            roi_start = np.array([0, 0, z_bottom])
            roi_end = np.array([orig_size[0], orig_size[1], z_top])
            cropper = SpatialCrop(roi_start=roi_start, roi_end=roi_end)
            d[key] = cropper(d[key])
        return d
예제 #11
0
    def __call__(
        self, data: Mapping[Hashable,
                            np.ndarray]) -> List[Dict[Hashable, np.ndarray]]:
        d = dict(data)
        label = d[self.label_key]
        image = d[self.image_key] if self.image_key else None
        fg_indices = d.get(
            self.fg_indices_key) if self.fg_indices_key is not None else None
        bg_indices = d.get(
            self.bg_indices_key) if self.bg_indices_key is not None else None

        if self.target_label is not None:
            label = (label == self.target_label).astype(np.uint8)

        self.randomize(label, fg_indices, bg_indices, image)
        if not isinstance(self.spatial_size, tuple):
            raise TypeError(
                f"Expect spatial_size to be tuple, but got {type(self.spatial_size)}"
            )
        if self.centers is None:
            raise AssertionError
        results: List[Dict[Hashable,
                           np.ndarray]] = [{} for _ in range(self.num_samples)]

        for i, center in enumerate(self.centers):
            for key in self.key_iterator(d):
                img = d[key]
                if np.greater(self.spatial_size, img.shape[1:]).any():
                    cropper = ResizeWithPadOrCrop(
                        spatial_size=self.spatial_size)
                else:
                    cropper = SpatialCrop(
                        roi_center=tuple(center),
                        roi_size=self.spatial_size)  # type: ignore
                results[i][key] = cropper(img)
            # fill in the extra keys with unmodified data
            for key in set(data.keys()).difference(set(self.keys)):
                results[i][key] = data[key]
            # add `patch_index` to the meta data
            for key in self.key_iterator(d):
                meta_data_key = f"{key}_{self.meta_key_postfix}"
                if meta_data_key not in results[i]:
                    results[i][meta_data_key] = {}  # type: ignore
                results[i][meta_data_key][Key.PATCH_INDEX] = i

        return results
예제 #12
0
 def test_shape(self, input_param, input_shape, expected_shape):
     input_data = np.random.randint(0, 2, size=input_shape)
     results = []
     for p in TEST_NDARRAYS:
         for q in TEST_NDARRAYS + (None, ):
             input_param_mod = {
                 k: q(v) if k != "roi_slices" and q is not None else v
                 for k, v in input_param.items()
             }
             im = p(input_data)
             result = SpatialCrop(**input_param_mod)(im)
             self.assertEqual(type(im), type(result))
             if isinstance(result, torch.Tensor):
                 self.assertEqual(result.device, im.device)
             self.assertTupleEqual(result.shape, expected_shape)
             results.append(result)
             if len(results) > 1:
                 assert_allclose(results[0], results[-1], type_test=False)
예제 #13
0
    def __call__(
        self, data: Mapping[Hashable,
                            np.ndarray]) -> List[Dict[Hashable, np.ndarray]]:
        d = dict(data)
        label = d[self.label_key]
        image = d[self.image_key] if self.image_key else None
        fg_indices = d.get(self.fg_indices_key,
                           None) if self.fg_indices_key is not None else None
        bg_indices = d.get(self.bg_indices_key,
                           None) if self.bg_indices_key is not None else None

        self.randomize(label, fg_indices, bg_indices, image)
        assert isinstance(self.spatial_size, tuple)
        assert self.centers is not None
        results: List[Dict[Hashable, np.ndarray]] = [
            dict() for _ in range(self.num_samples)
        ]
        for key in data.keys():
            if key in self.keys:
                img = d[key]
                for i, center in enumerate(self.centers):
                    if self.crop_mode in ["single", "parallel"]:
                        size_ = self.get_new_spatial_size()
                        slice_ = SpatialCrop(roi_center=tuple(center),
                                             roi_size=size_)(img)
                        results[i][key] = np.moveaxis(slice_.squeeze(0),
                                                      self.z_axis, 0)
                    else:
                        cross_slices = np.zeros(shape=(3, ) +
                                                self.spatial_size)
                        for k in range(3):
                            size_ = np.insert(self.spatial_size, k, 1)
                            slice_ = SpatialCrop(roi_center=tuple(center),
                                                 roi_size=size_)(img)
                            cross_slices[k] = slice_.squeeze()
                        results[i][key] = cross_slices
            else:
                for i in range(self.num_samples):
                    results[i][key] = data[key]

        return results
예제 #14
0
 def test_shape(self, input_param, input_shape, expected_shape):
     input_data = np.random.randint(0, 2, size=input_shape)
     result = SpatialCrop(**input_param)(input_data)
     self.assertTupleEqual(result.shape, expected_shape)
예제 #15
0
 def test_shape(self, input_param, input_data, expected_shape):
     result = SpatialCrop(**input_param)(input_data)
     self.assertTupleEqual(result.shape, expected_shape)
예제 #16
0
 def test_error(self, input_param):
     with self.assertRaises(ValueError):
         SpatialCrop(**input_param)