예제 #1
0
def test_str_repr_del(
    sample_labels,
    sample_predictions,
    population_size,
    population_flagged_count,
    confidence_level,
    exact_precision,
):
    """Unit test for str, repr, and del methods."""
    classification_confidence_intervals = ClassificationConfidenceIntervals(
        sample_labels=sample_labels,
        sample_predictions=sample_predictions,
        population_size=population_size,
        population_flagged_count=population_flagged_count,
        confidence_level=confidence_level,
        exact_precision=exact_precision,
    )
    pos_rate_cis, _, _, _ = classification_confidence_intervals.get_cis()

    assert (
        classification_confidence_intervals.__str__()
        == classification_confidence_intervals.__repr__()
    )
    del classification_confidence_intervals

    assert pos_rate_cis.__str__() == pos_rate_cis.__repr__()
    del pos_rate_cis
예제 #2
0
def test_get_cis(
    sample_labels,
    sample_predictions,
    population_size,
    population_flagged_count,
    confidence_level,
    exact_precision,
    plot_filename,
):
    """Unit test for get_cis method in ClassificationConfidenceIntervals class."""
    classification_confidence_intervals = ClassificationConfidenceIntervals(
        sample_labels=sample_labels,
        sample_predictions=sample_predictions,
        population_size=population_size,
        population_flagged_count=population_flagged_count,
        confidence_level=confidence_level,
        exact_precision=exact_precision,
    )
    classification_confidence_intervals.get_cis(n_iters=100, plot_filename=plot_filename)
예제 #3
0
def test_bad_get_cis(
    sample_labels,
    sample_predictions,
    population_size,
    population_flagged_count,
    confidence_level,
    exact_precision,
    n_iters,
    plot_filename,
):
    """Unit test for bad inputs to get_cis method in ClassificationConfidenceIntervals class."""
    classification_confidence_intervals = ClassificationConfidenceIntervals(
        sample_labels=sample_labels,
        sample_predictions=sample_predictions,
        population_size=population_size,
        population_flagged_count=population_flagged_count,
        confidence_level=confidence_level,
        exact_precision=exact_precision,
    )
    with pytest.raises(Exception):
        classification_confidence_intervals.get_cis(n_iters=n_iters, plot_filename=plot_filename)
예제 #4
0
def test_overwrite(
    sample_labels,
    sample_predictions,
    population_size,
    population_flagged_count,
    confidence_level,
    exact_precision,
):
    """Unit test for read_only_properties wrapper."""
    classification_confidence_intervals = ClassificationConfidenceIntervals(
        sample_labels=sample_labels,
        sample_predictions=sample_predictions,
        population_size=population_size,
        population_flagged_count=population_flagged_count,
        confidence_level=confidence_level,
        exact_precision=exact_precision,
    )

    pos_rate_cis, ppv_cis, npv_cis, recall_cis = classification_confidence_intervals.get_cis()

    with pytest.raises(Exception):
        pos_rate_cis.tnorm_ci = tuple([0.5, 0.8])
예제 #5
0
def test_init(
    sample_labels,
    sample_predictions,
    population_size,
    population_flagged_count,
    confidence_level,
    exact_precision,
):
    """Unit test for __init__ method in ClassificationConfidenceIntervals class."""
    ClassificationConfidenceIntervals(
        sample_labels=sample_labels,
        sample_predictions=sample_predictions,
        population_size=population_size,
        population_flagged_count=population_flagged_count,
        confidence_level=confidence_level,
        exact_precision=exact_precision,
    )
예제 #6
0
def test_bad_init(
    sample_labels,
    sample_predictions,
    population_size,
    population_flagged_count,
    confidence_level,
    exact_precision,
):
    """Unit test for bad inputs to __init__ method in ClassificationConfidenceIntervals class."""
    with pytest.raises(Exception):
        ClassificationConfidenceIntervals(
            sample_labels=sample_labels,
            sample_predictions=sample_predictions,
            population_size=population_size,
            population_flagged_count=population_flagged_count,
            confidence_level=confidence_level,
            exact_precision=exact_precision,
        )
예제 #7
0
def test_iris():
    iris = datasets.load_iris()
    y_true, y_pred, _ = make_prediction(dataset=iris, binary=False)
    le = LabelEncoder()
    le.fit(np.unique(y_true))
    y_true = le.transform(y_true)
    y_pred = le.transform(y_pred)
    enc = OneHotEncoder()
    y_true_matrix = enc.fit_transform(y_true.reshape(-1, 1)).toarray()
    y_pred_matrix = enc.transform(y_pred.reshape(-1, 1)).toarray()
    pos_rate_cis, ppv_cis, npv_cis, recall_cis = [], [], [], []
    ci = 0.95
    for j in range(y_true_matrix.shape[1]):
        class_ci = ClassificationConfidenceIntervals(
            sample_labels=y_true_matrix[:, j],
            sample_predictions=y_pred_matrix[:, j],
            population_size=1000000,
            population_flagged_count=50000,
            confidence_level=ci).get_cis()
        pos_rate_cis.append(class_ci[0].tnorm_ci)
        ppv_cis.append(class_ci[1].tnorm_ci)
        npv_cis.append(class_ci[2].tnorm_ci)
        recall_cis.append(class_ci[3].tnorm_ci)