Пример #1
0
 def randomize(self, img_size):
     self._size = fall_back_tuple(self.roi_size, img_size)
     if self.random_size:
         self._size = [self.R.randint(low=self._size[i], high=img_size[i] + 1) for i in range(len(img_size))]
     if self.random_center:
         valid_size = get_valid_patch_size(img_size, self._size)
         self._slices = ensure_tuple(slice(None)) + get_random_patch(img_size, valid_size, self.R)
Пример #2
0
 def __call__(self, img):
     """
     Apply the transform to `img`, assuming `img` is channel-first and
     slicing doesn't apply to the channel dim.
     """
     self.roi_size = fall_back_tuple(self.roi_size, img.shape[1:])
     center = [i // 2 for i in img.shape[1:]]
     cropper = SpatialCrop(roi_center=center, roi_size=self.roi_size)
     return cropper(img)
Пример #3
0
 def _determine_data_pad_width(self, data_shape):
     self.spatial_size = fall_back_tuple(self.spatial_size, data_shape)
     if self.method == Method.SYMMETRIC:
         pad_width = list()
         for i in range(len(self.spatial_size)):
             width = max(self.spatial_size[i] - data_shape[i], 0)
             pad_width.append((width // 2, width - (width // 2)))
         return pad_width
     else:
         return [(0, max(self.spatial_size[i] - data_shape[i], 0)) for i in range(len(self.spatial_size))]
Пример #4
0
    def __call__(self, data):
        d = dict(data)
        self.randomize()

        sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:])
        if self.rand_affine.do_transform:
            grid = self.rand_affine.rand_affine_grid(spatial_size=sp_size)
        else:
            grid = create_grid(spatial_size=sp_size)

        for idx, key in enumerate(self.keys):
            d[key] = self.rand_affine.resampler(d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx])
        return d
Пример #5
0
    def __call__(self, data):
        d = dict(data)
        sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, data[self.keys[0]].shape[1:])

        self.randomize(grid_size=sp_size)
        grid = create_grid(spatial_size=sp_size)
        if self.rand_3d_elastic.do_transform:
            device = self.rand_3d_elastic.device
            grid = torch.tensor(grid).to(device)
            gaussian = GaussianFilter(spatial_dims=3, sigma=self.rand_3d_elastic.sigma, truncated=3.0).to(device)
            offset = torch.tensor(self.rand_3d_elastic.rand_offset, device=device).unsqueeze(0)
            grid[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude
            grid = self.rand_3d_elastic.rand_affine_grid(grid=grid)

        for idx, key in enumerate(self.keys):
            d[key] = self.rand_3d_elastic.resampler(
                d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx]
            )
        return d
Пример #6
0
    def __call__(self, img, mode: Optional[Union[NumpyPadMode, str]] = None):
        """
        Args:
            img: data to be transformed, assuming `img` is channel-first
                and padding doesn't apply to the channel dim.
            mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
                ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
                One of the listed string values or a user supplied function. Defaults to ``self.mode``.
                See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
        """
        spatial_shape = img.shape[1:]
        k = fall_back_tuple(self.k, (1,) * len(spatial_shape))
        new_size = []
        for k_d, dim in zip(k, spatial_shape):
            new_dim = int(np.ceil(dim / k_d) * k_d) if k_d > 0 else dim
            new_size.append(new_dim)

        spatial_pad = SpatialPad(spatial_size=new_size, method=Method.SYMMETRIC, mode=mode or self.mode)
        return spatial_pad(img)
Пример #7
0
    def __call__(self, data):
        d = dict(data)

        sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, data[self.keys[0]].shape[1:])
        self.randomize(spatial_size=sp_size)

        if self.rand_2d_elastic.do_transform:
            grid = self.rand_2d_elastic.deform_grid(spatial_size=sp_size)
            grid = self.rand_2d_elastic.rand_affine_grid(grid=grid)
            grid = _torch_interp(
                input=grid.unsqueeze(0),
                scale_factor=list(self.rand_2d_elastic.deform_grid.spacing),
                mode=InterpolateMode.BICUBIC.value,
                align_corners=False,
            )
            grid = CenterSpatialCrop(roi_size=sp_size)(grid[0])
        else:
            grid = create_grid(spatial_size=sp_size)

        for idx, key in enumerate(self.keys):
            d[key] = self.rand_2d_elastic.resampler(
                d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx]
            )
        return d
Пример #8
0
def generate_pos_neg_label_crop_centers(
    label: np.ndarray,
    spatial_size,
    num_samples: int,
    pos_ratio: float,
    image: Optional[np.ndarray] = None,
    image_threshold: float = 0.0,
    rand_state: np.random.RandomState = np.random,
):
    """Generate valid sample locations based on image with option for specifying foreground ratio
    Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W]

    Args:
        label (numpy.ndarray): use the label data to get the foreground/background information.
        spatial_size (sequence of int): spatial size of the ROIs to be sampled.
        num_samples: total sample centers to be generated.
        pos_ratio: ratio of total locations generated that have center being foreground.
        image (numpy.ndarray): if image is not None, use ``label = 0 & image > image_threshold``
            to select background. so the crop center will only exist on valid image area.
        image_threshold: if enabled image_key, use ``image > image_threshold`` to
            determine the valid image content area.
        rand_state (random.RandomState): numpy randomState object to align with other modules.

    Raises:
        ValueError: no sampling location available.

    """
    max_size = label.shape[1:]
    spatial_size = fall_back_tuple(spatial_size, default=max_size)
    if not (np.subtract(max_size, spatial_size) >= 0).all():
        raise ValueError("proposed roi is larger than image itself.")

    # Select subregion to assure valid roi
    valid_start = np.floor_divide(spatial_size, 2)
    valid_end = np.subtract(max_size + np.array(1), spatial_size / np.array(2)).astype(np.uint16)  # add 1 for random
    # int generation to have full range on upper side, but subtract unfloored size/2 to prevent rounded range
    # from being too high
    for i in range(len(valid_start)):  # need this because np.random.randint does not work with same start and end
        if valid_start[i] == valid_end[i]:
            valid_end[i] += 1

    def _correct_centers(center_ori, valid_start, valid_end):
        for i, c in enumerate(center_ori):
            center_i = c
            if c < valid_start[i]:
                center_i = valid_start[i]
            if c >= valid_end[i]:
                center_i = valid_end[i] - 1
            center_ori[i] = center_i
        return center_ori

    centers = []
    # Prepare fg/bg indices
    if label.shape[0] > 1:
        label = label[1:]  # for One-Hot format data, remove the background channel
    label_flat = np.any(label, axis=0).ravel()  # in case label has multiple dimensions
    fg_indices = np.nonzero(label_flat)[0]
    if image is not None:
        img_flat = np.any(image > image_threshold, axis=0).ravel()
        bg_indices = np.nonzero(np.logical_and(img_flat, ~label_flat))[0]
    else:
        bg_indices = np.nonzero(~label_flat)[0]

    if not len(fg_indices) or not len(bg_indices):
        if not len(fg_indices) and not len(bg_indices):
            raise ValueError("no sampling location available.")
        warnings.warn(
            f"N foreground {len(fg_indices)}, N  background {len(bg_indices)},"
            "unable to generate class balanced samples."
        )
        pos_ratio = 0 if not len(fg_indices) else 1

    for _ in range(num_samples):
        indices_to_use = fg_indices if rand_state.rand() < pos_ratio else bg_indices
        random_int = rand_state.randint(len(indices_to_use))
        center = np.unravel_index(indices_to_use[random_int], label.shape)
        center = center[1:]
        # shift center to range of valid centers
        center_ori = [c for c in center]
        centers.append(_correct_centers(center_ori, valid_start, valid_end))

    return centers
Пример #9
0
def sliding_window_inference(
        inputs: Union[torch.Tensor, tuple],
        roi_size,
        sw_batch_size: int,
        predictor: Callable,
        overlap: float = 0.25,
        mode: Union[BlendMode, str] = BlendMode.CONSTANT,
        padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
        cval=0,
        uncertainty_flag=False):
    """
    Sliding window inference on `inputs` with `predictor`.

    When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
    To maintain the same spatial sizes, the output image will be cropped to the original input size.

    Args:
        inputs: input image to be processed (assuming NCHW[D])
        roi_size (list, tuple): the spatial window size for inferences.
            When its components have None or non-positives, the corresponding inputs dimension will be used.
            if the components of the `roi_size` are non-positive values, the transform will use the
            corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
            to `(32, 64)` if the second spatial dimension size of img is `64`.
        sw_batch_size: the batch size to run window slices.
        predictor: given input tensor `patch_data` in shape NCHW[D], `predictor(patch_data)`
            should return a prediction with the same spatial shape and batch_size, i.e. NMHW[D];
            where HW[D] represents the patch spatial size, M is the number of output channels, N is `sw_batch_size`.
        overlap: Amount of overlap between scans.
        mode: {``"constant"``, ``"gaussian"``}
            How to blend output of overlapping windows. Defaults to ``"constant"``.

            - ``"constant``": gives equal weight to all predictions.
            - ``"gaussian``": gives less weight to predictions on edges of windows.

        padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
            Padding mode when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
            See also: https://pytorch.org/docs/stable/nn.functional.html#pad
        cval: fill value for 'constant' padding mode. Default: 0

    Raises:
        NotImplementedError: inputs must have batch_size=1.

    Note:
        - input must be channel-first and have a batch dim, support both spatial 2D and 3D.
        - currently only supports `inputs` with batch_size=1.
    """
    assert 0 <= overlap < 1, "overlap must be >= 0 and < 1."

    # determine image spatial size and batch size
    # Note: all input images must have the same image size and batch size
    inputs_type = type(inputs)
    if inputs_type == tuple:
        phys_inputs = inputs[1]
        inputs = inputs[0]
    num_spatial_dims = len(inputs.shape) - 2
    image_size_ = list(inputs.shape[2:])
    batch_size = inputs.shape[0]

    # TODO: Enable batch sizes > 1 in future
    if batch_size > 1:
        raise NotImplementedError("inputs must have batch_size=1.")

    roi_size = fall_back_tuple(roi_size, image_size_)
    # in case that image size is smaller than roi size
    image_size = tuple(
        max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
    pad_size = []
    for k in range(len(inputs.shape) - 1, 1, -1):
        diff = max(roi_size[k - 2] - inputs.shape[k], 0)
        half = diff // 2
        pad_size.extend([half, diff - half])
    inputs = F.pad(inputs,
                   pad=pad_size,
                   mode=PytorchPadMode(padding_mode).value,
                   value=cval)

    scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims,
                                       overlap)

    # Store all slices in list
    slices = dense_patch_slices(image_size, roi_size, scan_interval)
    # print(f'The slices are {slices}')

    slice_batches = []
    for slice_index in range(0, len(slices), sw_batch_size):
        slice_index_range = range(
            slice_index, min(slice_index + sw_batch_size, len(slices)))
        input_slices = []
        for curr_index in slice_index_range:
            curr_slice = slices[curr_index]
            if len(curr_slice) == 3:
                input_slices.append(inputs[0, :, curr_slice[0], curr_slice[1],
                                           curr_slice[2]])
            else:
                input_slices.append(inputs[0, :, curr_slice[0], curr_slice[1]])
        slice_batches.append(torch.stack(input_slices))

    # Perform predictions
    if not uncertainty_flag:
        # No uncertainty, so only one prediction, so proceed normally
        output_rois = list()
        for data in slice_batches:
            if not uncertainty_flag and inputs_type == tuple:
                seg_prob, _ = predictor(
                    data, phys_inputs)  # batched patch segmentation
                output_rois.append(seg_prob)
            elif inputs_type != tuple:
                seg_prob, _ = predictor(data)  # batched patch segmentation
                output_rois.append(seg_prob)

        # stitching output image
        output_classes = output_rois[0].shape[1]
        output_shape = [batch_size, output_classes] + list(image_size)

        # Create importance map
        importance_map = compute_importance_map(get_valid_patch_size(
            image_size, roi_size),
                                                mode=mode,
                                                device=inputs.device)

        # allocate memory to store the full output and the count for overlapping parts
        output_image = torch.zeros(output_shape,
                                   dtype=torch.float32,
                                   device=inputs.device)
        count_map = torch.zeros(output_shape,
                                dtype=torch.float32,
                                device=inputs.device)

        for window_id, slice_index in enumerate(
                range(0, len(slices), sw_batch_size)):
            slice_index_range = range(
                slice_index, min(slice_index + sw_batch_size, len(slices)))

            # store the result in the proper location of the full output. Apply weights from importance map.
            for curr_index in slice_index_range:
                curr_slice = slices[curr_index]
                if len(curr_slice) == 3:
                    # print(output_image.shape, curr_slice, importance_map.shape, output_rois[window_id].shape)
                    output_image[0, :, curr_slice[0], curr_slice[1],
                                 curr_slice[2]] += (
                                     importance_map *
                                     output_rois[window_id][curr_index -
                                                            slice_index, :])
                    count_map[0, :, curr_slice[0], curr_slice[1],
                              curr_slice[2]] += importance_map
                else:
                    output_image[0, :, curr_slice[0], curr_slice[1]] += (
                        importance_map *
                        output_rois[window_id][curr_index - slice_index, :])
                    count_map[0, :, curr_slice[0],
                              curr_slice[1]] += importance_map

        # account for any overlapping sections
        output_image /= count_map

        if num_spatial_dims == 3:
            return output_image[..., pad_size[4]:image_size_[0] + pad_size[4],
                                pad_size[2]:image_size_[1] + pad_size[2],
                                pad_size[0]:image_size_[2] + pad_size[0], ]
        return output_image[..., pad_size[2]:image_size_[0] + pad_size[2],
                            pad_size[0]:image_size_[1] + pad_size[0]]  # 2D
    else:
        # Decide on number of histogram samples
        num_hist_samples = 20
        overall_stochastic_logits_hist = torch.empty(
            (1, 2, 181, 217, 181, num_hist_samples))
        overall_true_seg_net_out_hist = torch.empty(
            (1, 2, 181, 217, 181, num_hist_samples))

        output_rois = list()
        unc_output_rois = list()
        # Have uncertainty, therefore have MANY outputs, but only have ONE pass through network
        for data in slice_batches:
            if inputs_type == tuple:
                seg_prob, unc_prob, _ = predictor(
                    data, phys_inputs)  # batched patch segmentation
                output_rois.append(seg_prob)
                unc_output_rois.append(unc_prob)
            elif inputs_type != tuple:
                seg_prob, unc_prob, _ = predictor(
                    data)  # batched patch segmentation
                output_rois.append(seg_prob)
                unc_output_rois.append(unc_prob)
        # Get shape of logits
        logits_shape = list(seg_prob.shape)
        # Now want an array of randomly normally distributed samples size of logits x num samples
        # logits_shape.append(num_hist_samples)
        inf_ax = torch.distributions.Normal(
            torch.tensor(0.0).to(device=torch.device("cuda:0")),
            torch.tensor(1.0).to(device=torch.device("cuda:0")))
        # inf_noise_array = torch.empty(logits_shape).normal_(mean=0, std=1)
        # Loop through samples

        for infpass in range(num_hist_samples):
            true_output_rois = list()
            true_unc_output_rois = list()
            # print(f'The lengths of rois are {len(output_rois)}, {len(unc_output_rois)}')
            for roi, unc_roi in zip(output_rois, unc_output_rois):
                # output_rois = list()
                # unc_output_rois = list()m
                # Repeat steps above to get more samples
                # noise_sample = inf_noise_array[..., infpass]
                stochastic_logits = roi + unc_roi * inf_ax.sample(
                    logits_shape)  # noise_sample
                # print(f'The sigma mean is {torch.mean(unc_roi)}, logits mean is {torch.mean(roi)}')
                # print(
                #     f'The logits, sigma, ax sizes are: {roi.shape}, {unc_roi.shape}, {inf_ax.sample(logits_shape).shape}')
                # print(
                #     f'A little ax check: {inf_ax.sample(logits_shape)[0, 0, 0, 0, 0]}, {inf_ax.sample(logits_shape)[0, 1, 0, 0, 0]}')
                true_seg_net_out = torch.softmax(stochastic_logits, dim=1)
                # print(f'The stochastic logits shapes are {stochastic_logits.shape}, {true_seg_net_out.shape}')
                true_output_rois.append(true_seg_net_out)
                true_unc_output_rois.append(stochastic_logits)

            # stitching output image
            # print(f'The true output rois tensor shapes are {true_output_rois[0].shape}')
            output_classes = true_output_rois[0].shape[1]
            output_shape = [batch_size, output_classes] + list(image_size)

            # Create importance map
            importance_map = compute_importance_map(get_valid_patch_size(
                image_size, roi_size),
                                                    mode=mode,
                                                    device=inputs.device)

            # allocate memory to store the full output and the count for overlapping parts
            output_image = torch.zeros(output_shape,
                                       dtype=torch.float32,
                                       device=inputs.device)
            count_map = torch.zeros(output_shape,
                                    dtype=torch.float32,
                                    device=inputs.device)

            # slic_index, zero to len(slices) in increments of sw_batch_size
            for window_id, slice_index in enumerate(
                    range(0, len(slices), sw_batch_size)):
                slice_index_range = range(
                    slice_index, min(slice_index + sw_batch_size, len(slices)))

                # store the result in the proper location of the full output. Apply weights from importance map.
                for curr_index in slice_index_range:
                    curr_slice = slices[curr_index]
                    if len(curr_slice) == 3:
                        # print(output_image.shape, curr_slice, importance_map.shape, true_output_rois[window_id].shape)
                        output_image[0, :, curr_slice[0], curr_slice[1],
                                     curr_slice[2]] += (
                                         importance_map *
                                         true_output_rois[window_id][
                                             curr_index - slice_index, :])
                        count_map[0, :, curr_slice[0], curr_slice[1],
                                  curr_slice[2]] += importance_map
                    else:
                        output_image[0, :, curr_slice[0], curr_slice[1]] += (
                            importance_map *
                            true_output_rois[window_id][curr_index -
                                                        slice_index, :])
                        count_map[0, :, curr_slice[0],
                                  curr_slice[1]] += importance_map

            # account for any overlapping sections
            output_image /= count_map

            if num_spatial_dims == 3:
                output_image = output_image[..., pad_size[4]:image_size_[0] +
                                            pad_size[4],
                                            pad_size[2]:image_size_[1] +
                                            pad_size[2],
                                            pad_size[0]:image_size_[2] +
                                            pad_size[0], ]
                overall_true_seg_net_out_hist[..., infpass] = output_image
            else:
                output_image = output_image[..., pad_size[2]:image_size_[0] +
                                            pad_size[2],
                                            pad_size[0]:image_size_[1] +
                                            pad_size[0]]  # 2D
                overall_true_seg_net_out_hist[..., infpass] = output_image

            # Uncertainty part
            # stitching output image
            output_classes = true_unc_output_rois[0].shape[1]
            output_shape = [batch_size, output_classes] + list(image_size)

            # Create importance map
            importance_map = compute_importance_map(get_valid_patch_size(
                image_size, roi_size),
                                                    mode=mode,
                                                    device=inputs.device)

            # allocate memory to store the full output and the count for overlapping parts
            unc_output = torch.zeros(output_shape,
                                     dtype=torch.float32,
                                     device=inputs.device)
            count_map = torch.zeros(output_shape,
                                    dtype=torch.float32,
                                    device=inputs.device)

            for window_id, slice_index in enumerate(
                    range(0, len(slices), sw_batch_size)):
                slice_index_range = range(
                    slice_index, min(slice_index + sw_batch_size, len(slices)))

                # store the result in the proper location of the full output. Apply weights from importance map.
                for curr_index in slice_index_range:
                    curr_slice = slices[curr_index]
                    if len(curr_slice) == 3:
                        unc_output[0, :, curr_slice[0], curr_slice[1],
                                   curr_slice[2]] += (
                                       importance_map *
                                       true_unc_output_rois[window_id][
                                           curr_index - slice_index, :])
                        count_map[0, :, curr_slice[0], curr_slice[1],
                                  curr_slice[2]] += importance_map
                    else:
                        unc_output[0, :, curr_slice[0], curr_slice[1]] += (
                            importance_map *
                            true_unc_output_rois[window_id][curr_index -
                                                            slice_index, :])
                        count_map[0, :, curr_slice[0],
                                  curr_slice[1]] += importance_map

            # account for any overlapping sections
            unc_output /= count_map

            if num_spatial_dims == 3:
                unc_output = unc_output[..., pad_size[4]:image_size_[0] +
                                        pad_size[4],
                                        pad_size[2]:image_size_[1] +
                                        pad_size[2],
                                        pad_size[0]:image_size_[2] +
                                        pad_size[0], ]
                overall_stochastic_logits_hist[..., infpass] = unc_output
            else:
                unc_output = unc_output[..., pad_size[2]:image_size_[0] +
                                        pad_size[2],
                                        pad_size[0]:image_size_[1] +
                                        pad_size[0]]  # 2D
                overall_stochastic_logits_hist[..., infpass] = unc_output

        # Sigma part
        # stitching output image
        output_classes = unc_output_rois[0].shape[1]
        output_shape = [batch_size, output_classes] + list(image_size)

        # Create importance map
        importance_map = compute_importance_map(get_valid_patch_size(
            image_size, roi_size),
                                                mode=mode,
                                                device=inputs.device)

        # allocate memory to store the full output and the count for overlapping parts
        sigma_output = torch.zeros(output_shape,
                                   dtype=torch.float32,
                                   device=inputs.device)
        count_map = torch.zeros(output_shape,
                                dtype=torch.float32,
                                device=inputs.device)
        for window_id, slice_index in enumerate(
                range(0, len(slices), sw_batch_size)):
            slice_index_range = range(
                slice_index, min(slice_index + sw_batch_size, len(slices)))

            # store the result in the proper location of the full output. Apply weights from importance map.
            for curr_index in slice_index_range:
                curr_slice = slices[curr_index]
                if len(curr_slice) == 3:
                    sigma_output[0, :, curr_slice[0], curr_slice[1],
                                 curr_slice[2]] += (
                                     importance_map *
                                     unc_output_rois[window_id][curr_index -
                                                                slice_index, :]
                                 )
                    count_map[0, :, curr_slice[0], curr_slice[1],
                              curr_slice[2]] += importance_map
                else:
                    sigma_output[0, :, curr_slice[0], curr_slice[1]] += (
                        importance_map *
                        unc_output_rois[window_id][curr_index -
                                                   slice_index, :])
                    count_map[0, :, curr_slice[0],
                              curr_slice[1]] += importance_map

        # account for any overlapping sections
        sigma_output /= count_map

        if num_spatial_dims == 3:
            sigma_output = sigma_output[..., pad_size[4]:image_size_[0] +
                                        pad_size[4],
                                        pad_size[2]:image_size_[1] +
                                        pad_size[2],
                                        pad_size[0]:image_size_[2] +
                                        pad_size[0], ]
        else:
            sigma_output = sigma_output[..., pad_size[2]:image_size_[0] +
                                        pad_size[2],
                                        pad_size[0]:image_size_[1] +
                                        pad_size[0]]  # 2D
        return overall_true_seg_net_out_hist, overall_stochastic_logits_hist, sigma_output
Пример #10
0
 def randomize(self, label, image):
     self.spatial_size = fall_back_tuple(self.spatial_size,
                                         default=label.shape[1:])
     self.centers = generate_pos_neg_label_crop_centers(
         label, self.spatial_size, self.num_samples, self.pos_ratio, image,
         self.image_threshold, self.R)