def test_functions_with_invalid_csv(
        test_output_dirs: OutputFolderForTests) -> None:
    reports_folder = Path(__file__).parent
    test_metrics_file = reports_folder / "test_metrics_classification.csv"
    val_metrics_file = reports_folder / "val_metrics_classification.csv"
    invalid_metrics_file = Path(
        test_output_dirs.root_dir) / "invalid_metrics_classification.csv"
    shutil.copyfile(test_metrics_file, invalid_metrics_file)
    # Duplicate a subject
    with open(invalid_metrics_file, "a") as file:
        file.write(f"{MetricsDict.DEFAULT_HUE_KEY},0,5,1.0,1,-1,Test")
    with pytest.raises(ValueError) as ex:
        get_labels_and_predictions(invalid_metrics_file,
                                   MetricsDict.DEFAULT_HUE_KEY)
    assert "Subject IDs should be unique" in str(ex)

    with pytest.raises(ValueError) as ex:
        get_correct_and_misclassified_examples(invalid_metrics_file,
                                               test_metrics_file,
                                               MetricsDict.DEFAULT_HUE_KEY)
    assert "Subject IDs should be unique" in str(ex)

    with pytest.raises(ValueError) as ex:
        get_correct_and_misclassified_examples(val_metrics_file,
                                               invalid_metrics_file,
                                               MetricsDict.DEFAULT_HUE_KEY)
    assert "Subject IDs should be unique" in str(ex)
Ejemplo n.º 2
0
def test_get_correct_and_misclassified_examples() -> None:
    reports_folder = Path(__file__).parent
    test_metrics_file = reports_folder / "test_metrics_classification.csv"
    val_metrics_file = reports_folder / "val_metrics_classification.csv"

    results = get_correct_and_misclassified_examples(
        val_metrics_csv=val_metrics_file, test_metrics_csv=test_metrics_file)

    true_positives = [
        item[LoggingColumns.Patient.value]
        for _, item in results.true_positives.iterrows()
    ]
    assert all([i in true_positives for i in [3, 4, 5]])

    true_negatives = [
        item[LoggingColumns.Patient.value]
        for _, item in results.true_negatives.iterrows()
    ]
    assert all([i in true_negatives for i in [6, 7, 8]])

    false_positives = [
        item[LoggingColumns.Patient.value]
        for _, item in results.false_positives.iterrows()
    ]
    assert all([i in false_positives for i in [9, 10, 11]])

    false_negatives = [
        item[LoggingColumns.Patient.value]
        for _, item in results.false_negatives.iterrows()
    ]
    assert all([i in false_negatives for i in [0, 1, 2]])