Esempio n. 1
0
    (metrics.FBeta(beta=.5), functools.partial(sk_metrics.fbeta_score,
                                               beta=.5)),
    (metrics.MacroFBeta(beta=.5),
     functools.partial(sk_metrics.fbeta_score, beta=.5, average='macro')),
    (metrics.MicroFBeta(beta=.5),
     functools.partial(sk_metrics.fbeta_score, beta=.5, average='micro')),
    (metrics.WeightedFBeta(beta=.5),
     functools.partial(sk_metrics.fbeta_score, beta=.5, average='weighted')),
    (metrics.F1(), sk_metrics.f1_score),
    (metrics.MacroF1(), functools.partial(sk_metrics.f1_score,
                                          average='macro')),
    (metrics.MicroF1(), functools.partial(sk_metrics.f1_score,
                                          average='micro')),
    (metrics.WeightedF1(),
     functools.partial(sk_metrics.f1_score, average='weighted')),
    (metrics.MCC(), sk_metrics.matthews_corrcoef),
    (metrics.MAE(), sk_metrics.mean_absolute_error),
    (metrics.MSE(), sk_metrics.mean_squared_error),
]


@pytest.mark.parametrize('metric, sk_metric', TEST_CASES)
@pytest.mark.filterwarnings('ignore::RuntimeWarning')
@pytest.mark.filterwarnings(
    'ignore::sklearn.metrics.classification.UndefinedMetricWarning')
def test_metric(metric, sk_metric):

    # Check str works
    str(metric)

    for y_true, y_pred, sample_weights in generate_test_cases(metric=metric,
Esempio n. 2
0
     (metrics.MacroF1(), functools.partial(
         sk_metrics.f1_score, average='macro'), [0, 1, 2, 2,
                                                 2], [0, 0, 2, 2, 1]),
     (metrics.MicroF1(), functools.partial(
         sk_metrics.f1_score, average='micro'), [0, 1, 2, 2,
                                                 2], [0, 0, 2, 2, 1]),
     (metrics.LogLoss(), sk_metrics.log_loss, [True, False, False, True
                                               ], [0.9, 0.1, 0.2, 0.65]),
     (metrics.CrossEntropy(),
      functools.partial(sk_metrics.log_loss, labels=[0, 1, 2]), [0, 1, 2, 2],
      [[0.29450637, 0.34216758, 0.36332605],
       [0.21290077, 0.32728332, 0.45981591],
       [0.42860913, 0.33380113, 0.23758974],
       [0.44941979, 0.32962558, 0.22095463]]),
     (
         metrics.MCC(),
         sk_metrics.matthews_corrcoef,
         [True, True, True, False],
         [True, False, True, True],
     )])
@pytest.mark.filterwarnings('ignore::RuntimeWarning')
@pytest.mark.filterwarnings(
    'ignore::sklearn.metrics.classification.UndefinedMetricWarning')
def test_metric(metric, sk_metric, y_true, y_pred):

    for i, (yt, yp) in enumerate(zip(y_true, y_pred)):

        if isinstance(yp, list):
            yp = dict(enumerate(yp))

        metric.update(yt, yp)