def test_mutual_information_shapes(model): aq = MutualInformation(model) x = np.array([[-1, 1], [0, 0], [-2, 0.1]]) # value res = aq.evaluate(x) assert res.shape == (3, 1) # gradient res = aq.evaluate_with_gradients(x) assert res[0].shape == (3, 1) assert res[1].shape == (3, 2)
def mutual_information(model_test_list_fixture): return MutualInformation(model_test_list_fixture)
def mutual_information_acquisition(vanilla_bq_model): return MutualInformation(vanilla_bq_model)
def test_mutual_information_gradients(model): aq = MutualInformation(model) x = np.array([[-2.5, 1.5]]) _check_grad(aq, x)