コード例 #1
0
ファイル: heuristic_test.py プロジェクト: xyuan/baal
def test_random(distributions):
    np.random.seed(1337)
    random = Random()
    all_equals = np.all(
        [np.allclose(random(distributions), random(distributions)) for _ in range(10)]
    )
    assert not all_equals

    random = Random(threshold=0.1)
    marg = random(distributions)
    assert np.any(distributions[marg] <= 0.1)
コード例 #2
0
def test_heuristics_reorder_list():
    # we are just testing if given calculated uncertainty measures for chunks of data
    # the `reorder_indices` would make correct decision. Here index 0 has the
    # highest uncertainty chosen but both methods (uncertainties1 and uncertainties2)
    streaming_prediction = [np.array([0.98]), np.array([0.87, 0.68]),
                            np.array([0.96, 0.54])]
    heuristic = BALD()
    ranks = heuristic.reorder_indices(streaming_prediction)
    assert np.all(ranks == [0, 3, 1, 2, 4]), "reorder list for BALD is not right {}".format(ranks)

    heuristic = Variance()
    ranks = heuristic.reorder_indices(streaming_prediction)
    assert np.all(ranks == [0, 3, 1, 2, 4]), "reorder list for Variance is not right {}".format(
        ranks)

    heuristic = Entropy()
    ranks = heuristic.reorder_indices(streaming_prediction)
    assert np.all(ranks == [0, 3, 1, 2, 4]), "reorder list for Entropy is not right {}".format(
        ranks)

    heuristic = Margin()
    ranks = heuristic.reorder_indices(streaming_prediction)
    assert np.all(ranks == [4, 2, 1, 3, 0]), "reorder list for Margin is not right {}".format(ranks)

    heuristic = Certainty()
    ranks = heuristic.reorder_indices(streaming_prediction)
    assert np.all(ranks == [4, 2, 1, 3, 0]), "reorder list for Certainty is not right {}".format(
        ranks)

    heuristic = Random()
    ranks = heuristic.reorder_indices(streaming_prediction)
    assert ranks.size == 5, "reorder list for Random is not right {}".format(
        ranks)
コード例 #3
0
def test_random(distributions):
    np.random.seed(1337)
    random = Random()
    all_equals = np.all(
        [np.allclose(random(distributions), random(distributions)) for _ in range(10)]
    )
    assert not all_equals
コード例 #4
0
    variance_firstchunk = np.array([0.76])
    variance_secondchunk = np.array([0.63, 0.48])
    streaming_prediction = [[bald_firstchunk, variance_firstchunk],
                            [bald_secondchunk, variance_secondchunk]]

    heuristics = CombineHeuristics([BALD(), Variance()],
                                   weights=[0.5, 0.5],
                                   reduction='mean')
    ranks = heuristics.reorder_indices(streaming_prediction)
    assert np.all(
        ranks == [0, 1, 2]), "Combine Heuristics is not right {}".format(ranks)


@pytest.mark.parametrize("heur", [
    Random(),
    BALD(reduction='sum'),
    Entropy(reduction='sum'),
    Variance(reduction='sum')
])
@pytest.mark.parametrize("n_batch", [1, 10, 20])
def test_heuristics_works_with_generator(heur, n_batch):
    BATCH_SIZE = 32

    def predictions(n_batch):
        for _ in range(n_batch):
            yield np.random.randn(BATCH_SIZE, 3, 32, 32, 10)

    preds = predictions(n_batch)
    out = heur(preds)
    assert out.shape[0] == n_batch * BATCH_SIZE