def test_cache_root_decomposition(self):
     tkwargs = {"device": self.device}
     for dtype in (torch.float, torch.double):
         tkwargs["dtype"] = dtype
         # test mt-mvn
         train_x = torch.rand(2, 1, **tkwargs)
         train_y = torch.rand(2, 2, **tkwargs)
         test_x = torch.rand(2, 1, **tkwargs)
         model = SingleTaskGP(train_x, train_y)
         sampler = IIDNormalSampler(1)
         with torch.no_grad():
             posterior = model.posterior(test_x)
         acqf = DummyCachedCholeskyAcqf(
             model=model,
             sampler=sampler,
             objective=GenericMCObjective(lambda Y: Y[..., 0]),
         )
         baseline_L = torch.eye(2, **tkwargs)
         with mock.patch(
                 EXTRACT_BATCH_COVAR_PATH,
                 wraps=extract_batch_covar) as mock_extract_batch_covar:
             with mock.patch(CHOLESKY_PATH,
                             return_value=baseline_L) as mock_cholesky:
                 acqf._cache_root_decomposition(posterior=posterior)
                 mock_extract_batch_covar.assert_called_once_with(
                     posterior.mvn)
                 mock_cholesky.assert_called_once()
         # test mvn
         model = SingleTaskGP(train_x, train_y[:, :1])
         with torch.no_grad():
             posterior = model.posterior(test_x)
         with mock.patch(
                 EXTRACT_BATCH_COVAR_PATH) as mock_extract_batch_covar:
             with mock.patch(CHOLESKY_PATH,
                             return_value=baseline_L) as mock_cholesky:
                 acqf._cache_root_decomposition(posterior=posterior)
                 mock_extract_batch_covar.assert_not_called()
                 mock_cholesky.assert_called_once()
         self.assertTrue(torch.equal(acqf._baseline_L, baseline_L))
Example #2
0
def gp_fit_test(x_train: Tensor,
                y_train: Tensor,
                error_train: Tensor,
                x_test: Tensor,
                y_test: Tensor,
                error_test: Tensor,
                gp_obj_model: SingleTaskGP,
                gp_error_model: SingleTaskGP,
                tkwargs: Dict[str, Any],
                gp_test_folder: str,
                obj_out_wp: bool = False,
                err_out_wp: bool = False) -> None:
    """
    1) Estimates mean test error between predicted and the true objective function values.
    2) Estimates mean test error between predicted recon. error by the gp_model and the true recon. error of the vae_model.
    :param x_train: normalised points at which the gps were trained
    :param y_train: objective value function corresponding to x_train that were used as targets of `gp_obj_model`
    :param error_train: reconstruction error value at points x_train that were used as targets of `gp_error_model`
    :param x_test: normalised test points
    :param y_test: objective value function corresponding to x_test
    :param error_test: reconstruction error at test points
    :param gp_obj_model: the gp model trained to predict the black box objective function values
    :param gp_error_model: the gp model trained to predict reconstruction error
    :param tkwargs: dict of type and device
    :param gp_test_folder: folder to save test results
    :param obj_out_wp: if the `gp_obj_model` was trained with output warping then need to apply the same transform
    :param err_out_wp: if the `gp_error_model` was trained with output warping then need to apply the same transform
    :return: (Sum_i||true_y_i - pred_y_i||^2 / n_points, Sum_i||true_recon_i - pred_recon_i||^2 / n_points)
    """
    do_robust = True if gp_error_model is not None else False
    if not os.path.exists(gp_test_folder):
        os.mkdir(gp_test_folder)

    gp_obj_model.eval()
    gp_obj_model.to(tkwargs['device'])
    y_train = y_train.view(-1)
    if do_robust:
        gp_error_model.eval()
        gp_error_model.to(tkwargs['device'])
        error_train = error_train.view(-1)

    with torch.no_grad():
        if obj_out_wp:
            Y_numpy = y_train.cpu().numpy()
            if Y_numpy.min() <= 0:
                y_train = torch.FloatTensor(
                    power_transform(Y_numpy / Y_numpy.std(),
                                    method='yeo-johnson'))
            else:
                y_train = torch.FloatTensor(
                    power_transform(Y_numpy / Y_numpy.std(), method='box-cox'))
                if y_train.std() < 0.5:
                    Y_numpy = y_train.numpy()
                    y_train = torch.FloatTensor(
                        power_transform(Y_numpy / Y_numpy.std(),
                                        method='yeo-johnson')).to(x_train)

            Y_numpy = y_test.cpu().numpy()
            if Y_numpy.min() <= 0:
                y_test = torch.FloatTensor(
                    power_transform(Y_numpy / Y_numpy.std(),
                                    method='yeo-johnson'))
            else:
                y_test = torch.FloatTensor(
                    power_transform(Y_numpy / Y_numpy.std(), method='box-cox'))
                if y_test.std() < 0.5:
                    Y_numpy = y_test.numpy()
                    y_test = torch.FloatTensor(
                        power_transform(Y_numpy / Y_numpy.std(),
                                        method='yeo-johnson')).to(x_test)

        y_train = y_train.view(-1).to(**tkwargs)
        y_test = y_test.view(-1).to(**tkwargs)

        gp_obj_val_model_mse_train = (
            gp_obj_model.posterior(x_train).mean.view(-1) -
            y_train).pow(2).div(len(y_train))
        gp_obj_val_model_mse_test = (
            gp_obj_model.posterior(x_test).mean.view(-1) - y_test).pow(2).div(
                len(y_test))
        torch.save(
            gp_obj_val_model_mse_train,
            os.path.join(gp_test_folder, 'gp_obj_val_model_mse_train.npz'))
        torch.save(gp_obj_val_model_mse_test,
                   os.path.join(gp_test_folder, 'gp_obj_val_model_test.npz'))
        print(
            f'GP training fit on objective value: MSE={gp_obj_val_model_mse_train.sum().item():.5f}'
        )
        print(
            f'GP testing fit on objective value: MSE={gp_obj_val_model_mse_test.sum().item():.5f}'
        )

        if do_robust:
            if err_out_wp:
                error_train = error_train.view(-1, 1)
                R_numpy = error_train.cpu().numpy()
                if R_numpy.min() <= 0:
                    error_train = torch.FloatTensor(
                        power_transform(R_numpy / R_numpy.std(),
                                        method='yeo-johnson'))
                else:
                    error_train = torch.FloatTensor(
                        power_transform(R_numpy / R_numpy.std(),
                                        method='box-cox'))
                    if error_train.std() < 0.5:
                        R_numpy = error_train.numpy()
                        error_train = torch.FloatTensor(
                            power_transform(R_numpy / R_numpy.std(),
                                            method='yeo-johnson')).to(x_train)

                R_numpy = error_test.cpu().numpy()
                if R_numpy.min() <= 0:
                    error_test = torch.FloatTensor(
                        power_transform(R_numpy / R_numpy.std(),
                                        method='yeo-johnson'))
                else:
                    error_test = torch.FloatTensor(
                        power_transform(R_numpy / R_numpy.std(),
                                        method='box-cox'))
                    if error_test.std() < 0.5:
                        R_numpy = error_test.numpy()
                        error_test = torch.FloatTensor(
                            power_transform(R_numpy / R_numpy.std(),
                                            method='yeo-johnson')).to(x_test)

            error_train = error_train.view(-1).to(**tkwargs)
            error_test = error_test.view(-1).to(**tkwargs)

            pred_recon_train = gp_error_model.posterior(x_train).mean.view(-1)
            pred_recon_test = gp_error_model.posterior(x_test).mean.view(-1)

            gp_error_model_mse_train = (error_train -
                                        pred_recon_train).pow(2).div(
                                            len(error_train))
            gp_error_model_mse_test = (error_test -
                                       pred_recon_test).pow(2).div(
                                           len(error_test))
            torch.save(
                gp_error_model_mse_train,
                os.path.join(gp_test_folder, 'gp_error_model_mse_train.npz'))
            torch.save(
                gp_error_model_mse_test,
                os.path.join(gp_test_folder, 'gp_error_model_mse_test.npz'))
            print(
                f'GP training fit on reconstruction errors: MSE={gp_error_model_mse_train.sum().item():.5f}'
            )
            print(
                f'GP testing fit on reconstruction errors: MSE={gp_error_model_mse_test.sum().item():.5f}'
            )
            torch.save(error_test,
                       os.path.join(gp_test_folder, f"true_rec_err_z.pt"))
            torch.save(error_train,
                       os.path.join(gp_test_folder, f"error_train.pt"))

        torch.save(x_train, os.path.join(gp_test_folder, f"train_x.pt"))
        torch.save(x_test, os.path.join(gp_test_folder, f"test_x.pt"))
        torch.save(y_train, os.path.join(gp_test_folder, f"y_train.pt"))
        torch.save(x_test, os.path.join(gp_test_folder, f"X_test.pt"))
        torch.save(y_test, os.path.join(gp_test_folder, f"y_test.pt"))

        # y plots
        plt.hist(y_train.cpu().numpy(),
                 bins=100,
                 label='y train',
                 alpha=0.5,
                 density=True)
        plt.hist(gp_obj_model.posterior(x_train).mean.view(
            -1).detach().cpu().numpy(),
                 bins=100,
                 label='y pred',
                 alpha=0.5,
                 density=True)
        plt.legend()
        plt.title('Training set')
        plt.savefig(os.path.join(gp_test_folder, 'gp_obj_train.pdf'))
        plt.close()

        plt.hist(gp_obj_val_model_mse_train.detach().cpu().numpy(),
                 bins=100,
                 alpha=0.5,
                 density=True)
        plt.title('MSE of gp_obj_val model on training set')
        plt.savefig(os.path.join(gp_test_folder, 'gp_obj_train_mse.pdf'))
        plt.close()

        plt.hist(y_test.cpu().numpy(),
                 bins=100,
                 label='y true',
                 alpha=0.5,
                 density=True)
        plt.hist(gp_obj_model.posterior(x_test).mean.detach().cpu().numpy(),
                 bins=100,
                 alpha=0.5,
                 label='y pred',
                 density=True)
        plt.legend()
        plt.title('Validation set')
        plt.savefig(os.path.join(gp_test_folder, 'gp_obj_test.pdf'))
        plt.close()

        plt.hist(gp_obj_val_model_mse_test.detach().cpu().numpy(),
                 bins=100,
                 alpha=0.5,
                 density=True)
        plt.title('MSE of gp_obj_val model on validation set')
        plt.savefig(os.path.join(gp_test_folder, 'gp_obj_test_mse.pdf'))
        plt.close()

        if do_robust:
            # error plots
            plt.hist(error_train.cpu().numpy(),
                     bins=100,
                     label='error train',
                     alpha=0.5,
                     density=True)
            plt.hist(
                gp_error_model.posterior(x_train).mean.detach().cpu().numpy(),
                bins=100,
                label='error pred',
                alpha=0.5,
                density=True)
            plt.legend()
            plt.title('Training set')
            plt.savefig(os.path.join(gp_test_folder, 'gp_error_train.pdf'))
            plt.close()

            plt.hist(gp_error_model_mse_train.detach().cpu().numpy(),
                     bins=100,
                     alpha=0.5,
                     density=True)
            plt.title('MSE of gp_error model on training set')
            plt.savefig(os.path.join(gp_test_folder, 'gp_error_train_mse.pdf'))
            plt.close()

            plt.hist(error_test.cpu().numpy(),
                     bins=100,
                     label='error true',
                     alpha=0.5,
                     density=True)
            plt.hist(
                gp_error_model.posterior(x_test).mean.detach().cpu().numpy(),
                bins=100,
                alpha=0.5,
                label='error pred',
                density=True)
            plt.legend()
            plt.title('Validation set')
            plt.savefig(os.path.join(gp_test_folder, 'gp_error_test.pdf'))
            plt.close()

            plt.hist(gp_error_model_mse_test.detach().cpu().numpy(),
                     bins=100,
                     alpha=0.5,
                     density=True)
            plt.title('MSE of gp_error model on validation set')
            plt.savefig(os.path.join(gp_test_folder, 'gp_error_test_mse.pdf'))
            plt.close()

            # y-error plots
            y_train_sorted, indices_train = torch.sort(y_train)
            error_train_sorted = error_train[indices_train]
            gp_y_train_pred_sorted, indices_train_pred = torch.sort(
                gp_obj_model.posterior(x_train).mean.view(-1))
            gp_r_train_pred_sorted = (gp_error_model.posterior(
                x_train).mean.view(-1))[indices_train_pred]
            plt.scatter(y_train_sorted.cpu().numpy(),
                        error_train_sorted.cpu().numpy(),
                        label='true',
                        marker='+')
            plt.scatter(gp_y_train_pred_sorted.detach().cpu().numpy(),
                        gp_r_train_pred_sorted.detach().cpu().numpy(),
                        label='pred',
                        marker='*')
            plt.xlabel('y train targets')
            plt.ylabel('recon. error train targets')
            plt.title('y_train vs. error_train')
            plt.legend()
            plt.savefig(
                os.path.join(gp_test_folder, 'scatter_obj_error_train.pdf'))
            plt.close()

            y_test_std_sorted, indices_test = torch.sort(y_test)
            error_test_sorted = error_test[indices_test]
            gp_y_test_pred_sorted, indices_test_pred = torch.sort(
                gp_obj_model.posterior(x_test).mean.view(-1))
            gp_r_test_pred_sorted = (gp_error_model.posterior(
                x_test).mean.view(-1))[indices_test_pred]
            plt.scatter(y_test_std_sorted.cpu().numpy(),
                        error_test_sorted.cpu().numpy(),
                        label='true',
                        marker='+')
            plt.scatter(gp_y_test_pred_sorted.detach().cpu().numpy(),
                        gp_r_test_pred_sorted.detach().cpu().numpy(),
                        label='pred',
                        marker='*')
            plt.xlabel('y test targets')
            plt.ylabel('recon. error test targets')
            plt.title('y_test vs. error_test')
            plt.legend()
            plt.savefig(
                os.path.join(gp_test_folder, 'scatter_obj_error_test.pdf'))
            plt.close()

            # error var plots
            error_train_sorted, indices_train_pred = torch.sort(error_train)
            # error_train_sorted = error_train
            # indices_train_pred = np.arange(len(error_train))
            gp_r_train_pred_sorted = gp_error_model.posterior(
                x_train).mean[indices_train_pred].view(-1)
            gp_r_train_pred_std_sorted = gp_error_model.posterior(
                x_train).variance.view(-1).sqrt()[indices_train_pred]
            plt.scatter(np.arange(len(indices_train_pred)),
                        error_train_sorted.cpu().numpy(),
                        label='err true',
                        marker='+',
                        color='C1',
                        s=15)
            plt.errorbar(
                np.arange(len(indices_train_pred)),
                gp_r_train_pred_sorted.detach().cpu().numpy().flatten(),
                yerr=gp_r_train_pred_std_sorted.detach().cpu().numpy().flatten(
                ),
                fmt='*',
                alpha=0.05,
                label='err pred',
                color='C0',
                ecolor='C0')
            plt.scatter(np.arange(len(indices_train_pred)),
                        gp_r_train_pred_sorted.detach().cpu().numpy(),
                        marker='*',
                        alpha=0.2,
                        s=10,
                        color='C0')
            # plt.scatter(np.arange(len(indices_train_pred)),
            #             (gp_r_train_pred_sorted + gp_r_train_pred_std_sorted).detach().cpu().numpy(),
            #             label='err pred mean+std', marker='.')
            # plt.scatter(np.arange(len(indices_train_pred)),
            #             (gp_r_train_pred_sorted - gp_r_train_pred_std_sorted).detach().cpu().numpy(),
            #             label='err pred mean-std', marker='.')
            plt.legend()
            plt.title('error predictions and uncertainty on train set')
            plt.savefig(
                os.path.join(gp_test_folder, 'gp_error_train_uncertainty.pdf'))
            plt.close()

            error_test_sorted, indices_test_pred = torch.sort(error_test)
            # error_test_sorted = error_test
            # indices_test_pred = np.arange(len(error_test_sorted))
            gp_r_test_pred_sorted = gp_error_model.posterior(x_test).mean.view(
                -1)[indices_test_pred]
            gp_r_test_pred_std_sorted = gp_error_model.posterior(
                x_test).variance.view(-1).sqrt()[indices_test_pred]
            plt.scatter(np.arange(len(indices_test_pred)),
                        error_test_sorted.cpu().numpy(),
                        label='err true',
                        marker='+',
                        color='C1',
                        s=15)
            plt.errorbar(
                np.arange(len(indices_test_pred)),
                gp_r_test_pred_sorted.detach().cpu().numpy().flatten(),
                yerr=gp_r_test_pred_std_sorted.detach().cpu().numpy().flatten(
                ),
                marker='*',
                alpha=0.05,
                label='err pred',
                color='C0',
                ecolor='C0')
            plt.scatter(np.arange(len(indices_test_pred)),
                        gp_r_test_pred_sorted.detach().cpu().numpy().flatten(),
                        marker='*',
                        color='C0',
                        alpha=0.2,
                        s=10)
            # plt.scatter(np.arange(len(indices_test_pred)),
            #             (gp_r_test_pred_sorted + gp_r_test_pred_std_sorted).detach().cpu().numpy(),
            #             label='err pred mean+std', marker='.')
            # plt.scatter(np.arange(len(indices_test_pred)),
            #             (gp_r_test_pred_sorted - gp_r_test_pred_std_sorted).detach().cpu().numpy(),
            #             label='err pred mean-std', marker='.')
            plt.legend()
            plt.title('error predictions and uncertainty on test set')
            plt.savefig(
                os.path.join(gp_test_folder, 'gp_error_test_uncertainty.pdf'))
            plt.close()

        # y var plots
        y_train_std_sorted, indices_train = torch.sort(y_train)
        gp_y_train_pred_sorted = gp_obj_model.posterior(
            x_train).mean[indices_train].view(-1)
        gp_y_train_pred_std_sorted = gp_obj_model.posterior(
            x_train).variance.sqrt()[indices_train].view(-1)
        plt.scatter(np.arange(len(indices_train)),
                    y_train_std_sorted.cpu().numpy(),
                    label='y true',
                    marker='+',
                    color='C1',
                    s=15)
        plt.scatter(np.arange(len(indices_train)),
                    gp_y_train_pred_sorted.detach().cpu().numpy(),
                    marker='*',
                    alpha=0.2,
                    s=10,
                    color='C0')
        plt.errorbar(
            np.arange(len(indices_train)),
            gp_y_train_pred_sorted.detach().cpu().numpy().flatten(),
            yerr=gp_y_train_pred_std_sorted.detach().cpu().numpy().flatten(),
            fmt='*',
            alpha=0.05,
            label='y pred',
            color='C0',
            ecolor='C0')
        # plt.scatter(np.arange(len(indices_train_pred)),
        #             (gp_y_train_pred_sorted+gp_y_train_pred_std_sorted).detach().cpu().numpy(),
        #             label='y pred mean+std', marker='.')
        # plt.scatter(np.arange(len(indices_train_pred)),
        #             (gp_y_train_pred_sorted-gp_y_train_pred_std_sorted).detach().cpu().numpy(),
        #             label='y pred mean-std', marker='.')
        plt.legend()
        plt.title('y predictions and uncertainty on train set')
        plt.savefig(
            os.path.join(gp_test_folder, 'gp_obj_val_train_uncertainty.pdf'))
        plt.close()

        y_test_std_sorted, indices_test = torch.sort(y_test)
        gp_y_test_pred_sorted = gp_obj_model.posterior(x_test).mean.view(
            -1)[indices_test]
        gp_y_test_pred_std_sorted = gp_obj_model.posterior(
            x_test).variance.view(-1).sqrt()[indices_test]
        plt.scatter(np.arange(len(indices_test)),
                    y_test_std_sorted.cpu().numpy(),
                    label='y true',
                    marker='+',
                    color='C1',
                    s=15)
        plt.errorbar(
            np.arange(len(indices_test)),
            gp_y_test_pred_sorted.detach().cpu().numpy().flatten(),
            yerr=gp_y_test_pred_std_sorted.detach().cpu().numpy().flatten(),
            fmt='*',
            alpha=0.05,
            label='y pred',
            color='C0',
            ecolor='C0')
        plt.scatter(np.arange(len(indices_test)),
                    gp_y_test_pred_sorted.detach().cpu().numpy(),
                    marker='*',
                    alpha=0.2,
                    s=10,
                    color='C0')
        # plt.scatter(np.arange(len(indices_test_pred)),
        #             (gp_y_test_pred_sorted + gp_y_test_pred_std_sorted).detach().cpu().numpy(),
        #             label='y pred mean+std', marker='.')
        # plt.scatter(np.arange(len(indices_test_pred)),
        #             (gp_y_test_pred_sorted - gp_y_test_pred_std_sorted).detach().cpu().numpy(),
        #             label='y pred mean-std', marker='.')
        plt.legend()
        plt.title('y predictions and uncertainty on test set')
        plt.savefig(
            os.path.join(gp_test_folder, 'gp_obj_val_test_uncertainty.pdf'))
        plt.close()
import torch
from botorch.test_functions import Branin
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_model
from botorch.models.transforms import Standardize
from gpytorch.mlls import ExactMarginalLogLikelihood
from parametric_bandit.discrete_KG import DiscreteKGAlg

torch.manual_seed(0)

# generate input
n = 10
noise_std = 0.1
function = Branin(noise_std=0.1)
dim = function.dim
train_X = torch.rand((n, dim))
train_Y = function(train_X).unsqueeze(-1)

# fit model
gp = SingleTaskGP(train_X, train_Y, outcome_transform=Standardize(m=1))
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
fit_gpytorch_model(mll)

# get mu and Sigma
mu = gp.posterior(train_X).mean
Sigma = gp.posterior(train_X).mvn.covariance_matrix

# initiate the algorithm for testing
dkg = DiscreteKGAlg(M=n, error=noise_std**2, mu_0=mu, Sigma_0=Sigma)
print(dkg.find_maximizer())
Example #4
0
    def test_cache_root(self):
        sample_cached_path = (
            "botorch.acquisition.cached_cholesky.sample_cached_cholesky")
        raw_state_dict = {
            "likelihood.noise_covar.raw_noise":
            torch.tensor([[0.0895], [0.2594]], dtype=torch.float64),
            "mean_module.constant":
            torch.tensor([[-0.4545], [-0.1285]], dtype=torch.float64),
            "covar_module.raw_outputscale":
            torch.tensor([1.4876, 1.4897], dtype=torch.float64),
            "covar_module.base_kernel.raw_lengthscale":
            torch.tensor([[[-0.7202, -0.2868]], [[-0.8794, -1.2877]]],
                         dtype=torch.float64),
        }
        # test batched models (e.g. for MCMC)
        for train_batch_shape, m, dtype in product(
            (torch.Size([]), torch.Size([3])), (1, 2),
            (torch.float, torch.double)):
            state_dict = deepcopy(raw_state_dict)
            for k, v in state_dict.items():
                if m == 1:
                    v = v[0]
                if len(train_batch_shape) > 0:
                    v = v.unsqueeze(0).expand(*train_batch_shape, *v.shape)
                state_dict[k] = v
            tkwargs = {"device": self.device, "dtype": dtype}
            if m == 2:
                objective = GenericMCObjective(lambda Y, X: Y.sum(dim=-1))
            else:
                objective = None
            for k, v in state_dict.items():
                state_dict[k] = v.to(**tkwargs)
            all_close_kwargs = ({
                "atol": 1e-1,
                "rtol": 0.0,
            } if dtype == torch.float else {
                "atol": 1e-4,
                "rtol": 0.0
            })
            torch.manual_seed(1234)
            train_X = torch.rand(*train_batch_shape, 3, 2, **tkwargs)
            train_Y = (
                torch.sin(train_X * 2 * pi) +
                torch.randn(*train_batch_shape, 3, 2, **tkwargs))[..., :m]
            train_Y = standardize(train_Y)
            model = SingleTaskGP(
                train_X,
                train_Y,
            )
            if len(train_batch_shape) > 0:
                X_baseline = train_X[0]
            else:
                X_baseline = train_X
            model.load_state_dict(state_dict, strict=False)
            # test sampler with collapse_batch_dims=False
            sampler = IIDNormalSampler(5, seed=0, collapse_batch_dims=False)
            with self.assertRaises(UnsupportedError):
                qNoisyExpectedImprovement(
                    model=model,
                    X_baseline=X_baseline,
                    sampler=sampler,
                    objective=objective,
                    prune_baseline=False,
                    cache_root=True,
                )
            sampler = IIDNormalSampler(5, seed=0)
            torch.manual_seed(0)
            acqf = qNoisyExpectedImprovement(
                model=model,
                X_baseline=X_baseline,
                sampler=sampler,
                objective=objective,
                prune_baseline=False,
                cache_root=True,
            )

            orig_base_samples = acqf.base_sampler.base_samples.detach().clone()
            sampler2 = IIDNormalSampler(5, seed=0)
            sampler2.base_samples = orig_base_samples
            torch.manual_seed(0)
            acqf_no_cache = qNoisyExpectedImprovement(
                model=model,
                X_baseline=X_baseline,
                sampler=sampler2,
                objective=objective,
                prune_baseline=False,
                cache_root=False,
            )
            for q, batch_shape in product(
                (1, 3), (torch.Size([]), torch.Size([3]), torch.Size([4, 3]))):
                test_X = (0.3 +
                          0.05 * torch.randn(*batch_shape, q, 2, **tkwargs)
                          ).requires_grad_(True)
                with mock.patch(
                        sample_cached_path,
                        wraps=sample_cached_cholesky) as mock_sample_cached:
                    torch.manual_seed(0)
                    val = acqf(test_X)
                    mock_sample_cached.assert_called_once()
                val.sum().backward()
                base_samples = acqf.sampler.base_samples.detach().clone()
                X_grad = test_X.grad.clone()
                test_X2 = test_X.detach().clone().requires_grad_(True)
                acqf_no_cache.sampler.base_samples = base_samples
                with mock.patch(
                        sample_cached_path,
                        wraps=sample_cached_cholesky) as mock_sample_cached:
                    torch.manual_seed(0)
                    val2 = acqf_no_cache(test_X2)
                mock_sample_cached.assert_not_called()
                self.assertTrue(torch.allclose(val, val2, **all_close_kwargs))
                val2.sum().backward()
                self.assertTrue(
                    torch.allclose(X_grad, test_X2.grad, **all_close_kwargs))
            # test we fall back to standard sampling for
            # ill-conditioned covariances
            acqf._baseline_L = torch.zeros_like(acqf._baseline_L)
            with warnings.catch_warnings(
                    record=True) as ws, settings.debug(True):
                with torch.no_grad():
                    acqf(test_X)
            self.assertEqual(len(ws), 1)
            self.assertTrue(issubclass(ws[-1].category, BotorchWarning))

        # test w/ posterior transform
        X_baseline = torch.rand(2, 1)
        model = SingleTaskGP(X_baseline, torch.randn(2, 1))
        pt = ScalarizedPosteriorTransform(weights=torch.tensor([-1]))
        with mock.patch.object(
                qNoisyExpectedImprovement,
                "_cache_root_decomposition",
        ) as mock_cache_root:
            acqf = qNoisyExpectedImprovement(
                model=model,
                X_baseline=X_baseline,
                sampler=IIDNormalSampler(1),
                posterior_transform=pt,
                prune_baseline=False,
                cache_root=True,
            )
            tf_post = model.posterior(X_baseline, posterior_transform=pt)
            self.assertTrue(
                torch.allclose(
                    tf_post.mean,
                    mock_cache_root.call_args[-1]["posterior"].mean))
        y_new_raw = noisy_nonlin_fcn(x_new_raw, noise_std=noise_std)

        # Plot the model
        fig = plt.figure(figsize=(8, 6))
        gs = fig.add_gridspec(2, 1, height_ratios=[2, 1])
        ax_gp = fig.add_subplot(gs[0])
        ax_acq = fig.add_subplot(gs[1])
        X_test = to.linspace(x_min, x_max, 501)
        X_test_raw = to.linspace(x_min_raw, x_max_raw, 501)

        with to.no_grad():
            # Get the observations
            y_test_raw = noisy_nonlin_fcn(X_test_raw, noise_std=noise_std)

            # Get the posterior
            posterior = gp.posterior(X_test)
            mean = posterior.mean
            mean_raw = mean * y_test_raw.std() + y_test_raw.mean()
            lower, upper = posterior.mvn.confidence_region()
            lower = lower * y_test_raw.std() + y_test_raw.mean()
            upper = upper * y_test_raw.std() + y_test_raw.mean()

            ax_gp.plot(X_test_raw.numpy(),
                       y_test_raw.numpy(),
                       "k--",
                       label="f(x)")

            ax_gp.plot(X_test_raw.numpy(),
                       mean_raw.numpy(),
                       "b-",
                       lw=2,
def render_singletask_gp(
    ax: [plt.Axes, Axes3D, Sequence[plt.Axes]],
    data_x: to.Tensor,
    data_y: to.Tensor,
    idcs_sel: list,
    data_x_min: to.Tensor = None,
    data_x_max: to.Tensor = None,
    x_label: str = '',
    y_label: str = '',
    z_label: str = '',
    min_gp_obsnoise: float = None,
    resolution: int = 201,
    num_stds: int = 2,
    alpha: float = 0.3,
    color: chr = None,
    curve_label: str = 'mean',
    heatmap_cmap: colors.Colormap = None,
    show_legend_posterior: bool = True,
    show_legend_std: bool = False,
    show_legend_data: bool = True,
    legend_data_cmap: colors.Colormap = None,
    colorbar_label: str = None,
    title: str = None,
    render3D: bool = True,
) -> plt.Figure:
    """
    Fit the GP posterior to the input data and plot the mean and std as well as the data points.
    There are 3 options: 1D plot (infered by data dimensions), 2D plot

    .. note::
        If you want to have a tight layout, it is best to pass axes of a figure with `tight_layout=True` or
        `constrained_layout=True`.

    :param ax: axis of the figure to plot on, only in case of a 2-dim heat map plot provide 2 axis
    :param data_x: data to plot on the x-axis
    :param data_y: data to process and plot on the y-axis
    :param idcs_sel: selected indices of the input data
    :param data_x_min: explicit minimum value for the evaluation grid, by default this value is extracted from `data_x`
    :param data_x_max: explicit maximum value for the evaluation grid, by default this value is extracted from `data_x`
    :param x_label: label for x-axis
    :param y_label: label for y-axis
    :param z_label: label for z-axis (3D plot only)
    :param min_gp_obsnoise: set a minimal noise value (normalized) for the GP, if `None` the GP has no measurement noise
    :param resolution: number of samples for the input (corresponds to x-axis resolution of the plot)
    :param num_stds: number of standard deviations to plot around the mean
    :param alpha: transparency (alpha-value) for the std area
    :param color: color (e.g. 'k' for black), `None` invokes the default behavior
    :param curve_label: label for the mean curve (1D plot only)
    :param heatmap_cmap: color map forwarded to `render_heatmap()` (2D plot only), `None` to use Pyrado's default
    :param show_legend_posterior: flag if the legend entry for the posterior should be printed (affects mean and std)
    :param show_legend_std: flag if a legend entry for the std area should be printed
    :param show_legend_data: flag if a legend entry for the individual data points should be printed
    :param legend_data_cmap: color map for the sampled points, default is 'binary'
    :param colorbar_label: label for the color bar (2D plot only)
    :param title: title displayed above the figure, set to `None` to suppress the title
    :param render3D: use 3D rendering if possible
    :return: handle to the resulting figure
    """
    if data_x.ndim != 2:
        raise pyrado.ShapeErr(
            msg=
            "The GP's input data needs to be of shape num_samples x dim_input!"
        )
    data_x = data_x[:, idcs_sel]  # forget the rest
    dim_x = data_x.shape[1]  # samples are along axis 0

    if data_y.ndim != 2:
        raise pyrado.ShapeErr(given=data_y,
                              expected_match=to.Size([data_x.shape[0], 1]))

    if legend_data_cmap is None:
        legend_data_cmap = plt.get_cmap('binary')

    # Project to normalized input and standardized output
    if data_x_min is None or data_x_max is None:
        data_x_min, data_x_max = to.min(data_x, dim=0)[0], to.max(data_x,
                                                                  dim=0)[0]
    data_y_mean, data_y_std = to.mean(data_y, dim=0), to.std(data_y, dim=0)
    data_x = (data_x - data_x_min) / (data_x_max - data_x_min)
    data_y = (data_y - data_y_mean) / data_y_std

    # Create and fit the GP model
    gp = SingleTaskGP(data_x, data_y)
    if min_gp_obsnoise is not None:
        gp.likelihood.noise_covar.register_constraint(
            'raw_noise', GreaterThan(min_gp_obsnoise))
    mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
    mll.train()
    fit_gpytorch_model(mll)
    print_cbt('Fitted the SingleTaskGP.', 'g')

    argmax_pmean_norm, argmax_pmean_val_stdzed = optimize_acqf(
        acq_function=PosteriorMean(gp),
        bounds=to.stack([to.zeros(dim_x), to.ones(dim_x)]),
        q=1,
        num_restarts=500,
        raw_samples=1000)
    # Project back
    argmax_posterior = argmax_pmean_norm * (data_x_max -
                                            data_x_min) + data_x_min
    argmax_pmean_val = argmax_pmean_val_stdzed * data_y_std + data_y_mean
    print_cbt(
        f'Converged to argmax of the posterior mean: {argmax_posterior.numpy()}',
        'g')

    mll.eval()
    gp.eval()

    if dim_x == 1:
        # Evaluation grid
        x_grid = np.linspace(min(data_x),
                             max(data_x),
                             resolution,
                             endpoint=True).flatten()
        x_grid = to.from_numpy(x_grid)

        # Mean and standard deviation of the surrogate model
        posterior = gp.posterior(x_grid)
        mean = posterior.mean.detach().flatten()
        std = to.sqrt(posterior.variance.detach()).flatten()

        # Project back from normalized input and standardized output
        x_grid = x_grid * (data_x_max - data_x_min) + data_x_min
        data_x = data_x * (data_x_max - data_x_min) + data_x_min
        data_y = data_y * data_y_std + data_y_mean
        mean = mean * data_y_std + data_y_mean
        std *= data_y_std  # double-checked with posterior.mvn.confidence_region()

        # Plot the curve
        plt.fill_between(x_grid.numpy(),
                         mean.numpy() - num_stds * std.numpy(),
                         mean.numpy() + num_stds * std.numpy(),
                         alpha=alpha,
                         color=color)
        ax.plot(x_grid.numpy(), mean.numpy(), color=color)

        # Plot the queried data points
        scat_plot = ax.scatter(data_x.numpy().flatten(),
                               data_y.numpy().flatten(),
                               marker='o',
                               c=np.arange(data_x.shape[0], dtype=np.int),
                               cmap=legend_data_cmap)

        if show_legend_data:
            scat_legend = ax.legend(
                *scat_plot.legend_elements(fmt='{x:.0f}'),  # integer formatter
                bbox_to_anchor=(0., 1.1, 1., -0.1),
                title='query points',
                ncol=data_x.shape[0],
                loc='upper center',
                mode='expand',
                borderaxespad=0.,
                handletextpad=-0.5)
            ax.add_artist(scat_legend)
            # Increase vertical space between subplots when printing the data labels
            # plt.tight_layout(pad=2.)  # ignore argument
            # plt.subplots_adjust(hspace=0.6)

        # Plot the argmax of the posterior mean
        # ax.scatter(argmax_posterior.item(), argmax_pmean_val, c='darkorange', marker='o', s=60, label='argmax')
        ax.axvline(argmax_posterior.item(),
                   c='darkorange',
                   lw=1.5,
                   label='argmax')

        if show_legend_posterior:
            ax.add_artist(ax.legend(loc='lower right'))

    elif dim_x == 2:
        # Create mesh grid matrices from x and y vectors
        # x0_grid = to.linspace(min(data_x[:, 0]), max(data_x[:, 0]), resolution)
        # x1_grid = to.linspace(min(data_x[:, 1]), max(data_x[:, 1]), resolution)
        x0_grid = to.linspace(0, 1, resolution)
        x1_grid = to.linspace(0, 1, resolution)
        x0_mesh, x1_mesh = to.meshgrid([x0_grid, x1_grid])
        x0_mesh, x1_mesh = x0_mesh.t(), x1_mesh.t(
        )  # transpose not necessary but makes identical mesh as np.meshgrid

        # Mean and standard deviation of the surrogate model
        x_test = to.stack([
            x0_mesh.reshape(resolution**2, 1),
            x1_mesh.reshape(resolution**2, 1)
        ], -1).squeeze(1)
        posterior = gp.posterior(
            x_test)  # identical to  gp.likelihood(gp(x_test))
        mean = posterior.mean.detach().reshape(resolution, resolution)
        std = to.sqrt(posterior.variance.detach()).reshape(
            resolution, resolution)

        # Project back from normalized input and standardized output
        data_x = data_x * (data_x_max - data_x_min) + data_x_min
        data_y = data_y * data_y_std + data_y_mean
        mean_raw = mean * data_y_std + data_y_mean
        std_raw = std * data_y_std

        if render3D:
            # Project back from normalized input and standardized output (custom for 3D)
            x0_mesh = x0_mesh * (data_x_max[0] - data_x_min[0]) + data_x_min[0]
            x1_mesh = x1_mesh * (data_x_max[1] - data_x_min[1]) + data_x_min[1]
            lower = mean_raw - num_stds * std_raw
            upper = mean_raw + num_stds * std_raw

            # Plot a 2D surface in 3D
            ax.plot_surface(x0_mesh.numpy(), x1_mesh.numpy(), mean_raw.numpy())
            ax.plot_surface(x0_mesh.numpy(),
                            x1_mesh.numpy(),
                            lower.numpy(),
                            color='r',
                            alpha=alpha)
            ax.plot_surface(x0_mesh.numpy(),
                            x1_mesh.numpy(),
                            upper.numpy(),
                            color='r',
                            alpha=alpha)
            ax.set_xlabel(x_label)
            ax.set_ylabel(y_label)
            ax.set_zlabel(z_label)

            # Plot the queried data points
            scat_plot = ax.scatter(data_x[:, 0].numpy(),
                                   data_x[:, 1].numpy(),
                                   data_y.numpy(),
                                   marker='o',
                                   c=np.arange(data_x.shape[0], dtype=np.int),
                                   cmap=legend_data_cmap)

            if show_legend_data:
                scat_legend = ax.legend(
                    *scat_plot.legend_elements(
                        fmt='{x:.0f}'),  # integer formatter
                    bbox_to_anchor=(0.05, 1.1, 0.95, -0.1),
                    loc='upper center',
                    ncol=data_x.shape[0],
                    mode='expand',
                    borderaxespad=0.,
                    handletextpad=-0.5)
                ax.add_artist(scat_legend)

            # Plot the argmax of the posterior mean
            x, y = argmax_posterior[0, 0], argmax_posterior[0, 1]
            ax.scatter(x,
                       y,
                       argmax_pmean_val,
                       c='darkorange',
                       marker='*',
                       s=60)
            # ax.plot((x, x), (y, y), (data_y.min(), data_y.max()), c='k', ls='--', lw=1.5)

        else:
            if not len(ax) == 4:
                raise pyrado.ShapeErr(
                    msg='Provide 4 axes! 2 heat maps and 2 color bars.')

            # Project back normalized input and standardized output (custom for 2D)
            x0_grid_raw = x0_grid * (data_x_max[0] -
                                     data_x_min[0]) + data_x_min[0]
            x1_grid_raw = x1_grid * (data_x_max[1] -
                                     data_x_min[1]) + data_x_min[1]

            # Plot a 2D image
            df_mean = pd.DataFrame(mean_raw.numpy(),
                                   columns=x0_grid_raw.numpy(),
                                   index=x1_grid_raw.numpy())
            render_heatmap(df_mean,
                           ax_hm=ax[0],
                           ax_cb=ax[1],
                           x_label=x_label,
                           y_label=y_label,
                           annotate=False,
                           fig_canvas_title='Returns',
                           tick_label_prec=2,
                           add_sep_colorbar=True,
                           cmap=heatmap_cmap,
                           colorbar_label=colorbar_label,
                           num_major_ticks_hm=3,
                           num_major_ticks_cb=2,
                           colorbar_orientation='horizontal')

            df_std = pd.DataFrame(std_raw.numpy(),
                                  columns=x0_grid_raw.numpy(),
                                  index=x1_grid_raw.numpy())
            render_heatmap(
                df_std,
                ax_hm=ax[2],
                ax_cb=ax[3],
                x_label=x_label,
                y_label=y_label,
                annotate=False,
                fig_canvas_title='Standard Deviations',
                tick_label_prec=2,
                add_sep_colorbar=True,
                cmap=heatmap_cmap,
                colorbar_label=colorbar_label,
                num_major_ticks_hm=3,
                num_major_ticks_cb=2,
                colorbar_orientation='horizontal',
                norm=colors.Normalize())  # explicitly instantiate a new norm

            # Plot the queried data points
            for i in [0, 2]:
                scat_plot = ax[i].scatter(data_x[:, 0].numpy(),
                                          data_x[:, 1].numpy(),
                                          marker='o',
                                          s=15,
                                          c=np.arange(data_x.shape[0],
                                                      dtype=np.int),
                                          cmap=legend_data_cmap)

                if show_legend_data:
                    scat_legend = ax[i].legend(
                        *scat_plot.legend_elements(
                            fmt='{x:.0f}'),  # integer formatter
                        bbox_to_anchor=(0., 1.1, 1., 0.05),
                        loc='upper center',
                        ncol=data_x.shape[0],
                        mode='expand',
                        borderaxespad=0.,
                        handletextpad=-0.5)
                    ax[i].add_artist(scat_legend)

            # Plot the argmax of the posterior mean
            ax[0].scatter(argmax_posterior[0, 0],
                          argmax_posterior[0, 1],
                          c='darkorange',
                          marker='*',
                          s=60)  # steelblue
            ax[2].scatter(argmax_posterior[0, 0],
                          argmax_posterior[0, 1],
                          c='darkorange',
                          marker='*',
                          s=60)  # steelblue
            # ax[0].axvline(argmax_posterior[0, 0], c='w', ls='--', lw=1.5)
            # ax[0].axhline(argmax_posterior[0, 1], c='w', ls='--', lw=1.5)
            # ax[2].axvline(argmax_posterior[0, 0], c='w', ls='--', lw=1.5)
            # ax[2].axhline(argmax_posterior[0, 1], c='w', ls='--', lw=1.5)

    else:
        raise pyrado.ValueErr(msg='Can only plot 1-dim or 2-dim data!')

    return plt.gcf()
Example #7
0
class Experiment:
    """
    The class for running experiments
    """

    # dict of expected attributes and default values
    attr_list = {
        "dim_w": 1,
        "num_fantasies": 10,
        "num_restarts": 20,
        "raw_multiplier": 25,
        "alpha": 0.7,
        "q": 1,
        "num_repetitions": 10,
        "verbose": False,
        "maxiter": 1000,
        "CVaR": False,
        "random_sampling": False,
        "expectation": False,
        "dtype": torch.float32,
        "device": torch.device("cpu"),
        "apx": True,
        "apx_cvar": False,
        "tts_apx_cvar": False,
        "disc": True,
        "tts_frequency": 10,
        "num_inner_restarts": 10,
        "inner_raw_multiplier": 5,
        "weights": None,
        "fix_samples": True,
        "one_shot": False,
        "low_fantasies": 4,
        "random_w": False,
    }

    def __init__(self, function: str, **kwargs) -> None:
        """
        The experiment settings:
        :param function: The problem function to be used.
        :param noise_std: standard deviation of the function evaluation noise.
        :param dim_w: Dimension of the w component.
        :param num_samples: Number of samples of w to be used to evaluate C/VaR.
        :param w_samples: option to explicitly specify the samples. If given,
            num_samples is ignored. One of these is necessary!
        :param num_fantasies: Number of fantasy models to construct in evaluating rhoKG.
        :param num_restarts: Number of random restarts for optimization of rhoKG.
        :param raw_multiplier: Raw_samples = num_restarts * raw_multiplier
        :param alpha: The risk level of C/VaR.
        :param q: Number of parallel solutions to evaluate. Think qKG.
        :param num_repetitions: Number of posterior samples used for E[rho[F]]
        :param verbose: Print more stuff, such as current best value.
        :param maxiter: (Maximum) number of iterations allowed for L-BFGS-B algorithm.
        :param CVaR: If true, use CVaR instead of VaR, i.e. CVaRKG. The default is VaR.
        :param random_sampling: If true, we will use random sampling - no KG.
        :param expectation: If true, we are running BQO optimization.
        :param dtype: The tensor dtype for the experiment
        :param device: The device to use. Defaults to CPU.
        :param apx: If True, the rhoKGapx algorithm is used.
        :param apx_cvar: If True, we use ApxCVaRKG. Overwrites other options!
        :param tts_apx_cvar: If True, we use TTSApxCVaRKG. Overwrites other options!
        :param disc: If True, the optimization of acqf is done with w restricted to
            the set w_samples
        :param tts_frequency: The frequency of two-time-scale optimization.
            If 1, we do normal nested optimization. Default is 1.
        :param num_inner_restarts: Inner restarts for nested optimization
        :param inner_raw_multiplier: raw multipler for nested optimization
        :param weights: If w_samples are not uniformly distributed, these are the sample
            weights, summing up to 1, i.e. probability mass function.
            A 1-dim tensor of size num_samples
        :param fix_samples: When W is continuous, this determines whether the samples
            are redrawn at each call to rhoKG or fixed to a random realization.
            If w_samples are specified, this gets overwritten.
        :param one_shot: Uses one-shot optimization.
            DO NOT USE unless you know what you're doing.
        :param low_fantasies: see AbsKG.change_num_fantasies for details. This reduces
            the number of fantasies used during raw sample evaluation to reduce the
            computational cost.
        :param random_w: If this is True, the w component of the candidate is fixed to
            a random realization instead of being optimized. This is only for
            presenting a comparison in the paper, and should not be used.
        """
        if "seed" in kwargs.keys():
            warnings.warn("Seed should be set outside. It will be ignored!")
        self.function = function_picker(
            function,
            noise_std=kwargs.get("noise_std"),
            negate=getattr(kwargs, "negate", False),
        )
        self.dim = self.function.dim
        # read the attributes with default values
        # set the defaults first, then overwrite.
        # this lets us store everything passed with kwargs
        for key in self.attr_list.keys():
            setattr(self, key, self.attr_list[key])
        for key in kwargs.keys():
            setattr(self, key, kwargs[key])
        self.dim_x = self.dim - self.dim_w
        if kwargs.get("w_samples") is not None:
            self.w_samples = (
                kwargs["w_samples"]
                .reshape(-1, self.dim_w)
                .to(dtype=self.dtype, device=self.device)
            )
            self.num_samples = self.w_samples.size(0)
            self.fixed_samples = True
        elif "num_samples" in kwargs.keys():
            self.num_samples = kwargs["num_samples"]
            self.w_samples = None
            warnings.warn("w_samples is None and will be randomized at each iteration")
        else:
            raise ValueError("Either num_samples or w_samples must be specified!")
        if self.expectation:
            self.num_repetitions = 0
        if self.weights is not None:
            self.weights = self.weights.to(self.w_samples)
        self.X = torch.empty(0, self.dim).to(dtype=self.dtype, device=self.device)
        self.Y = torch.empty(0, 1).to(dtype=self.dtype, device=self.device)
        self.model = None
        self.low_fantasies = kwargs.get("low_fantasies", None)
        if self.apx_cvar and self.tts_apx_cvar:
            raise ValueError(
                "apx_cvar and tts_apx_cvar cannot be true at the same time!"
            )

        if self.tts_apx_cvar:
            inner_optimizer = InnerApxCVaROptimizer
        else:
            inner_optimizer = InnerOptimizer

        self.inner_optimizer = inner_optimizer(
            num_restarts=self.num_inner_restarts,
            raw_multiplier=self.inner_raw_multiplier,
            dim_x=self.dim_x,
            maxiter=self.maxiter,
            inequality_constraints=self.function.inequality_constraints,
            dtype=self.dtype,
            device=self.device,
        )
        if self.apx_cvar:
            optimizer = ApxCVaROptimizer
        elif self.one_shot:
            optimizer = OneShotOptimizer
        else:
            optimizer = Optimizer
        self.optimizer = optimizer(
            num_restarts=self.num_restarts,
            raw_multiplier=self.raw_multiplier,
            num_fantasies=self.num_fantasies,
            dim=self.dim,
            dim_x=self.dim_x,
            q=self.q,
            maxiter=self.maxiter,
            inequality_constraints=self.function.inequality_constraints,
            low_fantasies=self.low_fantasies,
            dtype=self.dtype,
            device=self.device,
        )
        if self.fix_samples:
            self.fixed_samples = self.w_samples
        else:
            self.fixed_samples = None

        self.passed = False  # error handling
        self.fit_count = 0

    def change_dtype_device(
        self, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None
    ) -> None:
        r"""
        This changes the dtype and device of all experiment tensors, and refits the GP
        model.
        :param dtype: The torch.dtype to use
        :param device: The device to use
        """
        if dtype is None and device is None:
            return None
        dtype = dtype or self.dtype
        device = device or self.device
        for key, value in vars(self).items():
            if isinstance(value, Tensor):
                setattr(self, key, value.to(dtype=dtype, device=device))
        self.dtype = dtype
        self.device = torch.device(device)
        if self.X.numel():
            self.fit_gp()

    def initialize_gp(self, init_samples: Tensor = None, n: int = None) -> None:
        """
        Initialize the gp with the given set of samples or number of samples.
        If none given, then defaults to n = 2 dim + 2 random samples.
        :param init_samples: Tensor of samples to initialize with. Overrides n.
        :param n: number of samples to initialize with
        """
        if init_samples is not None:
            self.X = init_samples.reshape(-1, self.dim).to(
                dtype=self.dtype, device=self.device
            )
        else:
            self.X = constrained_rand(
                (n or 2 * self.dim + 2, self.dim),
                self.function.inequality_constraints,
                dtype=self.dtype,
                device=self.device,
            )
        self.Y = self.function(self.X)
        self.fit_gp()

    def fit_gp(self) -> None:
        """
        Re-fits the GP using the most up to date data.
        """
        noise_prior = GammaPrior(1.1, 0.5)
        noise_prior_mode = (noise_prior.concentration - 1) / noise_prior.rate
        likelihood = GaussianLikelihood(
            noise_prior=noise_prior,
            batch_shape=[],
            noise_constraint=GreaterThan(
                # 0.000005,  # minimum observation noise assumed in the GP model
                0.0001,
                transform=None,
                initial_value=noise_prior_mode,
            ),
        )

        self.model = SingleTaskGP(
            self.X, self.Y, likelihood, outcome_transform=Standardize(m=1)
        )
        mll = ExactMarginalLogLikelihood(self.model.likelihood, self.model)
        fit_gpytorch_model(mll)

        # dummy computation to be safe with gp fit
        try:
            dummy = torch.rand(
                (1, self.q, self.dim), dtype=self.dtype, device=self.device
            )
            _ = self.model.posterior(dummy).mean
        except RuntimeError as err:
            if self.fit_count < 5:
                self.fit_count += 1
                self.Y = self.Y + torch.randn_like(self.Y) * 0.001
                self.fit_gp()
            else:
                raise err
        self.fit_count = 0
        self.passed = False

    def current_best(
        self, past_only: bool = False, inner_seed: int = None
    ) -> Tuple[Tensor, Tensor]:
        """
        Solve the inner optimization problem to return the current optimum
        :param past_only: If true, maximize over previously evaluated x only.
        :param inner_seed: Used for sampling randomness in InnerRho
        :return: Current best solution and value
        """
        if self.w_samples is None:
            w_samples = torch.rand(
                self.num_samples, self.dim_w, dtype=self.dtype, device=self.device
            )
        else:
            w_samples = self.w_samples
        inner_rho = InnerRho(
            inner_seed=inner_seed,
            **{_: vars(self)[_] for _ in vars(self) if _ != "w_samples"},
            w_samples=w_samples
        )
        if past_only:
            past_x = self.X[:, : self.dim_x]
            with torch.no_grad():
                values = inner_rho(past_x)
            best = torch.argmax(values)
            current_best_sol = past_x[best]
            current_best_value = -values[best]
        else:
            current_best_sol, current_best_value = self.optimizer.optimize_inner(
                inner_rho
            )
        if self.verbose:
            print(
                "Current best solution, value: ", current_best_sol, current_best_value
            )
        return current_best_sol, current_best_value

    def one_iteration(self, **kwargs) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        """
        Do a single iteration of the algorithm
        :param kwargs: ignored
        :return: current best solution & value, kg value and candidate (next sample)
        """
        iteration_start = time()
        inner_seed = int(torch.randint(100000, (1,)))
        self.optimizer.new_iteration()
        self.inner_optimizer.new_iteration()
        current_best_sol, current_best_value = self.current_best(
            past_only=self.apx, inner_seed=inner_seed
        )

        if self.random_sampling:
            candidate = constrained_rand(
                (self.q, self.dim),
                self.function.inequality_constraints,
                dtype=self.dtype,
                device=self.device,
            )
            value = torch.tensor([0]).to(candidate)
        else:
            if self.apx_cvar:
                acqf = ApxCVaRKG(current_best_rho=current_best_value, **vars(self))
            elif self.tts_apx_cvar:
                acqf = TTSApxCVaRKG(
                    current_best_rho=current_best_value,
                    inner_optimizer=self.inner_optimizer.optimize,
                    **{_: vars(self)[_] for _ in vars(self) if _ != "inner_optimizer"}
                )
            elif self.one_shot:
                acqf = OneShotrhoKG(
                    current_best_rho=current_best_value,
                    inner_seed=inner_seed,
                    **vars(self)
                )
            elif self.apx:
                acqf = rhoKGapx(
                    current_best_rho=current_best_value,
                    past_x=self.X[:, : self.dim_x],
                    inner_seed=inner_seed,
                    **vars(self)
                )
            else:
                acqf = rhoKG(
                    inner_optimizer=self.inner_optimizer.optimize,
                    current_best_rho=current_best_value,
                    inner_seed=inner_seed,
                    **{_: vars(self)[_] for _ in vars(self) if _ != "inner_optimizer"}
                )
            if self.disc:
                candidate, value = self.optimizer.optimize_outer(
                    acqf, self.w_samples, random_w=self.random_w
                )
            else:
                candidate, value = self.optimizer.optimize_outer(
                    acqf, random_w=self.random_w
                )
        candidate = candidate.detach()
        value = value.detach()

        if self.verbose:
            print("Candidate: ", candidate, " KG value: ", value)

        iteration_end = time()
        print("Iteration completed in %s" % (iteration_end - iteration_start))

        if self.one_shot or self.apx_cvar:
            candidate_point = candidate[..., : self.q * self.dim].reshape(
                self.q, self.dim
            )
        else:
            candidate_point = candidate.reshape(self.q, self.dim)

        observation = self.function(candidate_point)
        # update the model input data for refitting
        self.X = torch.cat((self.X, candidate_point), dim=0)
        self.Y = torch.cat((self.Y, observation), dim=0)

        # noting that X and Y are updated
        self.passed = True
        # construct and fit the GP
        self.fit_gp()
        # noting that gp fit successfully updated
        self.passed = False

        return current_best_sol, current_best_value, value, candidate_point
Example #8
0
  train_y = torch.cat([train_y, candidate_y])
  model = model.condition_on_observations(X=candidate_x, Y=candidate_y)

  # Train GP...
  mll = ExactMarginalLogLikelihood(model.likelihood, model)
  fit_gpytorch_model(mll)

  # Plotting...
  model.eval()

  fig, ax = plt.subplots(1, 1, figsize=(6, 4))
  plt.title(f"Bayesian Opt. without derivatives, Iteration {it}")
  test_x = torch.linspace(-1, 1, steps=100)

  with torch.no_grad():
    posterior = model.posterior(test_x)
    # these are 2 std devs from mean
    lower, upper = posterior.mvn.confidence_region()

    ax.plot(test_x.cpu().numpy(),
            obj(test_x).cpu().numpy(),
            'r--',
            label="true, noiseless objective")
    ax.plot(train_x.cpu().numpy(), train_y.cpu().numpy(), 'k*', alpha=0.1, label="observations")
    ax.plot(candidate_x.cpu().numpy(), candidate_y.cpu().numpy(), 'r*', label="candidate point")
    ax.plot(test_x.cpu().numpy(), posterior.mean.cpu().numpy(), 'b', label="GP posterior")
    ax.fill_between(test_x.cpu().numpy(), lower.cpu().numpy(), upper.cpu().numpy(), alpha=0.5)

  plt.legend(loc="lower left")
  plt.tight_layout()
  plt.savefig(f"/tmp/vanilla_bo_{it}.jpg")