Exemplo n.º 1
0
def _dim_two_gp(mean_shift: Tuple[float, float] = (0.0, 0.0)) -> GaussianProcess:
    matern52 = tfp.math.psd_kernels.MaternFiveHalves(
        amplitude=tf.cast(2.3, tf.float64), length_scale=tf.cast(0.5, tf.float64)
    )
    return GaussianProcess(
        [lambda x: mean_shift[0] + branin(x), lambda x: mean_shift[1] + quadratic(x)],
        [matern52, rbf()],
    )
Exemplo n.º 2
0
def test_batch_monte_carlo_expected_improvement_raises_for_model_with_wrong_event_shape(
) -> None:
    builder = BatchMonteCarloExpectedImprovement(100)
    data = mk_dataset([(0.0, 0.0)], [(0.0, 0.0)])
    matern52 = tfp.math.psd_kernels.MaternFiveHalves(
        amplitude=tf.cast(2.3, tf.float64),
        length_scale=tf.cast(0.5, tf.float64))
    model = GaussianProcess([lambda x: branin(x), lambda x: quadratic(x)],
                            [matern52, rbf()])
    with pytest.raises(TF_DEBUGGING_ERROR_TYPES):
        builder.prepare_acquisition_function(data, model)
Exemplo n.º 3
0
def test_expected_constrained_improvement_is_less_for_constrained_points() -> None:
    class _Constraint(AcquisitionFunctionBuilder):
        def prepare_acquisition_function(
            self, datasets: Mapping[str, Dataset], models: Mapping[str, ProbabilisticModel]
        ) -> AcquisitionFunction:
            return lambda x: tf.cast(x >= 0, x.dtype)

    def two_global_minima(x: tf.Tensor) -> tf.Tensor:
        return x ** 4 / 4 - x ** 2 / 2

    initial_query_points = tf.constant([[-2.0], [0.0], [1.2]])
    data = {"foo": Dataset(initial_query_points, two_global_minima(initial_query_points))}
    models_ = {"foo": GaussianProcess([two_global_minima], [rbf()])}

    eci = ExpectedConstrainedImprovement("foo", _Constraint()).prepare_acquisition_function(
        data, models_
    )

    npt.assert_array_less(eci(tf.constant([-1.0])), eci(tf.constant([1.0])))
Exemplo n.º 4
0
 def __init__(self):
     super().__init__([lambda x: 2 * x], [rbf()])
Exemplo n.º 5
0
 def __init__(self):
     super().__init__([lambda x: tf.exp(-x)], [rbf()])