Beispiel #1
0
def full_image_dataset(default_config: SegmentationModelBase,
                       normalize_fn: Callable) -> FullImageDataset:
    df = default_config.get_dataset_splits()
    return FullImageDataset(
        args=default_config,
        full_image_sample_transforms=Compose3D([normalize_fn]),  # type: ignore
        data_frame=df.train)
Beispiel #2
0
def visualize_random_crops_for_dataset(
        config: SegmentationModelBase,
        output_folder: Optional[Path] = None) -> None:
    """
    For segmentation models only: This function generates visualizations of the effect of sampling random patches
    for training. Visualizations are stored in both Nifti format, and as 3 PNG thumbnail files, in the output folder.
    :param config: The model configuration.
    :param output_folder: The folder in which the visualizations should be written. If not provided, use a subfolder
    "patch_sampling" in the models's default output folder
    """
    dataset_splits = config.get_dataset_splits()
    # Load a sample using the full image data loader
    full_image_dataset = FullImageDataset(config, dataset_splits.train)
    output_folder = output_folder or config.outputs_folder / PATCH_SAMPLING_FOLDER
    count = min(config.show_patch_sampling, len(full_image_dataset))
    for sample_index in range(count):
        sample = full_image_dataset.get_samples_at_index(index=sample_index)[0]
        visualize_random_crops(sample, config, output_folder=output_folder)
Beispiel #3
0
def cropping_dataset(default_config: SegmentationModelBase,
                     normalize_fn: Callable) -> CroppingDataset:
    df = default_config.get_dataset_splits()
    return CroppingDataset(args=default_config, data_frame=df.train)
Beispiel #4
0
def segmentation_model_test_epoch(
        config: SegmentationModelBase,
        data_split: ModelExecutionMode,
        test_epoch: int,
        results_folder: Path,
        epoch_and_split: str,
        run_recovery: Optional[RunRecovery] = None) -> Optional[List[float]]:
    """
    The main testing loop for a given epoch. It loads the model and datasets, then proceeds to test the model.
    Returns a list with an entry for each image in the dataset. The entry is the average Dice score,
    where the average is taken across all non-background structures in the image.
    :param test_epoch: The last trained epoch of the model.
    :param config: The arguments which specify all required information.
    :param data_split: Is the model evaluated on train, test, or validation set?
    :param results_folder: The folder where to store the results
    :param epoch_and_split: A string that should uniquely identify the epoch and the data split (train/val/test).
    :param run_recovery: Run recovery data if applicable.
    :raises TypeError: If the arguments are of the wrong type.
    :raises ValueError: When there are issues loading the model.
    :return A list with the mean dice score (across all structures apart from background) for each image.
    """
    ml_util.set_random_seed(config.get_effective_random_seed(),
                            "Model Training")
    results_folder = Path(results_folder)
    results_folder.mkdir(exist_ok=True)

    test_dataframe = config.get_dataset_splits()[data_split]
    test_csv_path = results_folder / STORED_CSV_FILE_NAMES[data_split]
    test_dataframe.to_csv(path_or_buf=test_csv_path, index=False)
    logging.info("Results directory: {}".format(results_folder))
    logging.info(
        f"Starting evaluation of model {config.model_name} on {epoch_and_split}"
    )

    # Write the dataset id and ground truth ids into the results folder
    store_run_information(results_folder, config.azure_dataset_id,
                          config.ground_truth_ids, config.image_channels)

    ds = config.get_torch_dataset_for_inference(data_split)

    inference_pipeline = create_inference_pipeline(config=config,
                                                   epoch=test_epoch,
                                                   run_recovery=run_recovery)

    if inference_pipeline is None:
        # This will happen if there is no checkpoint for the given epoch, in either the recovered run (if any) or
        # the current one.
        return None

    # for mypy
    assert isinstance(inference_pipeline, FullImageInferencePipelineBase)

    # Deploy the trained model on a set of images and store output arrays.
    for sample_index, sample in enumerate(ds, 1):
        logging.info(f"Predicting for image {sample_index} of {len(ds)}...")
        sample = Sample.from_dict(sample=sample)
        inference_result = inference_pipeline.predict_and_post_process_whole_image(
            image_channels=sample.image,
            mask=sample.mask,
            patient_id=sample.patient_id,
            voxel_spacing_mm=sample.metadata.image_header.spacing)
        store_inference_results(inference_result=inference_result,
                                config=config,
                                results_folder=results_folder,
                                image_header=sample.metadata.image_header)

    # Evaluate model generated segmentation maps.
    num_workers = min(cpu_count(), len(ds))
    with Pool(processes=num_workers) as pool:
        pool_outputs = pool.map(
            partial(evaluate_model_predictions,
                    config=config,
                    dataset=ds,
                    results_folder=results_folder), range(len(ds)))

    average_dice = list()
    metrics_writer = MetricsPerPatientWriter()
    for (patient_metadata, metrics_for_patient) in pool_outputs:
        # Add the Dice score for the foreground classes, stored in the default hue
        metrics.add_average_foreground_dice(metrics_for_patient)
        average_dice.append(
            metrics_for_patient.get_single_metric(MetricType.DICE))
        # Structure names does not include the background class (index 0)
        for structure_name in config.ground_truth_ids:
            dice_for_struct = metrics_for_patient.get_single_metric(
                MetricType.DICE, hue=structure_name)
            hd_for_struct = metrics_for_patient.get_single_metric(
                MetricType.HAUSDORFF_mm, hue=structure_name)
            md_for_struct = metrics_for_patient.get_single_metric(
                MetricType.MEAN_SURFACE_DIST_mm, hue=structure_name)
            metrics_writer.add(patient=str(patient_metadata.patient_id),
                               structure=structure_name,
                               dice=dice_for_struct,
                               hausdorff_distance_mm=hd_for_struct,
                               mean_distance_mm=md_for_struct)

    metrics_writer.to_csv(results_folder / METRICS_FILE_NAME)
    metrics_writer.save_aggregates_to_csv(results_folder /
                                          METRICS_AGGREGATES_FILE)
    if config.is_plotting_enabled:
        plt.figure()
        boxplot_per_structure(metrics_writer.to_data_frame(),
                              column_name=MetricsFileColumns.DiceNumeric.value,
                              title=f"Dice score for {epoch_and_split}")
        # The box plot file will be written to the output directory. AzureML will pick that up, and display
        # on the run overview page, without having to log to the run context.
        plotting.resize_and_save(5, 4, results_folder / BOXPLOT_FILE)
        plt.close()
    logging.info(
        f"Finished evaluation of model {config.model_name} on {epoch_and_split}"
    )

    return average_dice