コード例 #1
0
def test_bald_gpu_seg(segmentation_task):
    torch.manual_seed(1337)
    model, test_set = segmentation_task
    wrap = BALDGPUWrapper(model, reduction='sum')

    out = wrap.predict_on_dataset(test_set, 4, 10, False, 4)
    assert out.shape[0] == len(test_set)
    bald = BALD(reduction='sum')
    torch.manual_seed(1337)
    out_bald = bald.get_uncertainties_generator(
        model.predict_on_dataset_generator(test_set, 4, 10, False, 4))
    assert np.allclose(out, out_bald, rtol=1e-5, atol=1e-5)
コード例 #2
0
def test_bald(distributions, reduction):
    np.random.seed(1338)

    bald = BALD(reduction=reduction)
    marg = bald(distributions)
    str_marg = bald(chunks(distributions, 2))

    assert np.allclose(
        bald.get_uncertainties(distributions),
        bald.get_uncertainties_generator(chunks(distributions, 2)),
    )

    assert np.all(marg == [1, 2, 0]), "BALD is not right {}".format(marg)
    assert np.all(str_marg == [1, 2, 0]), "StreamingBALD is not right {}".format(marg)

    bald = BALD(0.99, reduction=reduction)
    marg = bald(distributions)

    # Unlikely, but not 100% sure
    assert np.any(marg != [1, 2, 0])