示例#1
0
def test_metrics_file(test_output_dirs: TestOutputDirectories) -> None:
    """Test if metrics files with Dice scores are written as expected."""
    folder = test_output_dirs.make_sub_dir("test_metrics_file")

    def new_file(suffix: str) -> str:
        file = os.path.join(folder, suffix)
        if os.path.exists(file):
            os.remove(file)
        return file

    d = MetricsPerPatientWriter()
    p1 = "Patient1"
    p2 = "Patient2"
    p3 = "Patient3"
    liver = "liver"
    kidney = "kidney"
    # Ordering for test data: For "liver", patient 2 has the lowest score, sorting should move them first
    # For "kidney", patient 1 has the lowest score and should be first.
    d.add(p1, liver, 1.0, 1.0, 0.5)
    d.add(p1, liver, 0.4, 1.0, 0.4)
    d.add(p2, liver, 0.8, 1.0, 0.3)
    d.add(p2, kidney, 0.7, 1.0, 0.2)
    d.add(p3, kidney, 0.4, 1.0, 0.1)
    metrics_file = new_file("metrics_file.csv")
    d.to_csv(Path(metrics_file))
    # Sorting should be first by structure name alphabetically, then Dice with lowest scores first.
    assert_file_contents(
        metrics_file,
        "Patient,Structure,Dice,HausdorffDistance_mm,MeanDistance_mm\n"
        "Patient3,kidney,0.400,1.000,0.100\n"
        "Patient2,kidney,0.700,1.000,0.200\n"
        "Patient1,liver,0.400,1.000,0.400\n"
        "Patient2,liver,0.800,1.000,0.300\n"
        "Patient1,liver,1.000,1.000,0.500\n")
    aggregates_file = new_file(METRICS_AGGREGATES_FILE)
    d.save_aggregates_to_csv(Path(aggregates_file))
    # Sorting should be first by structure name alphabetically, then Dice with lowest scores first.
    assert_file_contents_match_exactly(
        Path(aggregates_file),
        full_ml_test_data_path() / METRICS_AGGREGATES_FILE)
    boxplot_per_structure(d.to_data_frame(),
                          column_name=MetricsFileColumns.DiceNumeric.value,
                          title="Dice score")
    boxplot1 = new_file("boxplot_2class.png")
    resize_and_save(5, 4, boxplot1)
    plt.clf()
    d.add(p1, "lung", 0.5, 2.0, 1.0)
    d.add(p1, "foo", 0.9, 2.0, 1.0)
    d.add(p1, "bar", 0.9, 2.0, 1.0)
    d.add(p1, "baz", 0.9, 2.0, 1.0)
    boxplot_per_structure(d.to_data_frame(),
                          column_name=MetricsFileColumns.DiceNumeric.value,
                          title="Dice score")
    boxplot2 = new_file("boxplot_6class.png")
    resize_and_save(5, 4, boxplot2)
示例#2
0
def segmentation_model_test_epoch(
        config: SegmentationModelBase,
        data_split: ModelExecutionMode,
        test_epoch: int,
        results_folder: Path,
        epoch_and_split: str,
        run_recovery: Optional[RunRecovery] = None) -> Optional[List[float]]:
    """
    The main testing loop for a given epoch. It loads the model and datasets, then proceeds to test the model.
    Returns a list with an entry for each image in the dataset. The entry is the average Dice score,
    where the average is taken across all non-background structures in the image.
    :param test_epoch: The last trained epoch of the model.
    :param config: The arguments which specify all required information.
    :param data_split: Is the model evaluated on train, test, or validation set?
    :param results_folder: The folder where to store the results
    :param epoch_and_split: A string that should uniquely identify the epoch and the data split (train/val/test).
    :param run_recovery: Run recovery data if applicable.
    :raises TypeError: If the arguments are of the wrong type.
    :raises ValueError: When there are issues loading the model.
    :return A list with the mean dice score (across all structures apart from background) for each image.
    """
    ml_util.set_random_seed(config.get_effective_random_seed(),
                            "Model Training")
    results_folder = Path(results_folder)
    results_folder.mkdir(exist_ok=True)

    test_dataframe = config.get_dataset_splits()[data_split]
    test_csv_path = results_folder / STORED_CSV_FILE_NAMES[data_split]
    test_dataframe.to_csv(path_or_buf=test_csv_path, index=False)
    logging.info("Results directory: {}".format(results_folder))
    logging.info(
        f"Starting evaluation of model {config.model_name} on {epoch_and_split}"
    )

    # Write the dataset id and ground truth ids into the results folder
    store_run_information(results_folder, config.azure_dataset_id,
                          config.ground_truth_ids, config.image_channels)

    ds = config.get_torch_dataset_for_inference(data_split)

    inference_pipeline = create_inference_pipeline(config=config,
                                                   epoch=test_epoch,
                                                   run_recovery=run_recovery)

    if inference_pipeline is None:
        # This will happen if there is no checkpoint for the given epoch, in either the recovered run (if any) or
        # the current one.
        return None

    # for mypy
    assert isinstance(inference_pipeline, FullImageInferencePipelineBase)

    # Deploy the trained model on a set of images and store output arrays.
    for sample_index, sample in enumerate(ds, 1):
        logging.info(f"Predicting for image {sample_index} of {len(ds)}...")
        sample = Sample.from_dict(sample=sample)
        inference_result = inference_pipeline.predict_and_post_process_whole_image(
            image_channels=sample.image,
            mask=sample.mask,
            patient_id=sample.patient_id,
            voxel_spacing_mm=sample.metadata.image_header.spacing)
        store_inference_results(inference_result=inference_result,
                                config=config,
                                results_folder=results_folder,
                                image_header=sample.metadata.image_header)

    # Evaluate model generated segmentation maps.
    num_workers = min(cpu_count(), len(ds))
    with Pool(processes=num_workers) as pool:
        pool_outputs = pool.map(
            partial(evaluate_model_predictions,
                    config=config,
                    dataset=ds,
                    results_folder=results_folder), range(len(ds)))

    average_dice = list()
    metrics_writer = MetricsPerPatientWriter()
    for (patient_metadata, metrics_for_patient) in pool_outputs:
        # Add the Dice score for the foreground classes, stored in the default hue
        metrics.add_average_foreground_dice(metrics_for_patient)
        average_dice.append(
            metrics_for_patient.get_single_metric(MetricType.DICE))
        # Structure names does not include the background class (index 0)
        for structure_name in config.ground_truth_ids:
            dice_for_struct = metrics_for_patient.get_single_metric(
                MetricType.DICE, hue=structure_name)
            hd_for_struct = metrics_for_patient.get_single_metric(
                MetricType.HAUSDORFF_mm, hue=structure_name)
            md_for_struct = metrics_for_patient.get_single_metric(
                MetricType.MEAN_SURFACE_DIST_mm, hue=structure_name)
            metrics_writer.add(patient=str(patient_metadata.patient_id),
                               structure=structure_name,
                               dice=dice_for_struct,
                               hausdorff_distance_mm=hd_for_struct,
                               mean_distance_mm=md_for_struct)

    metrics_writer.to_csv(results_folder / METRICS_FILE_NAME)
    metrics_writer.save_aggregates_to_csv(results_folder /
                                          METRICS_AGGREGATES_FILE)
    if config.is_plotting_enabled:
        plt.figure()
        boxplot_per_structure(metrics_writer.to_data_frame(),
                              column_name=MetricsFileColumns.DiceNumeric.value,
                              title=f"Dice score for {epoch_and_split}")
        # The box plot file will be written to the output directory. AzureML will pick that up, and display
        # on the run overview page, without having to log to the run context.
        plotting.resize_and_save(5, 4, results_folder / BOXPLOT_FILE)
        plt.close()
    logging.info(
        f"Finished evaluation of model {config.model_name} on {epoch_and_split}"
    )

    return average_dice