def test_bald_gpu_seg(segmentation_task): torch.manual_seed(1337) model, test_set = segmentation_task wrap = BALDGPUWrapper(model, reduction='sum') out = wrap.predict_on_dataset(test_set, 4, 10, False, 4) assert out.shape[0] == len(test_set) bald = BALD(reduction='sum') torch.manual_seed(1337) out_bald = bald.get_uncertainties_generator( model.predict_on_dataset_generator(test_set, 4, 10, False, 4)) assert np.allclose(out, out_bald, rtol=1e-5, atol=1e-5)
def test_bald(distributions, reduction): np.random.seed(1338) bald = BALD(reduction=reduction) marg = bald(distributions) str_marg = bald(chunks(distributions, 2)) assert np.allclose( bald.get_uncertainties(distributions), bald.get_uncertainties_generator(chunks(distributions, 2)), ) assert np.all(marg == [1, 2, 0]), "BALD is not right {}".format(marg) assert np.all(str_marg == [1, 2, 0]), "StreamingBALD is not right {}".format(marg) bald = BALD(0.99, reduction=reduction) marg = bald(distributions) # Unlikely, but not 100% sure assert np.any(marg != [1, 2, 0])