def run_inference(images_with_header: List[ImageWithHeader], inference_pipeline: FullImageInferencePipelineBase, config: SegmentationModelBase) -> np.ndarray: """ Runs inference on a list of channels and given a config and inference pipeline :param images_with_header: :param inference_pipeline: :param config: :return: segmentation """ # Check the image has the correct spacing if config.dataset_expected_spacing_xyz: for image_with_header in images_with_header: spacing_xyz = reverse_tuple_float3( image_with_header.header.spacing) if not is_spacing_valid(spacing_xyz, config.dataset_expected_spacing_xyz): raise ValueError( f'Input image has spacing {spacing_xyz} ' f'but expected {config.dataset_expected_spacing_xyz}') # Photo norm photo_norm = PhotometricNormalization(config_args=config) photo_norm_images = [ photo_norm.transform(image_with_header.image) for image_with_header in images_with_header ] segmentation = inference_pipeline.predict_and_post_process_whole_image( image_channels=np.array(photo_norm_images), voxel_spacing_mm=images_with_header[0].header.spacing).segmentation return segmentation
def test_plot_normalization_result( test_output_dirs: TestOutputDirectories) -> None: """ Tests plotting of before/after histograms in photometric normalization. :return: """ size = (3, 3, 3) image = np.zeros((1, ) + size) for i, (z, y, x) in enumerate( itertools.product(range(size[0]), range(size[1]), range(size[2]))): image[0, z, y, x] = i labels = np.zeros((2, ) + size) labels[1, 1, 1, 1] = 1 sample = Sample(image=image, labels=labels, mask=np.ones(size), metadata=DummyPatientMetadata) config = SegmentationModelBase( norm_method=PhotometricNormalizationMethod.CtWindow, window=4, level=13, should_validate=False) normalizer = PhotometricNormalization(config) folder = Path(test_output_dirs.root_dir) files = plotting.plot_normalization_result(sample, normalizer, folder) expected = ["042_slice_001.png", "042_slice_001_contour.png"] compare_files(files, expected)
def load_train_and_test_data_channels( patient_ids: List[int], normalization_fn: PhotometricNormalization) -> List[Sample]: if np.any(np.asarray(patient_ids) <= 0): raise ValueError("data_items must be >= 0") file_name = lambda k, y: full_ml_test_data_path("train_and_test_data" ) / f"id{k}_{y}.nii.gz" get_sample = lambda z: io_util.load_images_from_dataset_source( dataset_source=PatientDatasetSource( metadata=PatientMetadata(patient_id=z), image_channels=[file_name(z, c) for c in TEST_CHANNEL_IDS], mask_channel=file_name(z, TEST_MASK_ID), ground_truth_channels=[file_name(z, TEST_GT_ID)])) samples = [] for x in patient_ids: sample = get_sample(x) sample = Sample(image=normalization_fn.transform( sample.image, sample.mask), mask=sample.mask, labels=sample.labels, metadata=sample.metadata) samples.append(sample) return samples
def main(yaml_file_path: Path) -> None: """ Invoke either by * specifying a model, '--model Lung' * or specifying dataset and normalization parameters separately: --azure_dataset_id=foo --norm_method=None In addition, the arguments '--image_channel' and '--gt_channel' must be specified (see below). """ config, runner_config, args = get_configs( SegmentationModelBase(should_validate=False), yaml_file_path) dataset_config = DatasetConfig(name=config.azure_dataset_id, local_folder=config.local_dataset, use_mounting=True) local_dataset, mount_context = dataset_config.to_input_dataset_local( workspace=runner_config.get_workspace()) dataframe = pd.read_csv(local_dataset / DATASET_CSV_FILE_NAME) normalizer_config = NormalizeAndVisualizeConfig(**args) actual_mask_channel = None if normalizer_config.ignore_mask else config.mask_id image_channel = normalizer_config.image_channel or config.image_channels[0] if not image_channel: raise ValueError( "No image channel selected. Specify a model by name, or use the image_channel argument." ) gt_channel = normalizer_config.gt_channel or config.ground_truth_ids[0] if not gt_channel: raise ValueError( "No GT channel selected. Specify a model by name, or use the gt_channel argument." ) dataset_sources = load_dataset_sources( dataframe, local_dataset_root_folder=local_dataset, image_channels=[image_channel], ground_truth_channels=[gt_channel], mask_channel=actual_mask_channel) result_folder = local_dataset if normalizer_config.result_folder is not None: result_folder = result_folder / normalizer_config.result_folder if not result_folder.is_dir(): result_folder.mkdir() all_patient_ids = [*dataset_sources.keys()] if normalizer_config.only_first == 0: patient_ids_to_process = all_patient_ids else: patient_ids_to_process = all_patient_ids[:normalizer_config.only_first] args_file = result_folder / ARGS_TXT args_file.write_text(" ".join(sys.argv[1:])) config_file = result_folder / "config.txt" config_file.write_text(str(config)) normalizer = PhotometricNormalization(config) for patient_id in patient_ids_to_process: print(f"Starting to process patient {patient_id}") images = load_images_from_dataset_source(dataset_sources[patient_id]) plotting.plot_normalization_result(images, normalizer, result_folder, result_prefix=image_channel)
def get_full_image_sample_transforms(self) -> ModelTransformsPerExecutionMode: """ Get transforms to perform on full image samples for each model execution mode. By default only PhotometricNormalization is performed. """ from InnerEye.ML.utils.transforms import Compose3D from InnerEye.ML.photometric_normalization import PhotometricNormalization photometric_transformation = Compose3D(transforms=[PhotometricNormalization(self, use_gpu=False)]) return ModelTransformsPerExecutionMode(train=photometric_transformation, val=photometric_transformation, test=photometric_transformation)
def plot_normalization_result(loaded_images: Sample, normalizer: PhotometricNormalization, result_folder: Path, result_prefix: str = "", image_range: Optional[TupleFloat2] = None, channel_index: int = 0, class_index: int = 1, contour_file_suffix: str = "") -> List[Path]: """ Creates two PNG plots that summarize the result of photometric normalization of the first channel in the sample image. The first plot contains pixel value histograms before and after photometric normalization. The second plot contains the normalized image, overlayed with contours for the foreground pixels, at the slice where the foreground has most pixels. :param loaded_images: An instance of Sample with the image and the labels. The first channel of the image will be plotted. :param image_range: The image value range that will be mapped to the color map. If None, the full image range will be mapped to the colormap. :param normalizer: The photometric normalization that should be applied. :param result_folder: The folder into which the resulting PNG files should be written. :param result_prefix: The prefix for all output filenames. :param channel_index: Compute normalization results for this channel. :param class_index: When plotting image/contour overlays, use this class. :param contour_file_suffix: Use this suffix for the file name that contains the image/contour overlay. :return: The paths of the two PNG files that the function writes. """ # Labels are encoded with background and a single foreground class. We need the # slice with largest number of foreground voxels ground_truth = loaded_images.labels[class_index, ...] largest_gt_slice = get_largest_z_slice(ground_truth) first_channel = loaded_images.image[channel_index, ...] filename_stem = f"{result_prefix}{loaded_images.patient_id:03d}_slice_{largest_gt_slice:03d}" normalized_image = normalizer.transform(loaded_images.image, loaded_images.mask)[channel_index, ...] before_after_plot = \ plot_before_after_statistics(first_channel, normalized_image, loaded_images.mask, z_slice=largest_gt_slice, normalizer_status_message=normalizer.status_of_most_recent_call, plot_file_name=result_folder / filename_stem) image_contour_plot = \ plot_image_and_label_contour(image=normalized_image[largest_gt_slice, ...], labels=ground_truth[largest_gt_slice, ...], contour_arguments={'colors': 'r'}, image_range=image_range, plot_file_name=result_folder / f"{filename_stem}_contour{contour_file_suffix}") return [before_after_plot, image_contour_plot]
def normalize_fn( default_config: SegmentationModelBase) -> PhotometricNormalization: return PhotometricNormalization(default_config)