예제 #1
0
    def __call__(self, data):
        d = dict(data)
        box_start, box_end = generate_spatial_bounding_box(
            d[self.source_key], self.select_fn, self.channel_indices, self.margin
        )

        center = list(np.mean([box_start, box_end], axis=0).astype(int))
        current_size = list(np.subtract(box_end, box_start).astype(int))

        if np.all(np.less(current_size, self.spatial_size)):
            cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size)
            box_start = cropper.roi_start
            box_end = cropper.roi_end
        else:
            cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)

        for key in self.key_iterator(d):
            meta_key = f"{key}_{self.meta_key_postfix}"
            d[meta_key][self.start_coord_key] = box_start
            d[meta_key][self.end_coord_key] = box_end
            d[meta_key][self.original_shape_key] = d[key].shape

            image = cropper(d[key])
            d[meta_key][self.cropped_shape_key] = image.shape
            d[key] = image
        return d
예제 #2
0
    def __call__(self, data):
        d = dict(data)
        box_start, box_end = generate_spatial_bounding_box(
            d[self.source_key], self.select_fn, self.channel_indices,
            self.margin, self.allow_smaller)

        center = list(
            np.mean([box_start, box_end], axis=0).astype(int, copy=False))
        current_size = list(
            np.subtract(box_end, box_start).astype(int, copy=False))

        if np.all(np.less(current_size, self.spatial_size)):
            cropper = SpatialCrop(roi_center=center,
                                  roi_size=self.spatial_size)
            box_start = np.array([s.start for s in cropper.slices])
            box_end = np.array([s.stop for s in cropper.slices])
        else:
            cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)

        for key, meta_key, meta_key_postfix in self.key_iterator(
                d, self.meta_keys, self.meta_key_postfix):
            meta_key = meta_key or f"{key}_{meta_key_postfix}"
            d[meta_key][self.start_coord_key] = box_start
            d[meta_key][self.end_coord_key] = box_end
            d[meta_key][self.original_shape_key] = d[key].shape

            image = cropper(d[key])
            d[meta_key][self.cropped_shape_key] = image.shape
            d[key] = image
        return d
예제 #3
0
 def __call__(self, data):
     d = dict(data)
     box_start, box_end = \
         generate_spatial_bounding_box(data[self.source_key], self.select_fn, self.channel_indexes, self.margin)
     cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
     for key in self.keys:
         d[key] = cropper(d[key])
     return d
예제 #4
0
 def __call__(self, img: np.ndarray) -> np.ndarray:
     """
     Apply the transform to `img`, assuming `img` is channel-first and
     slicing doesn't change the channel dim.
     """
     box_start, box_end = generate_spatial_bounding_box(
         img, self.select_fn, self.channel_indices, self.margin)
     cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
     return cropper(img)
예제 #5
0
파일: dictionary.py 프로젝트: owkin/MONAI
 def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
     d = dict(data)
     box_start, box_end = generate_spatial_bounding_box(
         d[self.source_key], self.select_fn, self.channel_indices, self.margin
     )
     cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
     for key in self.keys:
         d[key] = cropper(d[key])
     return d
예제 #6
0
파일: array.py 프로젝트: slohani-ai/MONAI
    def __call__(self, img: np.ndarray) -> np.ndarray:
        """
        See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`.
        """
        bbox = []

        for channel in range(img.shape[0]):
            start_, end_ = generate_spatial_bounding_box(img, select_fn=self.select_fn, channel_indices=channel)
            bbox.append([i for k in zip(start_, end_) for i in k])

        return np.stack(bbox, axis=0)
예제 #7
0
    def get_center_pos(self, mask_data, z_axis):
        if self.center_mode == "center":
            starts, ends = generate_spatial_bounding_box(
                mask_data, lambda x: x > 0)
            return tuple((st + ed) // 2 for st, ed in zip(starts, ends))
        elif self.center_mode == "maximum":
            axes = np.delete(np.arange(3), z_axis)
            mask_data_ = mask_data.squeeze()
            z_index = np.argmax(np.count_nonzero(mask_data_, axis=tuple(axes)))
            if z_index == 0 and self.crop_mode == "parallel":
                z_index = (self.n_slices - 1) // 2
            elif (z_index == mask_data_.shape[z_axis] - 1
                  and self.crop_mode == "parallel"):
                z_index -= (self.n_slices - 1) // 2

            starts, ends = generate_spatial_bounding_box(
                np.take(mask_data_, z_index, z_axis)[np.newaxis, ...],
                lambda x: x > 0)
            centers = [(st + ed) // 2 for st, ed in zip(starts, ends)]
            centers.insert(z_axis, z_index)
            return tuple(centers)
예제 #8
0
 def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
     d = dict(data)
     box_start, box_end = generate_spatial_bounding_box(
         d[self.source_key], self.select_fn, self.channel_indices, self.margin
     )
     d[self.start_coord_key] = np.asarray(box_start)
     d[self.end_coord_key] = np.asarray(box_end)
     cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
     for key in self.key_iterator(d):
         self.push_transform(d, key, extra_info={"box_start": box_start, "box_end": box_end})
         d[key] = cropper(d[key])
     return d
예제 #9
0
    def __call__(self, data):
        # load data
        d = dict(data)
        image = d["image"]
        image_spacings = d["image_meta_dict"]["pixdim"][1:4].tolist()

        if "label" in self.keys:
            label = d["label"]
            label[label < 0] = 0

        if self.training:
            # only task 04 does not be impacted
            cropped_data = self.crop_foreg({"image": image, "label": label})
            image, label = cropped_data["image"], cropped_data["label"]
        else:
            d["original_shape"] = np.array(image.shape[1:])
            box_start, box_end = generate_spatial_bounding_box(image)
            image = SpatialCrop(roi_start=box_start, roi_end=box_end)(image)
            d["bbox"] = np.vstack([box_start, box_end])
            d["crop_shape"] = np.array(image.shape[1:])

        original_shape = image.shape[1:]
        # calculate shape
        resample_flag = False
        anisotrophy_flag = False
        if self.target_spacing != image_spacings:
            # resample
            resample_flag = True
            resample_shape = self.calculate_new_shape(image_spacings,
                                                      original_shape)
            anisotrophy_flag = self.check_anisotrophy(image_spacings)
            image = resample_image(image, resample_shape, anisotrophy_flag)
            if self.training:
                label = resample_label(label, resample_shape, anisotrophy_flag)

        d["resample_flag"] = resample_flag
        d["anisotrophy_flag"] = anisotrophy_flag
        # clip image for CT dataset
        if self.low != 0 or self.high != 0:
            image = np.clip(image, self.low, self.high)
            image = (image - self.mean) / self.std
        else:
            image = self.normalize_intensity(image.copy())

        d["image"] = image

        if "label" in self.keys:
            d["label"] = label

        return d
예제 #10
0
    def get_center_pos_(self, mask_data):
        axes = np.delete(np.arange(3), self.z_axis)
        starts, ends = generate_spatial_bounding_box(mask_data,
                                                     lambda x: x > 0)
        z_start, z_end = (
            starts[self.z_axis] + (self.n_slices - 1) // 2,
            ends[self.z_axis] - (self.n_slices - 1) // 2,
        )
        centers = []
        for z in np.arange(z_start, z_end):
            center = [(st + ed) // 2 for st, ed in zip(
                np.array(starts)[axes],
                np.array(ends)[axes])]
            center.insert(self.z_axis, z)
            centers.append(tuple(center))

        return centers
예제 #11
0
    def compute_bounding_box(self, img: np.ndarray):
        """
        Compute the start points and end points of bounding box to crop.
        And adjust bounding box coords to be divisible by `k`.

        """
        box_start, box_end = generate_spatial_bounding_box(
            img, self.select_fn, self.channel_indices, self.margin)
        box_start_ = np.asarray(box_start, dtype=np.int16)
        box_end_ = np.asarray(box_end, dtype=np.int16)
        orig_spatial_size = box_end_ - box_start_
        # make the spatial size divisible by `k`
        spatial_size = np.asarray(
            compute_divisible_spatial_size(spatial_shape=orig_spatial_size,
                                           k=self.k_divisible))
        # update box_start and box_end
        box_start_ = box_start_ - np.floor_divide(
            np.asarray(spatial_size) - orig_spatial_size, 2)
        box_end_ = box_start_ + spatial_size
        return box_start_, box_end_
예제 #12
0
def get_mask_edges(
    seg_pred: Union[np.ndarray, torch.Tensor],
    seg_gt: Union[np.ndarray, torch.Tensor],
    label_idx: int = 1,
    crop: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Do binary erosion and use XOR for input to get the edges. This
    function is helpful to further calculate metrics such as Average Surface
    Distance and Hausdorff Distance.
    The input images can be binary or labelfield images. If labelfield images
    are supplied, they are converted to binary images using `label_idx`.

    `scipy`'s binary erosion is used to calculate the edges of the binary
    labelfield.

    In order to improve the computing efficiency, before getting the edges,
    the images can be cropped and only keep the foreground if not specifies
    ``crop = False``.

    We require that images are the same size, and assume that they occupy the
    same space (spacing, orientation, etc.).

    Args:
        seg_pred: the predicted binary or labelfield image.
        seg_gt: the actual binary or labelfield image.
        label_idx: for labelfield images, convert to binary with
            `seg_pred = seg_pred == label_idx`.
        crop: crop input images and only keep the foregrounds. In order to
            maintain two inputs' shapes, here the bounding box is achieved
            by ``(seg_pred | seg_gt)`` which represents the union set of two
            images. Defaults to ``True``.
    """

    # Get both labelfields as np arrays
    if isinstance(seg_pred, torch.Tensor):
        seg_pred = seg_pred.detach().cpu().numpy()
    if isinstance(seg_gt, torch.Tensor):
        seg_gt = seg_gt.detach().cpu().numpy()

    if seg_pred.shape != seg_gt.shape:
        raise ValueError(
            f"seg_pred and seg_gt should have same shapes, got {seg_pred.shape} and {seg_gt.shape}."
        )

    # If not binary images, convert them
    if seg_pred.dtype != bool:
        seg_pred = seg_pred == label_idx
    if seg_gt.dtype != bool:
        seg_gt = seg_gt == label_idx

    if crop:
        if not np.any(seg_pred | seg_gt):
            return np.zeros_like(seg_pred), np.zeros_like(seg_gt)

        seg_pred, seg_gt = np.expand_dims(seg_pred,
                                          0), np.expand_dims(seg_gt, 0)
        box_start, box_end = generate_spatial_bounding_box(
            np.asarray(seg_pred | seg_gt))
        cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
        seg_pred, seg_gt = np.squeeze(cropper(seg_pred)), np.squeeze(
            cropper(seg_gt))

    # Do binary erosion and use XOR to get edges
    edges_pred = binary_erosion(seg_pred) ^ seg_pred
    edges_gt = binary_erosion(seg_gt) ^ seg_gt

    return edges_pred, edges_gt
예제 #13
0
파일: array.py 프로젝트: wentaozhu/MONAI
 def __call__(self, img):
     box_start, box_end = generate_spatial_bounding_box(img, self.select_fn, self.channel_indexes, self.margin)
     cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
     return cropper(img)