def test_mutual_information_shapes(model):
    aq = MutualInformation(model)
    x = np.array([[-1, 1], [0, 0], [-2, 0.1]])

    # value
    res = aq.evaluate(x)
    assert res.shape == (3, 1)

    # gradient
    res = aq.evaluate_with_gradients(x)
    assert res[0].shape == (3, 1)
    assert res[1].shape == (3, 2)
예제 #2
0
def mutual_information(model_test_list_fixture):
    return MutualInformation(model_test_list_fixture)
예제 #3
0
def mutual_information_acquisition(vanilla_bq_model):
    return MutualInformation(vanilla_bq_model)
def test_mutual_information_gradients(model):
    aq = MutualInformation(model)
    x = np.array([[-2.5, 1.5]])
    _check_grad(aq, x)