Exemplo n.º 1
0
    def validate_inference_stride_size(inference_stride_size: Optional[TupleInt3],
                                       output_size: Optional[TupleInt3]) -> None:
        """
        Checks that patch stride size is positive and smaller than output patch size to ensure that posterior
        predictions are obtained for all pixels
        """
        if inference_stride_size is not None:
            if any_smaller_or_equal_than(inference_stride_size, 0):
                raise ValueError("inference_stride_size must be > 0 in all dimensions, found: {}"
                                 .format(inference_stride_size))

            if output_size is not None:
                if any_pairwise_larger(inference_stride_size, output_size):
                    raise ValueError("inference_stride_size must be <= output_size in all dimensions"
                                     "Found: output_size={}, inference_stride_size={}"
                                     .format(output_size, inference_stride_size))
Exemplo n.º 2
0
def pad_images_for_inference(
        images: np.ndarray,
        crop_size: TupleInt3,
        output_size: Optional[TupleInt3],
        padding_mode: PaddingMode = PaddingMode.Zero) -> np.ndarray:
    """
    Pad the original image to ensure that the size of the model output as the original image.
    Padding is needed to allow the patches on the corners of the image to be handled correctly, as the model response
    for each patch will only cover the center of  the input voxels for that patch. Hence, add a padding of size
    ceil(output_size - crop_size / 2) around the original image is needed to ensure that the output size of the model
    is the same as the original image size.

    :param images: the image(s) to be padded, in shape: Z x Y x X or batched in shape: Batches x Z x Y x X.
    :param crop_size: the shape of the patches that will be taken from this image.
    :param output_size: the shape of the response for each patch from the model.
    :param padding_mode: a valid numpy padding mode.
    :return: padded copy of the original image.
    """
    def create_padding_vector() -> Tuple[TupleInt2, TupleInt2, TupleInt2]:
        """
        Creates the padding vector.
        """
        diff = np.subtract(crop_size, output_size)
        pad: List[int] = np.ceil(diff / 2.0).astype(int)
        return (pad[0], diff[0] - pad[0]), (pad[1], diff[1] -
                                            pad[1]), (pad[2], diff[2] - pad[2])

    if images is None:
        raise Exception("Image must not be none")

    if output_size is None:
        raise Exception("Output size must not be none")

    if not len(images.shape) in [3, 4]:
        raise Exception("Image must be either 3 dimensions (Z x Y x X) or "
                        "Batched into 4 dimensions (Batches x Z x Y x X)")

    if any_pairwise_larger(output_size, crop_size):
        raise Exception(
            "crop_size must be >= output_size, found crop_size:{}, output_size:{}"
            .format(crop_size, output_size))

    return _pad_images(images=images,
                       padding_vector=create_padding_vector(),
                       padding_mode=padding_mode)
Exemplo n.º 3
0
 def adjust_after_mixed_precision_and_parallel(self, model: Any) -> None:
     """
     Updates the model config parameters (e.g. output patch size). If testing patch stride size is unset then
     its value is set by the output patch size
     """
     self._train_output_size = model.get_output_shape(
         input_shape=self.crop_size)
     self._test_output_size = model.get_output_shape(
         input_shape=self.test_crop_size)
     if self.inference_stride_size is None:
         self.inference_stride_size = self._test_output_size
     else:
         if any_pairwise_larger(self.inference_stride_size,
                                self._test_output_size):
             raise ValueError(
                 "The inference stride size must be smaller than the model's output size in each"
                 "dimension. Inference stride was set to {}, the model outputs {} in test mode."
                 .format(self.inference_stride_size,
                         self._test_output_size))
Exemplo n.º 4
0
def slicers_for_random_crop(
        sample: Sample,
        crop_size: TupleInt3,
        class_weights: List[float] = None) -> Tuple[List[slice], np.ndarray]:
    """
    Computes array slicers that produce random crops of the given crop_size.
    The selection of the center is dependant on background probability.
    By default it does not center on background.

    :param sample: A set of Image channels, ground truth labels and mask to randomly crop.
    :param crop_size: The size of the crop expressed as a list of 3 ints, one per spatial dimension.
    :param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
                          voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
    :return: Tuple element 1: The slicers that convert the input image to the chosen crop. Tuple element 2: The
    indices of the center point of the crop.
    :raises ValueError: If there are shape mismatches among the arguments or if the crop size is larger than the image.
    """
    shape = sample.image.shape[1:]

    if any_pairwise_larger(crop_size, shape):
        raise ValueError(
            "The crop_size across each dimension should be greater than zero and less than or equal "
            "to the current value (crop_size: {}, spatial shape: {})".format(
                crop_size, shape))

    # Sample a center pixel location for patch extraction.
    center = random_select_patch_center(sample, class_weights)

    # Verify and fix overflow for each dimension
    left = []
    for i in range(3):
        margin_left = int(crop_size[i] / 2)
        margin_right = crop_size[i] - margin_left
        left_index = center[i] - margin_left
        right_index = center[i] + margin_right
        if right_index > shape[i]:
            left_index = left_index - (right_index - shape[i])
        if left_index < 0:
            left_index = 0
        left.append(left_index)

    return [slice(left[x], left[x] + crop_size[x])
            for x in range(0, 3)], center
Exemplo n.º 5
0
    def __init__(self,
                 multiple_of: Optional[IntOrTuple3] = None,
                 minimum_size: Optional[IntOrTuple3] = None,
                 num_dimensions: int = 3):
        """
        :param multiple_of: Stores minimum size and other conditions that a training crop size must satisfy.
        :param minimum_size: Training crops must have a size that is a multiple of this value, along each dimension.
        For example, if set to (1, 16, 16), the crop size has to be a multiple of 16 along X and Y, and a
        multiple of 1 (i.e., any number) along the Z dimension.
        :param num_dimensions: Training crops must have a size that is at least this value.
        """
        self.multiple_of = multiple_of
        self.minimum_size = minimum_size
        self.num_dimensions = num_dimensions

        def make_tuple3(o: Optional[IntOrTuple3]) -> Optional[TupleInt3]:
            # "type ignore" directives below are because mypy is not clever enough
            if o is None:
                return None
            if isinstance(o, int):
                # noinspection PyTypeChecker
                return (o, ) * self.num_dimensions  # type: ignore
            if len(o) != self.num_dimensions:  # type: ignore
                raise ValueError(
                    "Object must have length {}, but got: {}".format(
                        self.num_dimensions, o))
            return o  # type: ignore

        self.multiple_of = make_tuple3(self.multiple_of)
        self.minimum_size = make_tuple3(self.minimum_size)
        if self.minimum_size is None:
            self.minimum_size = self.multiple_of
        else:
            if self.multiple_of is not None and any_pairwise_larger(
                    self.multiple_of, self.minimum_size):
                raise ValueError(
                    f"Invalid arguments: The minimum size must be at least as large as the multiple_of. "
                    f"minimum_size: {self.minimum_size}, multiple_of: {self.multiple_of}"
                )
Exemplo n.º 6
0
    def validate(self) -> None:
        """
        Validates the parameters stored in the present object.
        """
        super().validate()
        check_is_any_of("Architecture", self.architecture,
                        vars(ModelArchitectureConfig).keys())

        def len_or_zero(lst: Optional[List[Any]]) -> int:
            return 0 if lst is None else len(lst)

        if self.kernel_size % 2 == 0:
            raise ValueError(
                "The kernel size must be an odd number (kernel_size: {})".
                format(self.kernel_size))

        if self.architecture != ModelArchitectureConfig.UNet3D:
            if any_pairwise_larger(self.center_size, self.crop_size):
                raise ValueError(
                    "Each center_size should be less than or equal to the crop_size "
                    "(center_size: {}, crop_size: {}".format(
                        self.center_size, self.crop_size))
        else:
            if self.crop_size != self.center_size:
                raise ValueError(
                    "For UNet3D, the center size of each dimension should be equal to the crop size "
                    "(center_size: {}, crop_size: {}".format(
                        self.center_size, self.crop_size))

        self.validate_inference_stride_size(self.inference_stride_size,
                                            self.get_output_size())

        # check to make sure there is no overlap between image and ground-truth channels
        image_gt_intersect = np.intersect1d(self.image_channels,
                                            self.ground_truth_ids)
        if len(image_gt_intersect) != 0:
            raise ValueError(
                "Channels: {} were found in both image_channels, and ground_truth_ids"
                .format(image_gt_intersect))

        valid_norm_methods = [
            method.value for method in PhotometricNormalizationMethod
        ]
        check_is_any_of("norm_method", self.norm_method.value,
                        valid_norm_methods)

        if len(self.trim_percentiles
               ) < 2 or self.trim_percentiles[0] >= self.trim_percentiles[1]:
            raise ValueError(
                "Thresholds should contain lower and upper percentile thresholds, but got: {}"
                .format(self.trim_percentiles))

        if len_or_zero(self.class_weights) != (
                len_or_zero(self.ground_truth_ids) + 1):
            raise ValueError(
                "class_weights needs to be equal to number of ground_truth_ids + 1"
            )
        if self.class_weights is None:
            raise ValueError("class_weights must be set.")
        SegmentationModelBase.validate_class_weights(self.class_weights)
        if self.ground_truth_ids is None:
            raise ValueError("ground_truth_ids is None")
        if len(self.ground_truth_ids_display_names) != len(
                self.ground_truth_ids):
            raise ValueError(
                "len(ground_truth_ids_display_names)!=len(ground_truth_ids)")
        if len(self.ground_truth_ids_display_names) != len(self.colours):
            raise ValueError(
                "len(ground_truth_ids_display_names)!=len(colours)")
        if len(self.ground_truth_ids_display_names) != len(self.fill_holes):
            raise ValueError(
                "len(ground_truth_ids_display_names)!=len(fill_holes)")
        if self.mean_teacher_alpha is not None:
            raise ValueError(
                "Mean teacher model is currently only supported for ScalarModels."
                "Please reset mean_teacher_alpha to None.")