Example #1
0
def test_gumbel_samples_are_minima() -> None:
    dataset = Dataset(tf.zeros([3, 2], dtype=tf.float64),
                      tf.ones([3, 2], dtype=tf.float64))
    search_space = Box([0, 0], [1, 1])
    model = QuadraticMeanAndRBFKernel()
    gumbel_sampler = GumbelSampler(5, model)

    query_points = search_space.sample(100)
    query_points = tf.concat([dataset.query_points, query_points], 0)
    gumbel_samples = gumbel_sampler.sample(query_points)

    fmean, _ = model.predict(dataset.query_points)
    assert max(gumbel_samples) < min(fmean)
Example #2
0
def test_min_value_entropy_search_builder_gumbel_samples(mocked_mves) -> None:
    dataset = Dataset(tf.zeros([3, 2], dtype=tf.float64), tf.ones([3, 2], dtype=tf.float64))
    search_space = Box([0, 0], [1, 1])
    builder = MinValueEntropySearch(search_space)
    model = QuadraticMeanAndRBFKernel()
    builder.prepare_acquisition_function(dataset, model)
    mocked_mves.assert_called_once()

    # check that the Gumbel samples look sensible
    gumbel_samples = mocked_mves.call_args[0][1]
    query_points = builder._search_space.sample(num_samples=builder._grid_size)
    query_points = tf.concat([dataset.query_points, query_points], 0)
    fmean, _ = model.predict(query_points)
    assert max(gumbel_samples) < min(fmean)
Example #3
0
def test_augmented_expected_improvement_builder_builds_expected_improvement_times_augmentation(
    observation_noise: float, ) -> None:
    dataset = Dataset(
        tf.constant([[-2.0], [-1.0], [0.0], [1.0], [2.0]]),
        tf.constant([[4.1], [0.9], [0.1], [1.1], [3.9]]),
    )

    model = QuadraticMeanAndRBFKernel(noise_variance=observation_noise)
    acq_fn = AugmentedExpectedImprovement().prepare_acquisition_function(
        dataset, model)

    xs = tf.linspace([[-10.0]], [[10.0]], 100)
    ei = ExpectedImprovement().prepare_acquisition_function(dataset, model)(xs)

    _, variance = model.predict(tf.squeeze(xs, -2))
    augmentation = 1.0 - (tf.math.sqrt(observation_noise)) / (
        tf.math.sqrt(observation_noise + variance))
    npt.assert_allclose(acq_fn(xs), ei * augmentation)