def apply_mask_to_posteriors(posteriors: NumpyOrTorch, mask: NumpyOrTorch) -> NumpyOrTorch: """ Apply a binary mask to the provided posteriors such that for all voxels outside of the mask: 1) The background class posterior (index == 0) is set to 1. 2) All other classes posteriors are set to 0. :param posteriors: image tensors in shape: Batches (optional) x Classes x Z x Y x X :param mask: image tensor in shape: Batches (optional) x Z x Y x X :return posteriors with mask applied """ ml_util.check_size_matches(posteriors, mask, matching_dimensions=[-1, -2, -3]) batch_posteriors = len(posteriors.shape) != 5 if batch_posteriors: posteriors = posteriors[None, ...] if len(mask.shape) != 4: mask = mask[None, ...] if posteriors.shape[0] != mask.shape[0]: raise ValueError( "posteriors and mask must have the same number of patches, " "found posteriors={}, mask={}".format(posteriors.shape, mask.shape)) for c in range(posteriors.shape[1]): posteriors[:, c, ...][mask == 0] = int(c == 0) if batch_posteriors: posteriors = posteriors[0] return posteriors
def compute_dice_across_patches(segmentation: torch.Tensor, ground_truth: torch.Tensor, allow_multiple_classes_for_each_pixel: bool = False) -> torch.Tensor: """ Computes the Dice scores for all classes across all patches in the arguments. :param segmentation: Tensor containing class ids predicted by a model. :param ground_truth: One-hot encoded torch tensor containing ground-truth label ids. :param allow_multiple_classes_for_each_pixel: If set to False, ground-truth tensor has to contain only one foreground label for each pixel. :return A torch tensor of size (Patches, Classes) with the Dice scores. Dice scores are computed for all classes including the background class at index 0. """ check_size_matches(segmentation, ground_truth, 4, 5, [0, -3, -2, -1], arg1_name="segmentation", arg2_name="ground_truth") # One-hot encoded ground-truth values should sum up to one for all pixels if not allow_multiple_classes_for_each_pixel: if not torch.allclose(torch.sum(ground_truth, dim=1).float(), torch.ones(segmentation.shape, device=ground_truth.device).float()): raise Exception("Ground-truth one-hot matrix does not sum up to one for all pixels") # Convert the ground-truth to one-hot-encoding [num_patches, num_classes] = ground_truth.size()[:2] one_hot_segmentation = F.one_hot(segmentation, num_classes=num_classes).permute(0, 4, 1, 2, 3) # Convert the tensors to bool tensors one_hot_segmentation = one_hot_segmentation.bool().view(num_patches, num_classes, -1) ground_truth = ground_truth.bool().view(num_patches, num_classes, -1) # And operation between segmentation and ground-truth - reduction operation # Count the number of samples in segmentation and ground-truth intersection = 2.0 * torch.sum(one_hot_segmentation & ground_truth, dim=-1).float() union = torch.sum(one_hot_segmentation, dim=-1) + torch.sum(ground_truth, dim=-1).float() + 1.0e-6 return intersection / union
def __post_init__(self) -> None: # make sure all properties are populated common_util.check_properties_are_not_none(self) # ensure the center crops for the labels and mask are compatible with each other ml_util.check_size_matches(arg1=self.mask_center_crop, arg2=self.labels_center_crop, matching_dimensions=self._get_matching_dimensions())
def plot_contours_for_all_classes(sample: Sample, segmentation: np.ndarray, foreground_class_names: List[str], result_folder: Path, result_prefix: str = "", image_range: Optional[TupleFloat2] = None, channel_index: int = 0) -> List[Path]: """ Creates a plot with the image, the ground truth, and the predicted segmentation overlaid. One plot is created for each class, each plotting the Z slice where the ground truth has most pixels. :param sample: The image sample, with the photonormalized image and the ground truth labels. :param segmentation: The predicted segmentation: multi-value, size Z x Y x X. :param foreground_class_names: The names of all classes, excluding the background class. :param result_folder: The folder into which the resulting plot PNG files should be written. :param result_prefix: A string prefix that will be used for all plots. :param image_range: The minimum and maximum image values that will be mapped to the color map ranges. If None, use the actual min and max values. :param channel_index: The index of the image channel that should be plotted. :return: The paths to all generated PNG files. """ check_size_matches(sample.labels[0], segmentation) num_classes = sample.labels.shape[0] if len(foreground_class_names) != num_classes - 1: raise ValueError( f"Labels tensor indicates {num_classes} classes, but got {len(foreground_class_names)} foreground " f"class names: {foreground_class_names}") plot_names: List[Path] = [] image = sample.image[channel_index, ...] contour_arguments = [{ 'colors': 'r' }, { 'colors': 'b', 'linestyles': 'dashed' }] binaries = binaries_from_multi_label_array(segmentation, num_classes) for class_index, binary in enumerate(binaries): if class_index == 0: continue ground_truth = sample.labels[class_index, ...] largest_gt_slice = get_largest_z_slice(ground_truth) labels_at_largest_gt = ground_truth[largest_gt_slice] segmentation_at_largest_gt = binary[largest_gt_slice, ...] class_name = foreground_class_names[class_index - 1] patient_id = sample.patient_id if isinstance(patient_id, str): patient_id_str = patient_id else: patient_id_str = f"{patient_id:03d}" filename_stem = f"{result_prefix}{patient_id_str}_{class_name}_slice_{largest_gt_slice:03d}" plot_file = plot_image_and_label_contour( image=image[largest_gt_slice, ...], labels=[labels_at_largest_gt, segmentation_at_largest_gt], contour_arguments=contour_arguments, image_range=image_range, plot_file_name=result_folder / filename_stem) plot_names.append(plot_file) return plot_names
def __post_init__(self) -> None: # make sure all properties are populated common_util.check_properties_are_not_none(self) ml_util.check_size_matches(arg1=self.image, arg2=self.mask, matching_dimensions=self._get_matching_dimensions()) ml_util.check_size_matches(arg1=self.image, arg2=self.labels, matching_dimensions=self._get_matching_dimensions())
def __post_init__(self) -> None: common_util.check_properties_are_not_none(self) ml_util.check_size_matches(arg1=self.image, arg2=self.prediction, dim1=3, dim2=3, matching_dimensions=[]) ml_util.check_size_matches(arg1=self.image, arg2=self.labels, dim1=3, dim2=4, matching_dimensions=[-1, -2, -3])
def predict_whole_image(self, image_channels: np.ndarray, voxel_spacing_mm: TupleFloat3, mask: np.ndarray = None, patient_id: int = 0) -> InferencePipeline.Result: """ Performs a single inference pass through the pipeline for the provided image :param image_channels: The input image channels to perform inference on in format: Channels x Z x Y x X. :param voxel_spacing_mm: Voxel spacing to use for each dimension in (Z x Y x X) order :param mask: A binary image used to ignore results outside it in format: Z x Y x X. :param patient_id: The identifier of the patient this image belongs to (defaults to 0 if None provided). :return InferenceResult: that contains Segmentation for each of the classes and their posterior probabilities. """ if image_channels is None: raise Exception("image_channels cannot be None") if image_channels.ndim != 4: raise NotImplementedError("image_channels must be in shape: Channels x Z x Y x X" "found image_channels shape: {}".format(image_channels.shape)) if mask is not None: ml_util.check_size_matches(image_channels, mask, 4, 3, [-1, -2, -3]) self.model.eval() # create the dataset for the batch batch_dataset = Dataset(index=[patient_id], batch_class=InferenceBatch) # setup the pipeline pipeline = (batch_dataset.p # define pipeline variables .init_variables([InferencePipeline.Variables.Model, InferencePipeline.Variables.ModelConfig, InferencePipeline.Variables.CropSize, InferencePipeline.Variables.OutputSize, InferencePipeline.Variables.OutputImageShape, InferencePipeline.Variables.Stride]) # update the variables for the batch actions .update_variable(name=InferencePipeline.Variables.Model, value=self.model) .update_variable(name=InferencePipeline.Variables.ModelConfig, value=self.model_config) # perform cascaded batch actions .load(image_channels=image_channels, mask=mask) .pre_process() .predict() .post_process() ) # run the batch through the pipeline logging.info(f"Inference pipeline ({self.pipeline_id}), Predicting patient: {patient_id}") processed_batch: InferenceBatch = pipeline.next_batch(batch_size=1) posteriors = processed_batch.get_component(InferenceBatch.Components.Posteriors) image_util.check_array_range(posteriors, error_prefix="Whole image posteriors") # prepare pipeline results from the processed batch return InferencePipeline.Result( patient_id=patient_id, segmentation=processed_batch.get_component(InferenceBatch.Components.Segmentation), posteriors=posteriors, voxel_spacing_mm=voxel_spacing_mm )
def __init__(self, epoch: int, patient_id: int, segmentation: np.ndarray, posteriors: np.ndarray, voxel_spacing_mm: TupleFloat3): """ :param epoch: The epoch for which inference in being performed on. :param patient_id: The id of the patient instance for with inference is being performed on. :param segmentation: Z x Y x X (argmaxed over the posteriors in the class dimension) :param voxel_spacing_mm: Voxel spacing to use for each dimension in (Z x Y x X) order :param posteriors: Class x Z x Y x X """ self.epoch = epoch self.patient_id = patient_id self.segmentation = segmentation self.posteriors = posteriors self.voxel_spacing_mm = voxel_spacing_mm if len(self.voxel_spacing_mm) != 3: raise ValueError( f"voxel_spacing_mm must have length 3, found: {voxel_spacing_mm}" ) if any(np.array(self.voxel_spacing_mm) <= 0): raise ValueError( f"voxel_spacing_mm must have values > 0 in each dimension, found: {voxel_spacing_mm}" ) ml_util.check_size_matches(self.segmentation, self.posteriors, dim1=3, dim2=4, matching_dimensions=[-3, -2, -1], arg1_name="segmentation", arg2_name="posteriors") segmentation_value_range = np.unique(self.segmentation) if not np.all([ x in range(self.posteriors.shape[0]) for x in segmentation_value_range ]): raise Exception( "values in the segmentation map must be in range [0, classes), " "found classes:{}, segmentation range:{}".format( self.posteriors.shape[0], segmentation_value_range)) self._uncertainty = compute_uncertainty_map_from_posteriors( self.posteriors)
def test_check_size() -> None: """ Test `check_size_matches` function. """ a1 = np.zeros((2, 3, 4)) a2 = np.zeros((5, 2, 3, 4)) check_size_matches(a1, a1) check_size_matches(a1, a2, matching_dimensions=[-3, -2, -1]) check_size_matches(a1, a2, dim1=3, dim2=4, matching_dimensions=[-3, -2, -1]) check_size_matches(a1, a1, matching_dimensions=[0, 1]) def throws(func: Callable[..., None]) -> None: with pytest.raises(ValueError) as e: func() print("Exception message: {}".format(e.value)) # Can't compare arrays of different dimension throws(lambda: check_size_matches(a1, a2)) # type: ignore # a2 has wrong dimension throws(lambda: check_size_matches(a1, a2, dim1=3, dim2=3)) # type: ignore # a1 has wrong dimension throws(lambda: check_size_matches(a1, a2, dim1=4, dim2=4)) # type: ignore # a1 has wrong dimension [0] throws(lambda: check_size_matches(a1, a2, dim1=4, dim2=4)) # type: ignore
def calculate_metrics_per_class(segmentation: np.ndarray, ground_truth: np.ndarray, ground_truth_ids: List[str], voxel_spacing: TupleFloat3, patient_id: Optional[int] = None) -> MetricsDict: """ Calculate the dice for all foreground structures (the background class is completely ignored). Returns a MetricsDict with metrics for each of the foreground structures. Metrics are NaN if both ground truth and prediction are all zero for a class. :param ground_truth_ids: The names of all foreground classes. :param segmentation: predictions multi-value array with dimensions: [Z x Y x X] :param ground_truth: ground truth binary array with dimensions: [C x Z x Y x X] :param voxel_spacing: voxel_spacing in 3D Z x Y x X :param patient_id: for logging """ number_of_classes = ground_truth.shape[0] if len(ground_truth_ids) != (number_of_classes - 1): raise ValueError(f"Received {len(ground_truth_ids)} foreground class names, but " f"the label tensor indicates that there are {number_of_classes - 1} classes.") binaries = binaries_from_multi_label_array(segmentation, number_of_classes) all_classes_are_binary = [is_binary_array(ground_truth[label_id]) for label_id in range(ground_truth.shape[0])] if not np.all(all_classes_are_binary): raise ValueError("Ground truth values should be 0 or 1") overlap_measures_filter = sitk.LabelOverlapMeasuresImageFilter() hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter() metrics = MetricsDict(hues=ground_truth_ids) for i, prediction in enumerate(binaries): if i == 0: continue check_size_matches(prediction, ground_truth[i], arg1_name="prediction", arg2_name="ground_truth") if not is_binary_array(prediction): raise ValueError("Predictions values should be 0 or 1") # simpleitk returns a Dice score of 0 if both ground truth and prediction are all zeros. # We want to be able to fish out those cases, and treat them specially later. prediction_zero = np.all(prediction == 0) gt_zero = np.all(ground_truth[i] == 0) dice = mean_surface_distance = hausdorff_distance = math.nan if not (prediction_zero and gt_zero): prediction_image = sitk.GetImageFromArray(prediction.astype(np.uint8)) prediction_image.SetSpacing(sitk.VectorDouble(reverse_tuple_float3(voxel_spacing))) ground_truth_image = sitk.GetImageFromArray(ground_truth[i].astype(np.uint8)) ground_truth_image.SetSpacing(sitk.VectorDouble(reverse_tuple_float3(voxel_spacing))) overlap_measures_filter.Execute(prediction_image, ground_truth_image) dice = overlap_measures_filter.GetDiceCoefficient() if prediction_zero or gt_zero: hausdorff_distance = mean_surface_distance = math.inf else: try: hausdorff_distance_filter.Execute(prediction_image, ground_truth_image) hausdorff_distance = hausdorff_distance_filter.GetHausdorffDistance() except Exception as e: logging.warning("Cannot calculate Hausdorff distance for " f"structure {i} of patient {patient_id}: {e}") try: mean_surface_distance = surface_distance(prediction_image, ground_truth_image) except Exception as e: logging.warning(f"Cannot calculate mean distance for structure {i} of patient {patient_id}: {e}") logging.debug(f"Patient {patient_id}, class {i} has Dice score {dice}") def add_metric(metric_type: MetricType, value: float) -> None: metrics.add_metric(metric_type, value, skip_nan_when_averaging=True, hue=ground_truth_ids[i - 1]) add_metric(MetricType.DICE, dice) add_metric(MetricType.HAUSDORFF_mm, hausdorff_distance) add_metric(MetricType.MEAN_SURFACE_DIST_mm, mean_surface_distance) return metrics
def __post_init__(self) -> None: ml_util.check_size_matches(arg1=self.posteriors, arg2=self.segmentations, dim1=5, dim2=4, matching_dimensions=[0, -1, -2, -3])
def predict_whole_image(self, image_channels: np.ndarray, voxel_spacing_mm: TupleFloat3, mask: Optional[np.ndarray] = None, patient_id: int = 0) -> InferencePipeline.Result: """ Performs a single inference pass through the pipeline for the provided image :param image_channels: The input image channels to perform inference on in format: Channels x Z x Y x X. :param voxel_spacing_mm: Voxel spacing to use for each dimension in (Z x Y x X) order :param mask: A binary image used to ignore results outside it in format: Z x Y x X. :param patient_id: The identifier of the patient this image belongs to (defaults to 0 if None provided). :return InferenceResult: that contains Segmentation for each of the classes and their posterior probabilities. """ if image_channels is None: raise Exception("image_channels cannot be None") if image_channels.ndim != 4: raise NotImplementedError( "image_channels must be in shape: Channels x Z x Y x X" "found image_channels shape: {}".format(image_channels.shape)) if mask is not None: ml_util.check_size_matches(image_channels, mask, 4, 3, [-1, -2, -3]) self.model.eval() image = tio.ScalarImage(tensor=image_channels) INPUT = 'input_image' MASK = 'mask' subject_dict: Dict[str, tio.Image] = {INPUT: image} if mask is not None: subject_dict[MASK] = tio.LabelMap(tensor=mask[np.newaxis]) subject = tio.Subject(subject_dict) constraints = self.model.model.crop_size_constraints # Make sure the image size is compatible with the model multiple_constraints = constraints.multiple_of # type: ignore if multiple_constraints is not None: ensure_shape_multiple = tio.EnsureShapeMultiple( constraints.multiple_of) # type: ignore subject = ensure_shape_multiple(subject) # type: ignore # There may be cases where the test image is smaller than the test_crop_size. Adjust crop_size # to always fit into image. If test_crop_size is smaller than the image, crop will remain unchanged. restrict_patch_size = constraints.restrict_crop_size_to_image # type: ignore effective_patch_size, effective_stride = restrict_patch_size( subject.spatial_shape, # type: ignore self.model_config.test_crop_size, self.model_config.inference_stride_size) patch_overlap = np.array(effective_patch_size) - np.array( effective_stride) grid_sampler = tio.inference.GridSampler( subject, effective_patch_size, patch_overlap, padding_mode=self.model_config.padding_mode.value, ) batch_size = self.model_config.inference_batch_size patch_loader = torch.utils.data.DataLoader( grid_sampler, batch_size=batch_size) # type: ignore aggregator = tio.inference.GridAggregator(grid_sampler) logging.debug( f"Inference on image size {subject.spatial_shape} will run " f"with crop size {effective_patch_size} and stride {effective_stride}" ) for patches_batch in patch_loader: input_tensor = patches_batch[INPUT][tio.DATA].float() if self.model_config.use_gpu: input_tensor = input_tensor.cuda() locations = patches_batch[tio.LOCATION] # perform the forward pass patches_posteriors = self.model(input_tensor).detach() # pad posteriors if they are smaller than the input input_shape = input_tensor.shape[-3:] patches_posteriors_shape = patches_posteriors.shape[-3:] if input_shape != patches_posteriors_shape: difference = np.array(input_shape) - np.array( patches_posteriors_shape) assert not np.any( difference % 2) # the differences in shape are expected to be even padding = tuple(np.repeat(difference // 2, 2)) patches_posteriors = torch.nn.functional.pad( patches_posteriors, padding) # collect the predictions over each of the batches aggregator.add_batch(patches_posteriors, locations) posteriors = aggregator.get_output_tensor().numpy() posteriors_mask = None if mask is None else subject[MASK].numpy()[0] posteriors, segmentation = self.post_process_posteriors( posteriors, mask=posteriors_mask) image_util.check_array_range(posteriors, error_prefix="Whole image posteriors") # Make sure the final shape matches the input shape by undoing the padding in EnsureShapeMultiple (if any) posteriors_image = tio.ScalarImage(tensor=posteriors, affine=image.affine) segmentation_image = tio.LabelMap(tensor=segmentation[np.newaxis], affine=image.affine) subject.add_image(posteriors_image, 'posteriors') subject.add_image(segmentation_image, 'segmentation') # Remove some images to avoid unnecessary computations subject.remove_image(INPUT) if mask is not None: subject.remove_image(MASK) subject_original_space = subject.apply_inverse_transform( ) if subject.applied_transforms else subject posteriors = subject_original_space.posteriors.numpy() # type: ignore segmentation = subject_original_space.segmentation.numpy()[ 0] # type: ignore # prepare pipeline results from the processed batch return InferencePipeline.Result(patient_id=patient_id, segmentation=segmentation, posteriors=posteriors, voxel_spacing_mm=voxel_spacing_mm)