Exemplo n.º 1
0
def test_roc_auc(caplog):
    """Unit test of roc_auc_scorer"""

    caplog.set_level(logging.INFO)

    golds = np.array([[1], [0], [1], [0], [1], [0]])
    gold_probs = np.array([[0.4, 0.6], [0.9, 0.1], [0.3, 0.7], [0.8, 0.2],
                           [0.1, 0.9], [0.6, 0.4]])
    probs = np.array([[0.2, 0.8], [0.4, 0.6], [0.1, 0.9], [0.3, 0.7],
                      [0.3, 0.7], [0.8, 0.2]])
    preds = np.array([[0.8], [0.6], [0.9], [0.7], [0.7], [0.2]])

    metric_dict = roc_auc_scorer(golds, probs, None)

    assert isequal(metric_dict, {"roc_auc": 0.9444444444444444})

    metric_dict = roc_auc_scorer(gold_probs, probs, None)

    assert isequal(metric_dict, {"roc_auc": 0.9444444444444444})

    metric_dict = roc_auc_scorer(golds, preds, None)

    assert isequal(metric_dict, {"roc_auc": 0.9444444444444444})

    metric_dict = roc_auc_scorer(gold_probs, preds, None)

    assert isequal(metric_dict, {"roc_auc": 0.9444444444444444})

    golds = np.array([1, 1, 1, 1, 1, 1])

    metric_dict = roc_auc_scorer(golds, probs, None)
    assert isequal(metric_dict, {"roc_auc": float("nan")})
Exemplo n.º 2
0
def test_roc_auc(caplog):
    """Unit test of roc_auc_scorer."""
    caplog.set_level(logging.INFO)

    metric_dict = roc_auc_scorer(GOLDS, PROBS, None)

    assert isequal(metric_dict, {"roc_auc": 0.8333333333333333})

    metric_dict = roc_auc_scorer(PROB_GOLDS, PROBS, None)

    assert isequal(metric_dict, {"roc_auc": 0.8333333333333333})

    metric_dict = roc_auc_scorer(GOLDS, UNARY_PROBS, None)

    assert isequal(metric_dict, {"roc_auc": 0.8333333333333334})

    metric_dict = roc_auc_scorer(
        GOLDS.reshape(GOLDS.shape[0], 1),
        UNARY_PROBS.reshape(UNARY_PROBS.shape[0], 1),
        None,
    )

    assert isequal(metric_dict, {"roc_auc": 0.8333333333333334})

    metric_dict = roc_auc_scorer(PROB_GOLDS, UNARY_PROBS, None)

    assert isequal(metric_dict, {"roc_auc": 0.8333333333333334})

    ALL_ONES = np.array([1, 1, 1, 1, 1, 1])

    metric_dict = roc_auc_scorer(ALL_ONES, PROBS, None)
    assert isequal(metric_dict, {"roc_auc": float("nan")})