Exemplo n.º 1
0
def test_invalid_number_of_cv_files() -> None:
    """
    Test that an error is raised if the expected number of cross validation fold
    is not equal to the number of results files provided.
    """
    files, plotting_config = load_result_files_for_classification()
    plotting_config.number_of_cross_validation_splits = 4
    print(f"Writing aggregated metrics to {plotting_config.outputs_directory}")
    root_folder = Path(plotting_config.outputs_directory)
    with pytest.raises(ValueError):
        plot_cross_validation_from_files(OfflineCrossvalConfigAndFiles(
            config=plotting_config, files=files),
                                         root_folder=root_folder)
Exemplo n.º 2
0
def test_check_result_file_counts() -> None:
    """
    More tests on the function that checks the number of files of each ModeExecutionMode.
    """
    val_files, plotting_config = load_result_files_for_classification()
    # This test assumes that the loaded val_files all have mode Val
    assert all(file.execution_mode == ModelExecutionMode.VAL for file in val_files)
    plotting_config.number_of_cross_validation_splits = len(val_files)
    # Check that when just the Val files are present, the check does not throw
    config_and_files1 = OfflineCrossvalConfigAndFiles(config=plotting_config, files=val_files)
    check_result_file_counts(config_and_files1)
    # Check that when we add the same number of Test files, the check does not throw
    test_files = [RunResultFiles(execution_mode=ModelExecutionMode.TEST,
                                 metrics_file=file.metrics_file,
                                 dataset_csv_file=file.dataset_csv_file,
                                 run_recovery_id=file.run_recovery_id,
                                 split_index=file.split_index) for file in val_files]
    config_and_files2 = OfflineCrossvalConfigAndFiles(config=plotting_config, files=val_files + test_files)
    check_result_file_counts(config_and_files2)
    # Check that when we have the same number of files as the number of splits, but they are from a mixture
    # of modes, the check does throw
    config_and_files3 = OfflineCrossvalConfigAndFiles(config=plotting_config, files=val_files[:1] + test_files[1:])
    with pytest.raises(ValueError):
        check_result_file_counts(config_and_files3)
Exemplo n.º 3
0
def test_aggregate_files_with_prediction_target(test_output_dirs: TestOutputDirectories) -> None:
    """
    For multi-week RNNs that predict at multiple sequence points: Test that the dataframes
    including the prediction_target column can be aggregated.
    """
    plotting_config = PlotCrossValidationConfig(
        run_recovery_id="foo",
        epoch=1,
        model_category=ModelCategory.Classification
    )
    files = create_run_result_file_list(plotting_config, "multi_label_sequence_in_crossval")

    root_folder = Path(test_output_dirs.root_dir)
    print(f"Writing result files to {root_folder}")
    plot_cross_validation_from_files(OfflineCrossvalConfigAndFiles(config=plotting_config, files=files),
                                     root_folder=root_folder)
Exemplo n.º 4
0
def _test_result_aggregation_for_classification(
        files: List[RunResultFiles],
        plotting_config: PlotCrossValidationConfig,
        expected_aggregate_metrics: List[str],
        expected_epochs: Set[int]) -> None:
    """
    Test how metrics are aggregated for cross validation runs on classification models.
    """
    print(f"Writing aggregated metrics to {plotting_config.outputs_directory}")
    root_folder = plotting_config.outputs_directory
    plot_cross_validation_from_files(OfflineCrossvalConfigAndFiles(
        config=plotting_config, files=files),
                                     root_folder=root_folder)
    aggregates_file = root_folder / METRICS_AGGREGATES_FILE
    actual_aggregates = aggregates_file.read_text().splitlines()
    header_line = "prediction_target,area_under_roc_curve,area_under_pr_curve,accuracy_at_optimal_threshold," \
                  "false_positive_rate_at_optimal_threshold,false_negative_rate_at_optimal_threshold," \
                  "optimal_threshold,cross_entropy,accuracy_at_threshold_05,subject_count,data_split,epoch"
    expected_aggregate_metrics = [header_line] + expected_aggregate_metrics
    assert len(actual_aggregates) == len(
        expected_aggregate_metrics
    ), "Number of lines in aggregated metrics file"
    for i, (actual, expected) in enumerate(
            zip(actual_aggregates, expected_aggregate_metrics)):
        assert actual == expected, f"Mismatch in aggregate metrics at index {i}"
    per_subject_metrics = pd.read_csv(root_folder /
                                      FULL_METRICS_DATAFRAME_FILE)
    assert LoggingColumns.Label.value in per_subject_metrics
    assert set(per_subject_metrics[LoggingColumns.Label.value].unique()) == {
        0.0, 1.0
    }
    assert LoggingColumns.ModelOutput.value in per_subject_metrics
    assert LoggingColumns.Patient.value in per_subject_metrics
    assert len(
        per_subject_metrics[LoggingColumns.Patient.value].unique()) == 356
    assert LoggingColumns.Epoch.value in per_subject_metrics
    assert set(per_subject_metrics[
        LoggingColumns.Epoch.value].unique()) == expected_epochs
    assert LoggingColumns.CrossValidationSplitIndex.value in per_subject_metrics
    assert set(per_subject_metrics[
        LoggingColumns.CrossValidationSplitIndex.value].unique()) == {0, 1}
    assert LoggingColumns.DataSplit.value in per_subject_metrics
    assert per_subject_metrics[LoggingColumns.DataSplit.value].unique() == [
        "Val"
    ]