def create_from_checkpoint( path_to_checkpoint: Path, model_config: SegmentationModelBase, pipeline_id: int = 0) -> Optional[InferencePipeline]: """ Creates an instance of the inference pipeline for a given epoch from a stored checkpoint. After loading, the model parameters are checked for NaN and Infinity values. If there is no checkpoint file for the given epoch, return None. :param path_to_checkpoint: The path to the checkpoint that we want to load model_config.checkpoint_folder :param model_config: Model related configurations. :param pipeline_id: Numeric identifier for the pipeline (useful for logging when ensembling) :return InferencePipeline: an instantiated inference pipeline instance, or None if there was no checkpoint file for this epoch. """ model_and_info = model_util.load_from_checkpoint_and_adjust( model_config, path_to_checkpoint) if model_and_info.checkpoint_epoch is None or model_and_info.model is None: return None for name, param in model_and_info.model.named_parameters(): param_numpy = param.clone().cpu().data.numpy() image_util.check_array_range( param_numpy, error_prefix="Parameter {}".format(name)) return InferencePipeline(model=model_and_info.model, model_config=model_config, epoch=model_and_info.checkpoint_epoch, pipeline_id=pipeline_id)
def predict(self) -> InferenceBatch: """ Perform a forward pass of the model on the provided image, this generates a set of posterior maps for each class, as well as a segmentation output stored in the respective 'posteriors' and 'segmentation' components. """ model_config = self.get_configs() # extract patches for each image channel: Num patches x Channels x Z x Y x X patches = self._extract_patches_for_image_channels() # split the generated patches into batches and perform forward passes predictions = [] batch_size = model_config.inference_batch_size for batch_idx in range(0, len(patches), batch_size): # slice over the batches to prepare batch batch = patches[batch_idx:batch_idx + batch_size, ...] # perform the forward pass batch_predictions = self._model_fn(batch) image_util.check_array_range( batch_predictions, expected_range=InferencePipeline. MODEL_OUTPUT_POSTERIOR_RANGE, # type: ignore error_prefix="Model predictions for current batch") # collect the predictions over each of the batches predictions.append(batch_predictions) # map the batched predictions to the original batch shape # of shape but with an added class dimension: Num patches x Class x Z x Y x X predictions = np.concatenate(predictions, axis=0) # create posterior output for each class with the shape: Class x Z x Y x x. We use float32 as these # arrays can be big. output_image_shape = self.pipeline.get_variable( InferencePipeline.Variables.OutputImageShape) posteriors = np.zeros(shape=[model_config.number_of_classes] + list(output_image_shape), dtype=np.float32) stride = self.pipeline.get_variable(InferencePipeline.Variables.Stride) for c in range(len(posteriors)): # stitch the patches for each posterior class self.load_from_patches( predictions[:, c, ...], # type: ignore stride=stride, scan_shape=output_image_shape, data_attr=InferenceBatch.Components.Posteriors.value) # extract computed output from the component so the pipeline buffer can be reused posteriors[c] = self.get_component( InferenceBatch.Components.Posteriors) # store the stitched up results for the batch self.set_component(component=InferenceBatch.Components.Posteriors, data=posteriors) return self
def transform( self, image: Union[np.ndarray, torch.Tensor], mask: Optional[Union[np.ndarray, torch.Tensor]] = None, patient_id: Optional[int] = None ) -> Union[np.ndarray, torch.Tensor]: if mask is None: if torch.is_tensor(image): mask = torch.ones_like(image) else: mask = np.ones_like(image) self.status_of_most_recent_call = None if self.norm_method == PhotometricNormalizationMethod.Unchanged: image_out = image elif self.norm_method == PhotometricNormalizationMethod.SimpleNorm: image_out = simple_norm(image, mask, self.debug_mode) elif self.norm_method == PhotometricNormalizationMethod.MriWindow: if self.sharpen is None: raise ValueError("The 'sharpen' parameter must be provided.") if not (isinstance(self.tail, list) or isinstance(self.tail, float)): raise ValueError( "The 'tail' parameter must be provided and set to a float value or a list of float values." ) image_out, status = mri_window(image, mask, self.output_range, self.sharpen, self.tail, self.debug_mode) self.status_of_most_recent_call = status elif self.norm_method == PhotometricNormalizationMethod.CtWindow: if self.level is None: raise ValueError("The 'level' parameter must be provided.") if self.window is None: raise ValueError("The 'window' parameter must be provided.") image_out = CTRange.transform(data=image, output_range=self.output_range, level=self.level, window=self.window, use_gpu=self.use_gpu) elif self.norm_method == PhotometricNormalizationMethod.TrimmedNorm: image_out, status = normalize_trim(image, mask, self.output_range, self.sharpen, self.trim_percentiles, self.debug_mode) self.status_of_most_recent_call = status else: raise ValueError("Unknown normalization method {}".format( self.norm_method)) if patient_id is not None and self.status_of_most_recent_call is not None: logging.debug( f"Photonorm patient {patient_id}: {self.status_of_most_recent_call}" ) check_array_range(image_out, error_prefix="Normalized image") return image_out
def predict_whole_image(self, image_channels: np.ndarray, voxel_spacing_mm: TupleFloat3, mask: np.ndarray = None, patient_id: int = 0) -> InferencePipeline.Result: """ Performs a single inference pass through the pipeline for the provided image :param image_channels: The input image channels to perform inference on in format: Channels x Z x Y x X. :param voxel_spacing_mm: Voxel spacing to use for each dimension in (Z x Y x X) order :param mask: A binary image used to ignore results outside it in format: Z x Y x X. :param patient_id: The identifier of the patient this image belongs to (defaults to 0 if None provided). :return InferenceResult: that contains Segmentation for each of the classes and their posterior probabilities. """ if image_channels is None: raise Exception("image_channels cannot be None") if image_channels.ndim != 4: raise NotImplementedError("image_channels must be in shape: Channels x Z x Y x X" "found image_channels shape: {}".format(image_channels.shape)) if mask is not None: ml_util.check_size_matches(image_channels, mask, 4, 3, [-1, -2, -3]) self.model.eval() # create the dataset for the batch batch_dataset = Dataset(index=[patient_id], batch_class=InferenceBatch) # setup the pipeline pipeline = (batch_dataset.p # define pipeline variables .init_variables([InferencePipeline.Variables.Model, InferencePipeline.Variables.ModelConfig, InferencePipeline.Variables.CropSize, InferencePipeline.Variables.OutputSize, InferencePipeline.Variables.OutputImageShape, InferencePipeline.Variables.Stride]) # update the variables for the batch actions .update_variable(name=InferencePipeline.Variables.Model, value=self.model) .update_variable(name=InferencePipeline.Variables.ModelConfig, value=self.model_config) # perform cascaded batch actions .load(image_channels=image_channels, mask=mask) .pre_process() .predict() .post_process() ) # run the batch through the pipeline logging.info(f"Inference pipeline ({self.pipeline_id}), Predicting patient: {patient_id}") processed_batch: InferenceBatch = pipeline.next_batch(batch_size=1) posteriors = processed_batch.get_component(InferenceBatch.Components.Posteriors) image_util.check_array_range(posteriors, error_prefix="Whole image posteriors") # prepare pipeline results from the processed batch return InferencePipeline.Result( patient_id=patient_id, segmentation=processed_batch.get_component(InferenceBatch.Components.Segmentation), posteriors=posteriors, voxel_spacing_mm=voxel_spacing_mm )
def validate_and_store_model_parameters(writer: tensorboardX.SummaryWriter, epoch: int, model: DeviceAwareModule) -> None: """ Validates and writes all model weights to the given TensorBoard writer. :param writer: TensorBoard summary writer :param epoch: The epoch for which these model parameters correspond to. :param model: The model from which to extract the parameters. :return: """ for name, param in model.named_parameters(): param_numpy = param.clone().cpu().data.numpy() check_array_range(param_numpy, error_prefix="Parameter {}".format(name)) writer.add_histogram(name, param_numpy, epoch)
def test_check_input_range_with_tolerance() -> None: """ Test `check_array_range` for cases where values are only *just* outside the range. """ tolerance = image_util.VALUE_RANGE_TOLERANCE low_value = 0.0 high_value = 1.0 allowed_range = (low_value, high_value) values1 = np.array( [low_value - 1.1 * tolerance, high_value + 1.1 * tolerance]) with pytest.raises(ValueError): image_util.check_array_range(values1, allowed_range) values2 = np.array( [low_value - 0.9 * tolerance, high_value + 1.1 * tolerance]) with pytest.raises(ValueError): image_util.check_array_range(values2, allowed_range) values3 = np.array( [low_value - 1.1 * tolerance, high_value + 0.9 * tolerance]) with pytest.raises(ValueError): image_util.check_array_range(values3, allowed_range) values4 = np.array( [low_value - 0.9 * tolerance, high_value + 0.9 * tolerance]) image_util.check_array_range(values4, allowed_range) assert values4[0] == low_value assert values4[1] == high_value
def store_posteriors_as_nifti(image: np.ndarray, header: ImageHeader, file_name: PathOrString) -> Path: """ Saves an array of posteriors in nifti format as ubyte, and performs the following operations: 1) transpose the image back into X,Y,Z from Z,Y,X 2) perform a linear scaling from [0, 1] to byte range 3) cast the image values to ubyte before saving :param image: 3D image in shape: Z x Y x X. :param header: Image header for the image :param file_name: The name of the file for this image. :return: the path to the saved image """ check_array_range(image, DEFAULT_POSTERIOR_VALUE_RANGE, error_prefix="Posterior") return store_as_scaled_ubyte_nifti(image=image, header=header, file_name=file_name, input_range=DEFAULT_POSTERIOR_VALUE_RANGE)
def create_from_checkpoint( path_to_checkpoint: Path, model_config: SegmentationModelBase, pipeline_id: int = 0) -> Optional[InferencePipeline]: """ Creates an instance of the inference pipeline for a given epoch from a stored checkpoint. After loading, the model parameters are checked for NaN and Infinity values. If there is no checkpoint file for the given epoch, return None. :param path_to_checkpoint: The path to the checkpoint that we want to load model_config.checkpoint_folder :param model_config: Model related configurations. :param pipeline_id: Numeric identifier for the pipeline (useful for logging when ensembling) :return InferencePipeline: an instantiated inference pipeline instance, or None if there was no checkpoint file for this epoch. """ model_and_info = model_util.ModelAndInfo( config=model_config, model_execution_mode=ModelExecutionMode.TEST, checkpoint_path=path_to_checkpoint) if model_config.compute_mean_teacher_model: model_loaded = model_and_info.try_create_mean_teacher_model_load_from_checkpoint_and_adjust( ) model = model_and_info.mean_teacher_model else: model_loaded = model_and_info.try_create_model_load_from_checkpoint_and_adjust( ) model = model_and_info.model if not model_loaded: return None # for mypy, if model has been loaded these will not be None assert model_and_info.checkpoint_epoch is not None for name, param in model.named_parameters(): param_numpy = param.clone().cpu().data.numpy() image_util.check_array_range( param_numpy, error_prefix="Parameter {}".format(name)) return InferencePipeline(model=model, model_config=model_config, epoch=model_and_info.checkpoint_epoch, pipeline_id=pipeline_id)
def run_inference_on_unet(size: TupleInt3) -> None: """ Runs a model forward pass on a freshly created model, with an input image of the given size. Asserts that the model prediction has the same size as the input image. """ fg_classes = ["tumour_mass", "subtract"] number_of_classes = len(fg_classes) + 1 config = SegmentationModelBase( architecture="UNet3D", local_dataset=Path("dummy"), feature_channels=[1], kernel_size=3, largest_connected_component_foreground_classes=fg_classes, posterior_smoothing_mm=(2, 2, 2), crop_size=(64, 64, 64), # test_crop_size must be larger than 'size for the bug to trigger test_crop_size=(80, 80, 80), image_channels=["mr"], ground_truth_ids=fg_classes, ground_truth_ids_display_names=fg_classes, colours=[(255, 0, 0)] * len(fg_classes), fill_holes=[False] * len(fg_classes), mask_id=None, class_weights=[1.0 / number_of_classes] * number_of_classes, train_batch_size=8, inference_batch_size=1, inference_stride_size=(40, 40, 40), use_mixed_precision=True) lightning_model = create_lightning_model(config) assert isinstance(lightning_model, SegmentationLightning) pipeline = InferencePipeline(model=lightning_model, model_config=config) image = np.random.uniform(-1, 1, (1, ) + size) result = pipeline.predict_and_post_process_whole_image( image, mask=np.ones(size), voxel_spacing_mm=(1, 1, 1)) # All posteriors and segmentations must have the size of the input image for p in [*result.posteriors, result.segmentation]: assert p.shape == size # Check that all results are not NaN. In particular, if stride size is not adjusted # correctly, the results would be partially NaN. image_util.check_array_range(p)
def test_check_input_range() -> None: """ Test the `check_array_range` function in particular for arrays with missing values. """ image = np.array([1, 2, 3, 4]) image_nan = np.array([1, 2, 3, np.nan, np.nan]) image_inf = np.array([1, 2, 3, np.inf, np.inf]) image_nan_inf = np.array([1, 2, 3, np.nan, np.inf]) # All values are in the range, this should work image_util.check_array_range(image, (1, 4)) # When not providing a range, it should only check for NaN and Inf, but there are none. image_util.check_array_range(image, None) # Using a smaller range than is present in the array: This should fail, and print the interval in the error message with pytest.raises(ValueError) as err: image_util.check_array_range(image, (1, 2)) assert "within [1, 2]" in err.value.args[0] assert "invalid values: 3, 4" in err.value.args[0] # Now try all the arrays that contain NaN and/or Inf. None should pass the test, with or without an interval given. for data in [image_inf, image_nan_inf]: with pytest.raises(ValueError) as err: image_util.check_array_range(data) assert "finite" in err.value.args[0] assert "inf" in err.value.args[0] with pytest.raises(ValueError) as err: image_util.check_array_range(data, (1, 4)) assert "within [1, 4]" in err.value.args[0] assert "inf" in err.value.args[0] for data in [image_nan, image_nan_inf]: with pytest.raises(ValueError) as err: image_util.check_array_range(data) assert "finite" in err.value.args[0] assert "nan" in err.value.args[0] with pytest.raises(ValueError) as err: image_util.check_array_range(data, (1, 4)) assert "within [1, 4]" in err.value.args[0] assert "nan" in err.value.args[0] # Case where there are values outside of the expected range and NaN: with pytest.raises(ValueError) as err: image_util.check_array_range(image_nan_inf, (1, 2)) assert "within [1, 2]" in err.value.args[0] assert "nan, inf, 3.0" in err.value.args[0] # Degenerate interval with a single value single_value = np.array([2, 2]) image_util.check_array_range(single_value, (2, 2)) with pytest.raises(ValueError) as err: image_util.check_array_range(single_value, (3, 3)) assert "within [3, 3]" in err.value.args[0] assert "2" in err.value.args[0]
def predict_whole_image(self, image_channels: np.ndarray, voxel_spacing_mm: TupleFloat3, mask: Optional[np.ndarray] = None, patient_id: int = 0) -> InferencePipeline.Result: """ Performs a single inference pass through the pipeline for the provided image :param image_channels: The input image channels to perform inference on in format: Channels x Z x Y x X. :param voxel_spacing_mm: Voxel spacing to use for each dimension in (Z x Y x X) order :param mask: A binary image used to ignore results outside it in format: Z x Y x X. :param patient_id: The identifier of the patient this image belongs to (defaults to 0 if None provided). :return InferenceResult: that contains Segmentation for each of the classes and their posterior probabilities. """ if image_channels is None: raise Exception("image_channels cannot be None") if image_channels.ndim != 4: raise NotImplementedError( "image_channels must be in shape: Channels x Z x Y x X" "found image_channels shape: {}".format(image_channels.shape)) if mask is not None: ml_util.check_size_matches(image_channels, mask, 4, 3, [-1, -2, -3]) self.model.eval() image = tio.ScalarImage(tensor=image_channels) INPUT = 'input_image' MASK = 'mask' subject_dict: Dict[str, tio.Image] = {INPUT: image} if mask is not None: subject_dict[MASK] = tio.LabelMap(tensor=mask[np.newaxis]) subject = tio.Subject(subject_dict) constraints = self.model.model.crop_size_constraints # Make sure the image size is compatible with the model multiple_constraints = constraints.multiple_of # type: ignore if multiple_constraints is not None: ensure_shape_multiple = tio.EnsureShapeMultiple( constraints.multiple_of) # type: ignore subject = ensure_shape_multiple(subject) # type: ignore # There may be cases where the test image is smaller than the test_crop_size. Adjust crop_size # to always fit into image. If test_crop_size is smaller than the image, crop will remain unchanged. restrict_patch_size = constraints.restrict_crop_size_to_image # type: ignore effective_patch_size, effective_stride = restrict_patch_size( subject.spatial_shape, # type: ignore self.model_config.test_crop_size, self.model_config.inference_stride_size) patch_overlap = np.array(effective_patch_size) - np.array( effective_stride) grid_sampler = tio.inference.GridSampler( subject, effective_patch_size, patch_overlap, padding_mode=self.model_config.padding_mode.value, ) batch_size = self.model_config.inference_batch_size patch_loader = torch.utils.data.DataLoader( grid_sampler, batch_size=batch_size) # type: ignore aggregator = tio.inference.GridAggregator(grid_sampler) logging.debug( f"Inference on image size {subject.spatial_shape} will run " f"with crop size {effective_patch_size} and stride {effective_stride}" ) for patches_batch in patch_loader: input_tensor = patches_batch[INPUT][tio.DATA].float() if self.model_config.use_gpu: input_tensor = input_tensor.cuda() locations = patches_batch[tio.LOCATION] # perform the forward pass patches_posteriors = self.model(input_tensor).detach() # pad posteriors if they are smaller than the input input_shape = input_tensor.shape[-3:] patches_posteriors_shape = patches_posteriors.shape[-3:] if input_shape != patches_posteriors_shape: difference = np.array(input_shape) - np.array( patches_posteriors_shape) assert not np.any( difference % 2) # the differences in shape are expected to be even padding = tuple(np.repeat(difference // 2, 2)) patches_posteriors = torch.nn.functional.pad( patches_posteriors, padding) # collect the predictions over each of the batches aggregator.add_batch(patches_posteriors, locations) posteriors = aggregator.get_output_tensor().numpy() posteriors_mask = None if mask is None else subject[MASK].numpy()[0] posteriors, segmentation = self.post_process_posteriors( posteriors, mask=posteriors_mask) image_util.check_array_range(posteriors, error_prefix="Whole image posteriors") # Make sure the final shape matches the input shape by undoing the padding in EnsureShapeMultiple (if any) posteriors_image = tio.ScalarImage(tensor=posteriors, affine=image.affine) segmentation_image = tio.LabelMap(tensor=segmentation[np.newaxis], affine=image.affine) subject.add_image(posteriors_image, 'posteriors') subject.add_image(segmentation_image, 'segmentation') # Remove some images to avoid unnecessary computations subject.remove_image(INPUT) if mask is not None: subject.remove_image(MASK) subject_original_space = subject.apply_inverse_transform( ) if subject.applied_transforms else subject posteriors = subject_original_space.posteriors.numpy() # type: ignore segmentation = subject_original_space.segmentation.numpy()[ 0] # type: ignore # prepare pipeline results from the processed batch return InferencePipeline.Result(patient_id=patient_id, segmentation=segmentation, posteriors=posteriors, voxel_spacing_mm=voxel_spacing_mm)