Example #1
0
    def __call__(
        self,
        label: np.ndarray,
        image: Optional[np.ndarray] = None,
        output_shape: Optional[Sequence[int]] = None,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Args:
            label: input data to compute foreground and background indices.
            image: if image is not None, use ``label = 0 & image > image_threshold``
                to define background. so the output items will not map to all the voxels in the label.
            output_shape: expected shape of output indices. if None, use `self.output_shape` instead.

        """
        if output_shape is None:
            output_shape = self.output_shape
        fg_indices, bg_indices = map_binary_to_indices(label, image,
                                                       self.image_threshold)
        if output_shape is not None:
            fg_indices = np.stack(
                [np.unravel_index(i, output_shape) for i in fg_indices])
            bg_indices = np.stack(
                [np.unravel_index(i, output_shape) for i in bg_indices])

        return fg_indices, bg_indices
Example #2
0
    def __call__(
        self,
        img: np.ndarray,
        label: Optional[np.ndarray] = None,
        image: Optional[np.ndarray] = None,
        fg_indices: Optional[np.ndarray] = None,
        bg_indices: Optional[np.ndarray] = None,
    ) -> List[np.ndarray]:
        """
        Args:
            img: input data to crop samples from based on the pos/neg ratio of `label` and `image`.
                Assumes `img` is a channel-first array.
            label: the label image that is used for finding foreground/background, if None, use `self.label`.
            image: optional image data to help select valid area, can be same as `img` or another image array.
                use ``label == 0 & image > image_threshold`` to select the negative sample(background) center.
                so the crop center will only exist on valid image area. if None, use `self.image`.
            fg_indices: foreground indices to randomly select crop centers,
                need to provide `fg_indices` and `bg_indices` together.
            bg_indices: background indices to randomly select crop centers,
                need to provide `fg_indices` and `bg_indices` together.

        """
        if label is None:
            label = self.label
        if label is None:
            raise ValueError("label should be provided.")
        if image is None:
            image = self.image
        if fg_indices is None or bg_indices is None:
            if self.fg_indices is not None and self.bg_indices is not None:
                fg_indices = self.fg_indices
                bg_indices = self.bg_indices
            else:
                fg_indices, bg_indices = map_binary_to_indices(
                    label, image, self.image_threshold)

        if self.target_label is not None:
            label = (label == self.target_label).astype(np.uint8)

        self.randomize(label, fg_indices, bg_indices, image)
        results: List[np.ndarray] = []
        if self.centers is not None:
            for center in self.centers:
                if np.any(np.greater(self.spatial_size, img.shape[1:])):
                    cropper = ResizeWithPadOrCrop(
                        spatial_size=self.spatial_size)
                else:
                    cropper = SpatialCrop(
                        roi_center=tuple(center),
                        spatial_size=self.spatial_size)  # type: ignore
                results.append(cropper(img))

        return results
Example #3
0
    def randomize(
        self,
        label: np.ndarray,
        fg_indices: Optional[np.ndarray] = None,
        bg_indices: Optional[np.ndarray] = None,
        image: Optional[np.ndarray] = None,
    ) -> None:
        self.spatial_size = fall_back_tuple(self.spatial_size,
                                            default=label.shape[1:])
        if np.greater(self.spatial_size, label.shape[1:]).any():
            self.centers = [
                None,
            ] * self.num_samples
            return

        if fg_indices is None or bg_indices is None:
            fg_indices_, bg_indices_ = map_binary_to_indices(
                label, image, self.image_threshold)
        else:
            fg_indices_ = fg_indices
            bg_indices_ = bg_indices
        self.centers = generate_pos_neg_label_crop_centers(
            self.spatial_size,
            self.num_samples,
            self.pos_ratio,
            label.shape[1:],
            fg_indices_,
            bg_indices_,
            self.R,
        )

        self.offset_centers = []
        for center in self.centers:
            if 0 < self.offset <= 1:
                offset = [
                    self.R.randint(self.offset * sz // 2) *
                    self.R.choice([1, -1]) for sz in self.spatial_size
                ]
            elif self.offset > 1:
                offset = [
                    self.R.randint(self.offset) * self.R.choice([1, -1])
                    for sz in self.spatial_size
                ]
            else:
                offset = [
                    0,
                ] * len(self.spatial_size)
            self.offset_centers.append(
                [int(c + b) for c, b in zip(center, offset)])
        self.centers = self.offset_centers
Example #4
0
 def randomize(
     self,
     label: np.ndarray,
     fg_indices: Optional[np.ndarray] = None,
     bg_indices: Optional[np.ndarray] = None,
     image: Optional[np.ndarray] = None,
 ) -> None:
     self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:])
     if fg_indices is None or bg_indices is None:
         fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold)
     else:
         fg_indices_ = fg_indices
         bg_indices_ = bg_indices
     self.centers = generate_pos_neg_label_crop_centers(
         self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R
     )
Example #5
0
    def randomize(
        self,
        label: np.ndarray,
        fg_indices: Optional[np.ndarray] = None,
        bg_indices: Optional[np.ndarray] = None,
        image: Optional[np.ndarray] = None,
    ) -> None:
        self.spatial_size = fall_back_tuple(self.spatial_size,
                                            default=label.shape[1:])

        if np.greater(self.spatial_size, label.shape[1:]).any():
            self.centers = [
                None,
            ] * self.num_samples
            return
        # Select subregion to assure valid roi
        valid_start = np.floor_divide(self.spatial_size, 2)
        # add 1 for random
        valid_end = np.subtract(label.shape[1:] + np.array(1),
                                self.spatial_size / np.array(2)).astype(
                                    np.uint16)
        # int generation to have full range on upper side, but subtract unfloored size/2 to prevent rounded range
        # from being too high
        for i in range(
                len(valid_start)
        ):  # need this because np.random.randint does not work with same start and end
            if valid_start[i] == valid_end[i]:
                valid_end[i] += 1

        def _correct_centers(center_ori: List[np.ndarray],
                             valid_start: np.ndarray,
                             valid_end: np.ndarray) -> List[np.ndarray]:
            for i, c in enumerate(center_ori):
                center_i = c
                if c < valid_start[i]:
                    center_i = valid_start[i]
                if c >= valid_end[i]:
                    center_i = valid_end[i] - 1
                center_ori[i] = center_i
            return center_ori

        if fg_indices is None or bg_indices is None:
            fg_indices_, bg_indices_ = map_binary_to_indices(
                label, image, self.image_threshold)
        else:
            fg_indices_ = fg_indices
            bg_indices_ = bg_indices
        self.centers = generate_pos_neg_label_crop_centers(
            self.spatial_size,
            self.num_samples,
            self.pos_ratio,
            label.shape[1:],
            fg_indices_,
            bg_indices_,
            self.R,
        )
        self.offset_centers = []
        for center in self.centers:
            if 0 < self.offset <= 1:
                offset = [
                    self.R.randint(self.offset * sz // 2) *
                    self.R.choice([1, -1]) for sz in self.spatial_size
                ]
            elif self.offset > 1:
                offset = [
                    self.R.randint(self.offset) * self.R.choice([1, -1])
                    for sz in self.spatial_size
                ]
            else:
                offset = [
                    0,
                ] * len(self.spatial_size)
            # print('Offset: ', offset, "Center: ", center)
            offset_centers = [int(c + b) for c, b in zip(center, offset)]
            self.offset_centers.append(
                _correct_centers(offset_centers, valid_start, valid_end))
        self.centers = self.offset_centers