예제 #1
0
def main():
    str_acq = 'aei'
    num_iter = 10
    X_train = np.array([
        [-5],
        [-1],
        [1],
        [2],
    ])
    num_init = X_train.shape[0]
    model_bo = bo.BO(np.array([[-6., 6.]]), str_acq=str_acq)
    X_test = np.linspace(-6, 6, 400)
    X_test = np.reshape(X_test, (400, 1))
    for ind_ in range(1, num_iter + 1):
        Y_train = fun_target(X_train)
        next_x, dict_info = model_bo.optimize(X_train,
                                              fun_target(X_train),
                                              str_initial_method_ao='uniform')
        cov_X_X = dict_info['cov_X_X']
        inv_cov_X_X = dict_info['inv_cov_X_X']
        hyps = dict_info['hyps']

        mu_test, sigma_test, Sigma_test = gp.predict_test_(
            X_train, Y_train, X_test, cov_X_X, inv_cov_X_X, hyps)
        acq_test = acquisition.aei(mu_test.flatten(), sigma_test.flatten(),
                                   Y_train, hyps['noise'])
        acq_test = np.expand_dims(acq_test, axis=1)
        X_train = np.vstack((X_train, next_x))
        Y_train = fun_target(X_train)
        utils_plotting.plot_bo_step(X_train,
                                    Y_train,
                                    X_test,
                                    fun_target(X_test),
                                    mu_test,
                                    sigma_test,
                                    path_save=PATH_SAVE,
                                    str_postfix='bo_{}_'.format(str_acq) +
                                    str(ind_),
                                    int_init=num_init)
        utils_plotting.plot_bo_step_acq(X_train,
                                        Y_train,
                                        X_test,
                                        fun_target(X_test),
                                        mu_test,
                                        sigma_test,
                                        acq_test,
                                        path_save=PATH_SAVE,
                                        str_postfix='bo_{}_'.format(str_acq) +
                                        str(ind_),
                                        int_init=num_init)
예제 #2
0
def choose_fun_acquisition(
    str_acq: str,
    noise: constants.TYPING_UNION_FLOAT_NONE = None
) -> constants.TYPING_CALLABLE:
    """
    It chooses and returns an acquisition function.

    :param str_acq: the name of acquisition function.
    :type str_acq: str.
    :param hyps: dictionary of hyperparameters for acquisition function.
    :type hyps: dict.

    :returns: acquisition function.
    :rtype: callable

    :raises: AssertionError

    """

    assert isinstance(str_acq, str)
    assert isinstance(noise, (float, constants.TYPE_NONE))
    assert str_acq in constants.ALLOWED_BO_ACQ

    if str_acq == 'pi':
        fun_acquisition = acquisition.pi
    elif str_acq == 'ei':
        fun_acquisition = acquisition.ei
    elif str_acq == 'ucb':
        fun_acquisition = acquisition.ucb
    elif str_acq == 'aei':
        assert noise is not None

        fun_acquisition = lambda pred_mean, pred_std, Y_train: acquisition.aei(
            pred_mean, pred_std, Y_train, noise)
    elif str_acq == 'pure_exploit':
        fun_acquisition = lambda pred_mean, pred_std, Y_train: acquisition.pure_exploit(
            pred_mean)
    elif str_acq == 'pure_explore':
        fun_acquisition = lambda pred_mean, pred_std, Y_train: acquisition.pure_explore(
            pred_std)
    else:
        raise NotImplementedError('_choose_fun_acquisition: allowed str_acq,\
            but it is not implemented.')

    return fun_acquisition
예제 #3
0
파일: bo.py 프로젝트: cltk9090/bayeso
def _choose_fun_acquisition(str_acq, hyps):
    """
    It chooses and returns an acquisition function.

    :param str_acq: the name of acquisition function.
    :type str_acq: str.
    :param hyps: dictionary of hyperparameters for acquisition function.
    :type hyps: dict.

    :returns: acquisition function.
    :rtype: function

    :raises: AssertionError

    """

    assert isinstance(str_acq, str)
    assert isinstance(hyps, dict)
    assert str_acq in constants.ALLOWED_BO_ACQ

    if str_acq == 'pi':
        fun_acquisition = acquisition.pi
    elif str_acq == 'ei':
        fun_acquisition = acquisition.ei
    elif str_acq == 'ucb':
        fun_acquisition = acquisition.ucb
    elif str_acq == 'aei':
        fun_acquisition = lambda pred_mean, pred_std, Y_train: acquisition.aei(
            pred_mean, pred_std, Y_train, hyps['noise'])
    elif str_acq == 'pure_exploit':
        fun_acquisition = acquisition.pure_exploit
    elif str_acq == 'pure_explore':
        fun_acquisition = acquisition.pure_explore
    else:
        raise NotImplementedError(
            '_choose_fun_acquisition: allowed str_acq, but it is not implemented.'
        )
    return fun_acquisition
예제 #4
0
def _choose_fun_acquisition(str_acq, hyps):
    assert isinstance(str_acq, str)
    assert isinstance(hyps, dict)
    assert str_acq in constants.ALLOWED_BO_ACQ

    if str_acq == 'pi':
        fun_acquisition = acquisition.pi
    elif str_acq == 'ei':
        fun_acquisition = acquisition.ei
    elif str_acq == 'ucb':
        fun_acquisition = acquisition.ucb
    elif str_acq == 'aei':
        fun_acquisition = lambda pred_mean, pred_std, Y_train: acquisition.aei(
            pred_mean, pred_std, Y_train, hyps['noise'])
    elif str_acq == 'pure_exploit':
        fun_acquisition = acquisition.pure_exploit
    elif str_acq == 'pure_explore':
        fun_acquisition = acquisition.pure_explore
    else:
        raise NotImplementedError(
            '_choose_fun_acquisition: allowed str_acq, but it is not implemented.'
        )
    return fun_acquisition
예제 #5
0
def test_aei():
    with pytest.raises(AssertionError) as error:
        package_target.aei('abc', np.ones(10), np.zeros((5, 1)), 1.0)
    with pytest.raises(AssertionError) as error:
        package_target.aei(np.ones(10), 'abc', np.zeros((5, 1)), 1.0)
    with pytest.raises(AssertionError) as error:
        package_target.aei(np.ones(10), np.ones(10), 'abc', 1.0)
    with pytest.raises(AssertionError) as error:
        package_target.aei(np.ones(10), np.ones(10), np.zeros((5, 1)), 'abc')
    with pytest.raises(AssertionError) as error:
        package_target.aei(np.ones(10), np.ones(10), np.zeros((5, 1)), 1.0, jitter=1)
    with pytest.raises(AssertionError) as error:
        package_target.aei(np.ones(5), np.ones(10), np.zeros((5, 1)), 1.0)
    with pytest.raises(AssertionError) as error:
        package_target.aei(np.ones(10), np.ones(10), np.zeros(5), 1.0)
    with pytest.raises(AssertionError) as error:
        package_target.aei(np.ones(10), np.ones((10, 1)), np.zeros((5, 1)), 1.0)
    with pytest.raises(AssertionError) as error:
        package_target.aei(np.ones((10, 1)), np.ones(10), np.zeros((5, 1)), 1.0)

    val_acq = package_target.aei(np.arange(0, 10), np.ones(10), np.zeros((5, 1)), 1.0)
    truth_val_acq = np.array([1.16847489e-01, 2.44025364e-02, 2.48686922e-03, 1.11930407e-04, 2.09279771e-06, 1.56585558e-08, 4.57958958e-11, 5.15587486e-14, 2.21142019e-17, 3.58729395e-21])
    assert (np.abs(val_acq - truth_val_acq) < TEST_EPSILON).all()