Example #1
0
def apply_mask_to_posteriors(posteriors: NumpyOrTorch,
                             mask: NumpyOrTorch) -> NumpyOrTorch:
    """
    Apply a binary mask to the provided posteriors such that for all voxels outside
    of the mask:
        1) The background class posterior (index == 0) is set to 1.
        2) All other classes posteriors are set to 0.

    :param posteriors: image tensors in shape: Batches (optional) x Classes x Z x Y x X
    :param mask: image tensor in shape: Batches (optional) x Z x Y x X
    :return posteriors with mask applied
    """
    ml_util.check_size_matches(posteriors,
                               mask,
                               matching_dimensions=[-1, -2, -3])

    batch_posteriors = len(posteriors.shape) != 5
    if batch_posteriors:
        posteriors = posteriors[None, ...]
    if len(mask.shape) != 4:
        mask = mask[None, ...]

    if posteriors.shape[0] != mask.shape[0]:
        raise ValueError(
            "posteriors and mask must have the same number of patches, "
            "found posteriors={}, mask={}".format(posteriors.shape,
                                                  mask.shape))

    for c in range(posteriors.shape[1]):
        posteriors[:, c, ...][mask == 0] = int(c == 0)

    if batch_posteriors:
        posteriors = posteriors[0]

    return posteriors
Example #2
0
def compute_dice_across_patches(segmentation: torch.Tensor,
                                ground_truth: torch.Tensor,
                                allow_multiple_classes_for_each_pixel: bool = False) -> torch.Tensor:
    """
    Computes the Dice scores for all classes across all patches in the arguments.
    :param segmentation: Tensor containing class ids predicted by a model.
    :param ground_truth: One-hot encoded torch tensor containing ground-truth label ids.
    :param allow_multiple_classes_for_each_pixel: If set to False, ground-truth tensor has
    to contain only one foreground label for each pixel.
    :return A torch tensor of size (Patches, Classes) with the Dice scores. Dice scores are computed for
    all classes including the background class at index 0.
    """
    check_size_matches(segmentation, ground_truth, 4, 5, [0, -3, -2, -1],
                       arg1_name="segmentation", arg2_name="ground_truth")

    # One-hot encoded ground-truth values should sum up to one for all pixels
    if not allow_multiple_classes_for_each_pixel:
        if not torch.allclose(torch.sum(ground_truth, dim=1).float(),
                              torch.ones(segmentation.shape, device=ground_truth.device).float()):
            raise Exception("Ground-truth one-hot matrix does not sum up to one for all pixels")

    # Convert the ground-truth to one-hot-encoding
    [num_patches, num_classes] = ground_truth.size()[:2]
    one_hot_segmentation = F.one_hot(segmentation, num_classes=num_classes).permute(0, 4, 1, 2, 3)

    # Convert the tensors to bool tensors
    one_hot_segmentation = one_hot_segmentation.bool().view(num_patches, num_classes, -1)
    ground_truth = ground_truth.bool().view(num_patches, num_classes, -1)

    # And operation between segmentation and ground-truth - reduction operation
    # Count the number of samples in segmentation and ground-truth
    intersection = 2.0 * torch.sum(one_hot_segmentation & ground_truth, dim=-1).float()
    union = torch.sum(one_hot_segmentation, dim=-1) + torch.sum(ground_truth, dim=-1).float() + 1.0e-6

    return intersection / union
Example #3
0
    def __post_init__(self) -> None:
        # make sure all properties are populated
        common_util.check_properties_are_not_none(self)

        # ensure the center crops for the labels and mask are compatible with each other
        ml_util.check_size_matches(arg1=self.mask_center_crop,
                                   arg2=self.labels_center_crop,
                                   matching_dimensions=self._get_matching_dimensions())
Example #4
0
def plot_contours_for_all_classes(sample: Sample,
                                  segmentation: np.ndarray,
                                  foreground_class_names: List[str],
                                  result_folder: Path,
                                  result_prefix: str = "",
                                  image_range: Optional[TupleFloat2] = None,
                                  channel_index: int = 0) -> List[Path]:
    """
    Creates a plot with the image, the ground truth, and the predicted segmentation overlaid. One plot is created
    for each class, each plotting the Z slice where the ground truth has most pixels.
    :param sample: The image sample, with the photonormalized image and the ground truth labels.
    :param segmentation: The predicted segmentation: multi-value, size Z x Y x X.
    :param foreground_class_names: The names of all classes, excluding the background class.
    :param result_folder: The folder into which the resulting plot PNG files should be written.
    :param result_prefix: A string prefix that will be used for all plots.
    :param image_range: The minimum and maximum image values that will be mapped to the color map ranges.
    If None, use the actual min and max values.
    :param channel_index: The index of the image channel that should be plotted.
    :return: The paths to all generated PNG files.
    """
    check_size_matches(sample.labels[0], segmentation)
    num_classes = sample.labels.shape[0]
    if len(foreground_class_names) != num_classes - 1:
        raise ValueError(
            f"Labels tensor indicates {num_classes} classes, but got {len(foreground_class_names)} foreground "
            f"class names: {foreground_class_names}")
    plot_names: List[Path] = []
    image = sample.image[channel_index, ...]
    contour_arguments = [{
        'colors': 'r'
    }, {
        'colors': 'b',
        'linestyles': 'dashed'
    }]
    binaries = binaries_from_multi_label_array(segmentation, num_classes)
    for class_index, binary in enumerate(binaries):
        if class_index == 0:
            continue
        ground_truth = sample.labels[class_index, ...]
        largest_gt_slice = get_largest_z_slice(ground_truth)
        labels_at_largest_gt = ground_truth[largest_gt_slice]
        segmentation_at_largest_gt = binary[largest_gt_slice, ...]
        class_name = foreground_class_names[class_index - 1]
        patient_id = sample.patient_id
        if isinstance(patient_id, str):
            patient_id_str = patient_id
        else:
            patient_id_str = f"{patient_id:03d}"
        filename_stem = f"{result_prefix}{patient_id_str}_{class_name}_slice_{largest_gt_slice:03d}"
        plot_file = plot_image_and_label_contour(
            image=image[largest_gt_slice, ...],
            labels=[labels_at_largest_gt, segmentation_at_largest_gt],
            contour_arguments=contour_arguments,
            image_range=image_range,
            plot_file_name=result_folder / filename_stem)

        plot_names.append(plot_file)
    return plot_names
Example #5
0
    def __post_init__(self) -> None:
        # make sure all properties are populated
        common_util.check_properties_are_not_none(self)

        ml_util.check_size_matches(arg1=self.image, arg2=self.mask,
                                   matching_dimensions=self._get_matching_dimensions())

        ml_util.check_size_matches(arg1=self.image, arg2=self.labels,
                                   matching_dimensions=self._get_matching_dimensions())
    def __post_init__(self) -> None:
        common_util.check_properties_are_not_none(self)

        ml_util.check_size_matches(arg1=self.image, arg2=self.prediction,
                                   dim1=3, dim2=3,
                                   matching_dimensions=[])

        ml_util.check_size_matches(arg1=self.image, arg2=self.labels,
                                   dim1=3, dim2=4,
                                   matching_dimensions=[-1, -2, -3])
Example #7
0
 def predict_whole_image(self, image_channels: np.ndarray,
                         voxel_spacing_mm: TupleFloat3,
                         mask: np.ndarray = None,
                         patient_id: int = 0) -> InferencePipeline.Result:
     """
     Performs a single inference pass through the pipeline for the provided image
     :param image_channels: The input image channels to perform inference on in format: Channels x Z x Y x X.
     :param voxel_spacing_mm: Voxel spacing to use for each dimension in (Z x Y x X) order
     :param mask: A binary image used to ignore results outside it in format: Z x Y x X.
     :param patient_id: The identifier of the patient this image belongs to (defaults to 0 if None provided).
     :return InferenceResult: that contains Segmentation for each of the classes and their posterior probabilities.
     """
     if image_channels is None:
         raise Exception("image_channels cannot be None")
     if image_channels.ndim != 4:
         raise NotImplementedError("image_channels must be in shape: Channels x Z x Y x X"
                                   "found image_channels shape: {}".format(image_channels.shape))
     if mask is not None:
         ml_util.check_size_matches(image_channels, mask, 4, 3, [-1, -2, -3])
     self.model.eval()
     # create the dataset for the batch
     batch_dataset = Dataset(index=[patient_id], batch_class=InferenceBatch)
     # setup the pipeline
     pipeline = (batch_dataset.p
                 # define pipeline variables
                 .init_variables([InferencePipeline.Variables.Model,
                                  InferencePipeline.Variables.ModelConfig,
                                  InferencePipeline.Variables.CropSize,
                                  InferencePipeline.Variables.OutputSize,
                                  InferencePipeline.Variables.OutputImageShape,
                                  InferencePipeline.Variables.Stride])
                 # update the variables for the batch actions
                 .update_variable(name=InferencePipeline.Variables.Model, value=self.model)
                 .update_variable(name=InferencePipeline.Variables.ModelConfig, value=self.model_config)
                 # perform cascaded batch actions
                 .load(image_channels=image_channels, mask=mask)
                 .pre_process()
                 .predict()
                 .post_process()
                 )
     # run the batch through the pipeline
     logging.info(f"Inference pipeline ({self.pipeline_id}), Predicting patient: {patient_id}")
     processed_batch: InferenceBatch = pipeline.next_batch(batch_size=1)
     posteriors = processed_batch.get_component(InferenceBatch.Components.Posteriors)
     image_util.check_array_range(posteriors, error_prefix="Whole image posteriors")
     # prepare pipeline results from the processed batch
     return InferencePipeline.Result(
         patient_id=patient_id,
         segmentation=processed_batch.get_component(InferenceBatch.Components.Segmentation),
         posteriors=posteriors,
         voxel_spacing_mm=voxel_spacing_mm
     )
Example #8
0
        def __init__(self, epoch: int, patient_id: int,
                     segmentation: np.ndarray, posteriors: np.ndarray,
                     voxel_spacing_mm: TupleFloat3):
            """
            :param epoch: The epoch for which inference in being performed on.
            :param patient_id: The id of the patient instance for with inference is being performed on.
            :param segmentation: Z x Y x X (argmaxed over the posteriors in the class dimension)
            :param voxel_spacing_mm: Voxel spacing to use for each dimension in (Z x Y x X) order
            :param posteriors: Class x Z x Y x X
            """
            self.epoch = epoch
            self.patient_id = patient_id
            self.segmentation = segmentation
            self.posteriors = posteriors
            self.voxel_spacing_mm = voxel_spacing_mm

            if len(self.voxel_spacing_mm) != 3:
                raise ValueError(
                    f"voxel_spacing_mm must have length 3, found: {voxel_spacing_mm}"
                )
            if any(np.array(self.voxel_spacing_mm) <= 0):
                raise ValueError(
                    f"voxel_spacing_mm must have values > 0 in each dimension, found: {voxel_spacing_mm}"
                )

            ml_util.check_size_matches(self.segmentation,
                                       self.posteriors,
                                       dim1=3,
                                       dim2=4,
                                       matching_dimensions=[-3, -2, -1],
                                       arg1_name="segmentation",
                                       arg2_name="posteriors")

            segmentation_value_range = np.unique(self.segmentation)
            if not np.all([
                    x in range(self.posteriors.shape[0])
                    for x in segmentation_value_range
            ]):
                raise Exception(
                    "values in the segmentation map must be in range [0, classes), "
                    "found classes:{}, segmentation range:{}".format(
                        self.posteriors.shape[0], segmentation_value_range))

            self._uncertainty = compute_uncertainty_map_from_posteriors(
                self.posteriors)
def test_check_size() -> None:
    """
    Test `check_size_matches` function.
    """
    a1 = np.zeros((2, 3, 4))
    a2 = np.zeros((5, 2, 3, 4))
    check_size_matches(a1, a1)
    check_size_matches(a1, a2, matching_dimensions=[-3, -2, -1])
    check_size_matches(a1,
                       a2,
                       dim1=3,
                       dim2=4,
                       matching_dimensions=[-3, -2, -1])
    check_size_matches(a1, a1, matching_dimensions=[0, 1])

    def throws(func: Callable[..., None]) -> None:
        with pytest.raises(ValueError) as e:
            func()
        print("Exception message: {}".format(e.value))

    # Can't compare arrays of different dimension
    throws(lambda: check_size_matches(a1, a2))  # type: ignore
    # a2 has wrong dimension
    throws(lambda: check_size_matches(a1, a2, dim1=3, dim2=3))  # type: ignore
    # a1 has wrong dimension
    throws(lambda: check_size_matches(a1, a2, dim1=4, dim2=4))  # type: ignore
    # a1 has wrong dimension [0]
    throws(lambda: check_size_matches(a1, a2, dim1=4, dim2=4))  # type: ignore
Example #10
0
def calculate_metrics_per_class(segmentation: np.ndarray,
                                ground_truth: np.ndarray,
                                ground_truth_ids: List[str],
                                voxel_spacing: TupleFloat3,
                                patient_id: Optional[int] = None) -> MetricsDict:
    """
    Calculate the dice for all foreground structures (the background class is completely ignored).
    Returns a MetricsDict with metrics for each of the foreground
    structures. Metrics are NaN if both ground truth and prediction are all zero for a class.
    :param ground_truth_ids: The names of all foreground classes.
    :param segmentation: predictions multi-value array with dimensions: [Z x Y x X]
    :param ground_truth: ground truth binary array with dimensions: [C x Z x Y x X]
    :param voxel_spacing: voxel_spacing in 3D Z x Y x X
    :param patient_id: for logging
    """
    number_of_classes = ground_truth.shape[0]
    if len(ground_truth_ids) != (number_of_classes - 1):
        raise ValueError(f"Received {len(ground_truth_ids)} foreground class names, but "
                         f"the label tensor indicates that there are {number_of_classes - 1} classes.")
    binaries = binaries_from_multi_label_array(segmentation, number_of_classes)

    all_classes_are_binary = [is_binary_array(ground_truth[label_id]) for label_id in range(ground_truth.shape[0])]
    if not np.all(all_classes_are_binary):
        raise ValueError("Ground truth values should be 0 or 1")
    overlap_measures_filter = sitk.LabelOverlapMeasuresImageFilter()
    hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
    metrics = MetricsDict(hues=ground_truth_ids)
    for i, prediction in enumerate(binaries):
        if i == 0:
            continue
        check_size_matches(prediction, ground_truth[i], arg1_name="prediction", arg2_name="ground_truth")
        if not is_binary_array(prediction):
            raise ValueError("Predictions values should be 0 or 1")
        # simpleitk returns a Dice score of 0 if both ground truth and prediction are all zeros.
        # We want to be able to fish out those cases, and treat them specially later.
        prediction_zero = np.all(prediction == 0)
        gt_zero = np.all(ground_truth[i] == 0)
        dice = mean_surface_distance = hausdorff_distance = math.nan
        if not (prediction_zero and gt_zero):
            prediction_image = sitk.GetImageFromArray(prediction.astype(np.uint8))
            prediction_image.SetSpacing(sitk.VectorDouble(reverse_tuple_float3(voxel_spacing)))
            ground_truth_image = sitk.GetImageFromArray(ground_truth[i].astype(np.uint8))
            ground_truth_image.SetSpacing(sitk.VectorDouble(reverse_tuple_float3(voxel_spacing)))
            overlap_measures_filter.Execute(prediction_image, ground_truth_image)
            dice = overlap_measures_filter.GetDiceCoefficient()
            if prediction_zero or gt_zero:
                hausdorff_distance = mean_surface_distance = math.inf
            else:
                try:
                    hausdorff_distance_filter.Execute(prediction_image, ground_truth_image)
                    hausdorff_distance = hausdorff_distance_filter.GetHausdorffDistance()
                except Exception as e:
                    logging.warning("Cannot calculate Hausdorff distance for "
                                    f"structure {i} of patient {patient_id}: {e}")
                try:
                    mean_surface_distance = surface_distance(prediction_image, ground_truth_image)
                except Exception as e:
                    logging.warning(f"Cannot calculate mean distance for structure {i} of patient {patient_id}: {e}")
            logging.debug(f"Patient {patient_id}, class {i} has Dice score {dice}")

        def add_metric(metric_type: MetricType, value: float) -> None:
            metrics.add_metric(metric_type, value, skip_nan_when_averaging=True, hue=ground_truth_ids[i - 1])

        add_metric(MetricType.DICE, dice)
        add_metric(MetricType.HAUSDORFF_mm, hausdorff_distance)
        add_metric(MetricType.MEAN_SURFACE_DIST_mm, mean_surface_distance)
    return metrics
Example #11
0
 def __post_init__(self) -> None:
     ml_util.check_size_matches(arg1=self.posteriors,
                                arg2=self.segmentations,
                                dim1=5,
                                dim2=4,
                                matching_dimensions=[0, -1, -2, -3])
Example #12
0
    def predict_whole_image(self,
                            image_channels: np.ndarray,
                            voxel_spacing_mm: TupleFloat3,
                            mask: Optional[np.ndarray] = None,
                            patient_id: int = 0) -> InferencePipeline.Result:
        """
        Performs a single inference pass through the pipeline for the provided image
        :param image_channels: The input image channels to perform inference on in format: Channels x Z x Y x X.
        :param voxel_spacing_mm: Voxel spacing to use for each dimension in (Z x Y x X) order
        :param mask: A binary image used to ignore results outside it in format: Z x Y x X.
        :param patient_id: The identifier of the patient this image belongs to (defaults to 0 if None provided).
        :return InferenceResult: that contains Segmentation for each of the classes and their posterior probabilities.
        """
        if image_channels is None:
            raise Exception("image_channels cannot be None")
        if image_channels.ndim != 4:
            raise NotImplementedError(
                "image_channels must be in shape: Channels x Z x Y x X"
                "found image_channels shape: {}".format(image_channels.shape))
        if mask is not None:
            ml_util.check_size_matches(image_channels, mask, 4, 3,
                                       [-1, -2, -3])
        self.model.eval()

        image = tio.ScalarImage(tensor=image_channels)
        INPUT = 'input_image'
        MASK = 'mask'

        subject_dict: Dict[str, tio.Image] = {INPUT: image}
        if mask is not None:
            subject_dict[MASK] = tio.LabelMap(tensor=mask[np.newaxis])
        subject = tio.Subject(subject_dict)

        constraints = self.model.model.crop_size_constraints

        # Make sure the image size is compatible with the model
        multiple_constraints = constraints.multiple_of  # type: ignore
        if multiple_constraints is not None:
            ensure_shape_multiple = tio.EnsureShapeMultiple(
                constraints.multiple_of)  # type: ignore
            subject = ensure_shape_multiple(subject)  # type: ignore

        # There may be cases where the test image is smaller than the test_crop_size. Adjust crop_size
        # to always fit into image. If test_crop_size is smaller than the image, crop will remain unchanged.
        restrict_patch_size = constraints.restrict_crop_size_to_image  # type: ignore
        effective_patch_size, effective_stride = restrict_patch_size(
            subject.spatial_shape,  # type: ignore
            self.model_config.test_crop_size,
            self.model_config.inference_stride_size)

        patch_overlap = np.array(effective_patch_size) - np.array(
            effective_stride)
        grid_sampler = tio.inference.GridSampler(
            subject,
            effective_patch_size,
            patch_overlap,
            padding_mode=self.model_config.padding_mode.value,
        )
        batch_size = self.model_config.inference_batch_size
        patch_loader = torch.utils.data.DataLoader(
            grid_sampler, batch_size=batch_size)  # type: ignore
        aggregator = tio.inference.GridAggregator(grid_sampler)

        logging.debug(
            f"Inference on image size {subject.spatial_shape} will run "
            f"with crop size {effective_patch_size} and stride {effective_stride}"
        )
        for patches_batch in patch_loader:
            input_tensor = patches_batch[INPUT][tio.DATA].float()
            if self.model_config.use_gpu:
                input_tensor = input_tensor.cuda()
            locations = patches_batch[tio.LOCATION]
            # perform the forward pass
            patches_posteriors = self.model(input_tensor).detach()
            # pad posteriors if they are smaller than the input
            input_shape = input_tensor.shape[-3:]
            patches_posteriors_shape = patches_posteriors.shape[-3:]
            if input_shape != patches_posteriors_shape:
                difference = np.array(input_shape) - np.array(
                    patches_posteriors_shape)
                assert not np.any(
                    difference %
                    2)  # the differences in shape are expected to be even
                padding = tuple(np.repeat(difference // 2, 2))
                patches_posteriors = torch.nn.functional.pad(
                    patches_posteriors, padding)
            # collect the predictions over each of the batches
            aggregator.add_batch(patches_posteriors, locations)
        posteriors = aggregator.get_output_tensor().numpy()
        posteriors_mask = None if mask is None else subject[MASK].numpy()[0]
        posteriors, segmentation = self.post_process_posteriors(
            posteriors, mask=posteriors_mask)

        image_util.check_array_range(posteriors,
                                     error_prefix="Whole image posteriors")

        # Make sure the final shape matches the input shape by undoing the padding in EnsureShapeMultiple (if any)
        posteriors_image = tio.ScalarImage(tensor=posteriors,
                                           affine=image.affine)
        segmentation_image = tio.LabelMap(tensor=segmentation[np.newaxis],
                                          affine=image.affine)
        subject.add_image(posteriors_image, 'posteriors')
        subject.add_image(segmentation_image, 'segmentation')
        # Remove some images to avoid unnecessary computations
        subject.remove_image(INPUT)
        if mask is not None:
            subject.remove_image(MASK)
        subject_original_space = subject.apply_inverse_transform(
        ) if subject.applied_transforms else subject
        posteriors = subject_original_space.posteriors.numpy()  # type: ignore
        segmentation = subject_original_space.segmentation.numpy()[
            0]  # type: ignore

        # prepare pipeline results from the processed batch
        return InferencePipeline.Result(patient_id=patient_id,
                                        segmentation=segmentation,
                                        posteriors=posteriors,
                                        voxel_spacing_mm=voxel_spacing_mm)