コード例 #1
0
ファイル: test_rule.py プロジェクト: johnamcleod/trieste
def test_thompson_sampling_raises_for_invalid_models_keys(
        datasets: dict[str, Dataset],
        models: dict[str, ProbabilisticModel]) -> None:
    search_space = Box([-1], [1])
    rule = ThompsonSampling(100, 10)
    with pytest.raises(ValueError):
        rule.acquire(search_space, datasets, models)
コード例 #2
0
def test_thompson_sampling_raises_for_invalid_models_keys(
        datasets: Dict[str, Dataset],
        models: Dict[str, ProbabilisticModel]) -> None:
    search_space = one_dimensional_range(-1, 1)
    rule = ThompsonSampling(100, 10)
    with pytest.raises(ValueError):
        rule.acquire(search_space, datasets, models)
コード例 #3
0

@random_seed
@pytest.mark.parametrize(
    "num_steps, acquisition_rule",
    [
        (20, EfficientGlobalOptimization()),
        (
            15,
            EfficientGlobalOptimization(
                BatchMonteCarloExpectedImprovement(sample_size=500).using(OBJECTIVE),
                num_query_points=2,
            ),
        ),
        (15, TrustRegion()),
        (17, ThompsonSampling(500, 3)),
    ],
)
def test_optimizer_finds_minima_of_the_branin_function(
    num_steps: int, acquisition_rule: AcquisitionRule
) -> None:
    search_space = Box([0, 0], [1, 1])

    def build_model(data: Dataset) -> GaussianProcessRegression:
        variance = tf.math.reduce_variance(data.observations)
        kernel = gpflow.kernels.Matern52(variance, tf.constant([0.2, 0.2], tf.float64))
        gpr = gpflow.models.GPR((data.query_points, data.observations), kernel, noise_variance=1e-5)
        gpflow.utilities.set_trainable(gpr.likelihood, False)
        return GaussianProcessRegression(gpr)

    initial_query_points = search_space.sample(5)
コード例 #4
0
ファイル: test_rule.py プロジェクト: johnamcleod/trieste
def test_thompson_sampling_raises_for_no_points(num_search_space_samples,
                                                num_query_points) -> None:
    with pytest.raises(ValueError):
        ThompsonSampling(num_search_space_samples, num_query_points)