示例#1
0
def test_accuracy_runner():
    # Set up data that should deliver accuracy of 0.20 if all goes right
    class MockAlgo:
        def fit(self, X, y):
            return

        def predict(self, X):
            nr = X.shape[0]
            res = np.zeros(nr)
            res[0:int(nr / 5.0)] = 1.0
            return res

    pair = algorithms.AlgorithmPair(
        MockAlgo,
        MockAlgo,
        shared_args={},
        name="Mock",
        accuracy_function=metrics.accuracy_score,
    )

    runner = AccuracyComparisonRunner([20], [5],
                                      dataset_name='zeros',
                                      test_fraction=0.20)
    results = runner.run(pair)[0]

    assert results["cuml_acc"] == pytest.approx(0.80)
示例#2
0
def test_fil_input_types(input_type):
    pair = algorithms.algorithm_by_name('FIL')

    if not has_xgboost():
        pytest.xfail()

    runner = AccuracyComparisonRunner(
        [20], [5], dataset_name='classification', test_fraction=0.5,
        input_type=input_type)
    results = runner.run(pair, run_cpu=False)[0]
    assert results["cuml_acc"] is not None
示例#3
0
def test_real_algos_runner(algo_name):
    pair = algorithms.algorithm_by_name(algo_name)

    if (algo_name == 'UMAP' and not has_umap()) or \
       (algo_name == 'FIL' and not has_xgboost()):
        pytest.xfail()

    runner = AccuracyComparisonRunner([20], [5],
                                      dataset_name='classification',
                                      test_fraction=0.20)
    results = runner.run(pair)[0]
    print(results)
    assert results["cuml_acc"] is not None
示例#4
0
def test_multi_reps():
    class CountingAlgo:
        tot_reps = 0

        def fit(self, X, y):
            CountingAlgo.tot_reps += 1

    pair = algorithms.AlgorithmPair(
        CountingAlgo,
        CountingAlgo,
        shared_args={},
        name="Counting",
    )

    runner = AccuracyComparisonRunner([20], [5],
                                      dataset_name='zeros',
                                      test_fraction=0.20,
                                      n_reps=4)
    runner.run(pair)

    # Double the n_reps since it is used in cpu and cuml versions
    assert CountingAlgo.tot_reps == 8