예제 #1
0
def test_optimizer_copy(acq_func):
    """Check that base estimator, objective and target values are copied correctly"""
    # TODO: Refactor - Use PyTest

    base_estimator = ExtraTreesRegressor(random_state=2)
    opt = Optimizer(
        [(-2.0, 2.0)],
        base_estimator,
        acq_func=acq_func,
        n_initial_points=1,
        acq_optimizer="sampling",
    )

    # Run three iterations so that we have some points and objective values
    if "ps" in acq_func:
        opt.run(bench1_with_time, n_iter=3)
    else:
        opt.run(bench1, n_iter=3)

    opt_copy = opt.copy()
    copied_estimator = opt_copy.base_estimator

    if "ps" in acq_func:
        assert isinstance(copied_estimator, MultiOutputRegressor)
        # Check that `base_estimator` is not wrapped multiple times
        assert not isinstance(copied_estimator.estimator, MultiOutputRegressor)
    else:
        assert not isinstance(copied_estimator, MultiOutputRegressor)

    assert_array_equal(opt_copy.Xi, opt.Xi)
    assert_array_equal(opt_copy.yi, opt.yi)
예제 #2
0
def test_optimizer_base_estimator_string_smoke(base_estimator):
    opt = Optimizer([(-2.0, 2.0)], base_estimator=base_estimator, n_initial_points=1, acq_func="EI")
    opt.run(func=lambda x: x[0] ** 2, n_iter=3)