def binary_classification_metrics(preds, probs, labels, multilabel):
    acc = simple_accuracy(preds, labels).get("acc")
    roc_auc_score = roc_auc(probs=probs, labels=labels, multilabel=multilabel)
    roc_auc_score_weighted = roc_auc(probs=probs,
                                     labels=labels,
                                     average="weighted",
                                     multilabel=multilabel)
    f1macro = f1_score(y_true=labels, y_pred=preds, average="macro")
    f1micro = f1_score(y_true=labels, y_pred=preds, average="micro")

    if not multilabel:
        f1_0 = f1_score(y_true=labels, y_pred=preds, pos_label="0")
        f1_1 = f1_score(y_true=labels, y_pred=preds, pos_label="1")
        mcc = matthews_corrcoef(labels, preds)
    else:
        f1_0, f1_1, mcc = None, None, None
    return {
        "acc": acc,
        "roc_auc": roc_auc_score,
        "roc_auc_weighted": roc_auc_score_weighted,
        "f1_macro": f1macro,
        "f1_micro": f1micro,
        "f1_0": f1_0,
        "f1_1": f1_1,
        "mcc": mcc
    }
Пример #2
0
 def mymetrics(preds, labels):
     acc = simple_accuracy(preds, labels)
     f1macro = f1_score(y_true=labels, y_pred=preds, average="macro")
     f1micro = f1_score(y_true=labels, y_pred=preds, average="micro")
     # AML log
     try:
         aml_run.log('acc', acc.get('acc'))
         aml_run.log('f1macro', f1macro)
         aml_run.log('f1micro', f1micro)
     except:
         pass
     return {"acc": acc, "f1_macro": f1macro, "f1_micro": f1micro}
 def mymetrics(preds, labels):
     acc = simple_accuracy(preds, labels)
     f1other = f1_score(y_true=labels, y_pred=preds, pos_label="OTHER")
     f1offense = f1_score(y_true=labels, y_pred=preds, pos_label="OFFENSE")
     f1macro = f1_score(y_true=labels, y_pred=preds, average="macro")
     f1micro = f1_score(y_true=labels, y_pred=preds, average="macro")
     return {
         "acc": acc,
         "f1_other": f1other,
         "f1_offense": f1offense,
         "f1_macro": f1macro,
         "f1_micro": f1micro
     }
def multiclass_classification_metrics(preds, probs, labels):
    acc = simple_accuracy(preds, labels).get("acc")
    roc_auc_score = roc_auc(probs=probs, labels=labels, multi_class='ovo')
    roc_auc_score_weighted = roc_auc(probs=probs,
                                     labels=labels,
                                     average='weighted',
                                     multi_class='ovo')
    f1macro = f1_score(y_true=labels, y_pred=preds, average="macro")
    f1micro = f1_score(y_true=labels, y_pred=preds, average="micro")
    mcc = matthews_corrcoef(labels, preds)

    return {
        "acc": acc,
        "roc_auc": roc_auc_score,
        "roc_auc_weighted": roc_auc_score_weighted,
        "f1_macro": f1macro,
        "f1_micro": f1micro,
        "mcc": mcc
    }
Пример #5
0
def compute_metrics(metric, preds, probs, labels, multilabel):
    assert len(preds) == len(labels)
    if metric == "mcc":
        return {"mcc": matthews_corrcoef(labels, preds)}
    elif metric == "acc":
        return simple_accuracy(preds, labels)
    elif metric == "acc_f1":
        return acc_and_f1(preds, labels)
    elif metric == "pear_spear":
        return pearson_and_spearman(preds, labels)
    elif metric == "seq_f1":
        return {"seq_f1": ner_f1_score(labels, preds)}
    elif metric == "f1_macro":
        return f1_macro(preds, labels)
    elif metric == "squad":
        return squad(preds, labels)
    elif metric == "mse":
        return {"mse": mean_squared_error(preds, labels)}
    elif metric == "r2":
        return {"r2": r2_score(preds, labels)}
    elif metric == "top_n_accuracy":
        return {"top_n_accuracy": top_n_accuracy(preds, labels)}
    elif metric == "text_similarity_metric":
        return text_similarity_metric(preds, labels)
    elif metric == "roc_auc":
        return {"roc_auc": roc_auc(probs, labels, multilabel=multilabel)}
    elif metric in registered_metrics:
        metric_func = registered_metrics[metric]

        metric_args = inspect.getfullargspec(metric_func).args
        if "probs" and "multilabel" in metric_args:
            return metric_func(preds, probs, labels, multilabel)
        elif "probs" in metric_args:
            return metric_func(preds, probs, labels)
        elif "multilabel" in metric_args:
            return metric_func(preds, labels, multilabel)
        else:
            return metric_func(preds, labels)
    else:
        raise KeyError(metric)
        def evaluation_metrics(preds, labels):
            acc = simple_accuracy(preds, labels).get("acc")
            f1other = f1_score(y_true=labels, y_pred=preds, pos_label="Other")
            f1infoneed = f1_score(y_true=labels,
                                  y_pred=preds,
                                  pos_label=current_info_need)
            recall_infoneed = recall_score(y_true=labels,
                                           y_pred=preds,
                                           pos_label=current_info_need)
            precision_infoneed = precision_score(y_true=labels,
                                                 y_pred=preds,
                                                 pos_label=current_info_need)
            recall_other = recall_score(y_true=labels,
                                        y_pred=preds,
                                        pos_label="Other")
            precision_other = precision_score(y_true=labels,
                                              y_pred=preds,
                                              pos_label="Other")
            recall_macro = recall_score(y_true=labels,
                                        y_pred=preds,
                                        average="macro")
            precision_macro = precision_score(y_true=labels,
                                              y_pred=preds,
                                              average="macro")
            recall_micro = recall_score(y_true=labels,
                                        y_pred=preds,
                                        average="micro")
            precision_micro = precision_score(y_true=labels,
                                              y_pred=preds,
                                              average="micro")
            recall_weighted = recall_score(y_true=labels,
                                           y_pred=preds,
                                           average="weighted")
            precision_weighted = precision_score(y_true=labels,
                                                 y_pred=preds,
                                                 average="weighted")
            f1macro = f1_score(y_true=labels, y_pred=preds, average="macro")
            f1micro = f1_score(y_true=labels, y_pred=preds, average="micro")
            mcc = matthews_corrcoef(labels, preds)
            f1weighted = f1_score(y_true=labels,
                                  y_pred=preds,
                                  average="weighted")

            return {
                "info_need": current_info_need,
                "model": bert_model,
                "num_epochs": num_epochs,
                "condition": condition,
                "acc": acc,
                "f1_other": f1other,
                "f1_infoneed": f1infoneed,
                "precision_infoneed": precision_infoneed,
                "recall_infoneed": recall_infoneed,
                "recall_other": recall_other,
                "precision_other": precision_other,
                "recall_macro": recall_macro,
                "precision_macro": precision_macro,
                "recall_micro": recall_micro,
                "precision_micro": precision_micro,
                "recall_weighted": recall_weighted,
                "precision_weighted": precision_weighted,
                "f1_weighted": f1weighted,
                "f1_macro": f1macro,
                "f1_micro": f1micro,
                "f1_weighted": f1weighted,
                "mcc": mcc
            }