Example #1
0
def test_min_value_entropy_search_chooses_same_as_probability_of_improvement(
) -> None:
    """
    When based on a single max-value sample, MES should choose the same point that probability of
    improvement would when calcualted with the max-value as its threshold (See :cite:`wang2017max`).
    """

    kernel = tfp.math.psd_kernels.MaternFiveHalves()
    model = GaussianProcess([branin], [kernel])

    x_range = tf.linspace(0.0, 1.0, 11)
    x_range = tf.cast(x_range, dtype=tf.float64)
    xs = tf.reshape(
        tf.stack(tf.meshgrid(x_range, x_range, indexing="ij"), axis=-1),
        (-1, 2))

    gumbel_sample = tf.constant([1.0], dtype=tf.float64)
    mes_evals = min_value_entropy_search(model, gumbel_sample, xs)

    mean, variance = model.predict(xs)
    gamma = (tf.cast(gumbel_sample, dtype=mean.dtype) -
             mean) / tf.sqrt(variance)
    norm = tfp.distributions.Normal(tf.cast(0, dtype=mean.dtype),
                                    tf.cast(1, dtype=mean.dtype))
    pi_evals = norm.cdf(gamma)
    npt.assert_array_equal(tf.argmax(mes_evals), tf.argmax(pi_evals))
Example #2
0
def test_min_value_entropy_search_raises_for_invalid_batch_size(
        at: TensorType) -> None:
    mes = min_value_entropy_search(QuadraticMeanAndRBFKernel(),
                                   tf.constant([[1.0], [2.0]]))

    with pytest.raises(TF_DEBUGGING_ERROR_TYPES):
        mes(at)
Example #3
0
def test_min_value_entropy_search_returns_correct_shape() -> None:
    model = QuadraticMeanAndRBFKernel()
    gumbel_samples = tf.constant([1.0])
    query_at = tf.linspace([[-10.0]], [[10.0]], 5)
    evals = min_value_entropy_search(model, gumbel_samples)(query_at)
    npt.assert_array_equal(evals.shape, tf.constant([5, 1]))
Example #4
0
def test_min_value_entropy_search_raises_for_gumbel_samples_with_invalid_shape(
    samples: TensorType,
) -> None:
    with pytest.raises(ValueError):
        min_value_entropy_search(QuadraticMeanAndRBFKernel(), samples)