Пример #1
0
def get_correct_and_misclassified_examples(val_metrics_csv: Path,
                                           test_metrics_csv: Path) -> Results:
    """
    Given the paths to the metrics files for the validation and test sets, get a list of true positives,
    false positives, false negatives and true negatives.
    The threshold for classification is obtained by looking at the validation file, and applied to the test set to get
    label predictions.
    """
    df_val = pd.read_csv(val_metrics_csv)

    if not df_val[LoggingColumns.Patient.value].is_unique:
        raise ValueError(
            f"Subject IDs should be unique, but found duplicate entries "
            f"in column {LoggingColumns.Patient.value} in the csv file.")

    fpr, tpr, thresholds = roc_curve(df_val[LoggingColumns.Label.value],
                                     df_val[LoggingColumns.ModelOutput.value])
    optimal_idx = MetricsDict.get_optimal_idx(fpr=fpr, tpr=tpr)
    optimal_threshold = thresholds[optimal_idx]

    df_test = pd.read_csv(test_metrics_csv)

    if not df_test[LoggingColumns.Patient.value].is_unique:
        raise ValueError(
            f"Subject IDs should be unique, but found duplicate entries "
            f"in column {LoggingColumns.Patient.value} in the csv file.")

    df_test["predicted"] = df_test.apply(lambda x: int(x[
        LoggingColumns.ModelOutput.value] >= optimal_threshold),
                                         axis=1)

    true_positives = df_test[(df_test["predicted"] == 1)
                             & (df_test[LoggingColumns.Label.value] == 1)]
    false_positives = df_test[(df_test["predicted"] == 1)
                              & (df_test[LoggingColumns.Label.value] == 0)]
    false_negatives = df_test[(df_test["predicted"] == 0)
                              & (df_test[LoggingColumns.Label.value] == 1)]
    true_negatives = df_test[(df_test["predicted"] == 0)
                             & (df_test[LoggingColumns.Label.value] == 0)]

    return Results(true_positives=true_positives,
                   true_negatives=true_negatives,
                   false_positives=false_positives,
                   false_negatives=false_negatives)
Пример #2
0
def get_metric(val_metrics_csv: Path, test_metrics_csv: Path,
               metric: ReportedMetrics) -> float:
    """
    Given a csv file, read the predicted values and ground truth labels and return the specified metric.
    """
    results_val = get_results(val_metrics_csv)
    fpr, tpr, thresholds = roc_curve(results_val.labels,
                                     results_val.model_outputs)
    optimal_idx = MetricsDict.get_optimal_idx(fpr=fpr, tpr=tpr)
    optimal_threshold = thresholds[optimal_idx]

    if metric is ReportedMetrics.OptimalThreshold:
        return optimal_threshold

    results_test = get_results(test_metrics_csv)

    if metric is ReportedMetrics.AUC_ROC:
        return roc_auc_score(results_test.labels, results_test.model_outputs)
    elif metric is ReportedMetrics.AUC_PR:
        precision, recall, _ = precision_recall_curve(
            results_test.labels, results_test.model_outputs)
        return auc(recall, precision)
    elif metric is ReportedMetrics.Accuracy:
        return binary_classification_accuracy(
            model_output=results_test.model_outputs,
            label=results_test.labels,
            threshold=optimal_threshold)
    elif metric is ReportedMetrics.FalsePositiveRate:
        tnr = recall_score(results_test.labels,
                           results_test.model_outputs >= optimal_threshold,
                           pos_label=0)
        return 1 - tnr
    elif metric is ReportedMetrics.FalseNegativeRate:
        return 1 - recall_score(
            results_test.labels,
            results_test.model_outputs >= optimal_threshold)
    else:
        raise ValueError("Unknown metric")