def _setUp(self, double=False): dtype = torch.double if double else torch.float train_x = torch.linspace(0, 1, 10, device=self.device, dtype=dtype).unsqueeze(-1) train_y = torch.sin(train_x * (2 * math.pi)) train_yvar = torch.tensor(0.1**2, device=self.device) noise = torch.tensor(NOISE, device=self.device, dtype=dtype) self.train_x = train_x self.train_y = train_y + noise self.train_yvar = train_yvar self.bounds = torch.tensor([[0.0], [1.0]], device=self.device, dtype=dtype) model_st = SingleTaskGP(self.train_x, self.train_y) self.model_st = model_st.to(device=self.device, dtype=dtype) self.mll_st = ExactMarginalLogLikelihood(self.model_st.likelihood, self.model_st) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=OptimizationWarning) self.mll_st = fit_gpytorch_model(self.mll_st, options={"maxiter": 5}, max_retries=1) model_fn = FixedNoiseGP(self.train_x, self.train_y, self.train_yvar.expand_as(self.train_y)) self.model_fn = model_fn.to(device=self.device, dtype=dtype) self.mll_fn = ExactMarginalLogLikelihood(self.model_fn.likelihood, self.model_fn) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=OptimizationWarning) self.mll_fn = fit_gpytorch_model(self.mll_fn, options={"maxiter": 5}, max_retries=1)
def _setUp(self, double=False, cuda=False, expand=False): device = torch.device("cuda") if cuda else torch.device("cpu") dtype = torch.double if double else torch.float train_x = torch.linspace(0, 1, 10, device=device, dtype=dtype).unsqueeze(-1) train_y = torch.sin(train_x * (2 * math.pi)) noise = torch.tensor(NOISE, device=device, dtype=dtype) self.train_x = train_x self.train_y = train_y + noise if expand: self.train_x = self.train_x.expand(-1, 2) ics = torch.tensor([[0.5, 1.0]], device=device, dtype=dtype) else: ics = torch.tensor([[0.5]], device=device, dtype=dtype) self.initial_conditions = ics self.f_best = self.train_y.max().item() model = SingleTaskGP(self.train_x, self.train_y) self.model = model.to(device=device, dtype=dtype) self.mll = ExactMarginalLogLikelihood(self.model.likelihood, self.model) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=OptimizationWarning) self.mll = fit_gpytorch_model(self.mll, options={"maxiter": 1}, max_retries=1)
def _setUp(self, double=False, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") dtype = torch.double if double else torch.float train_x = torch.linspace(0, 1, 10, device=device, dtype=dtype).unsqueeze(-1) train_y = torch.sin(train_x * (2 * math.pi)).squeeze(-1) train_yvar = torch.tensor(0.1 ** 2, device=device) noise = torch.tensor(NOISE, device=device, dtype=dtype) self.train_x = train_x self.train_y = train_y + noise self.train_yvar = train_yvar self.bounds = torch.tensor([[0.0], [1.0]], device=device, dtype=dtype) model_st = SingleTaskGP(self.train_x, self.train_y) self.model_st = model_st.to(device=device, dtype=dtype) self.mll_st = ExactMarginalLogLikelihood( self.model_st.likelihood, self.model_st ) self.mll_st = fit_gpytorch_model(self.mll_st, options={"maxiter": 5}) model_fn = FixedNoiseGP( self.train_x, self.train_y, self.train_yvar.expand_as(self.train_y) ) self.model_fn = model_fn.to(device=device, dtype=dtype) self.mll_fn = ExactMarginalLogLikelihood( self.model_fn.likelihood, self.model_fn ) self.mll_fn = fit_gpytorch_model(self.mll_fn, options={"maxiter": 5})
def initialize_model(x0, y0, n=5): # initialize botorch GP model # generate prior xs and ys for GP train_x = 2 * torch.rand(n, latent_dim, device=device).float() - 1 if not args.inf_norm: train_x = latent_proj(train_x, args.eps) train_obj = obj_func(train_x, x0, y0) mean, std = train_obj.mean(), train_obj.std() if args.standardize: train_obj = (train_obj - train_obj.mean()) / train_obj.std() best_observed_value = train_obj.max().item() # define models for objective and constraint model = SingleTaskGP(train_X=train_x, train_Y=train_obj[:, None]) model = model.to(train_x) mll = ExactMarginalLogLikelihood(model.likelihood, model) mll = mll.to(train_x) return train_x, train_obj, mll, model, best_observed_value, mean, std
def _setUp(self, double=False, cuda=False, expand=False): device = torch.device("cuda") if cuda else torch.device("cpu") dtype = torch.double if double else torch.float train_x = torch.linspace(0, 1, 10, device=device, dtype=dtype).unsqueeze(-1) train_y = torch.sin(train_x * (2 * math.pi)).squeeze(-1) noise = torch.tensor(NOISE, device=device, dtype=dtype) self.train_x = train_x self.train_y = train_y + noise if expand: self.train_x = self.train_x.expand(-1, 2) ics = torch.tensor([[0.5, 1.0]], device=device, dtype=dtype) else: ics = torch.tensor([[0.5]], device=device, dtype=dtype) self.initial_conditions = ics self.f_best = self.train_y.max().item() model = SingleTaskGP(self.train_x, self.train_y) self.model = model.to(device=device, dtype=dtype) self.mll = ExactMarginalLogLikelihood(self.model.likelihood, self.model) self.mll = fit_gpytorch_model(self.mll, options={"maxiter": 1})
def gp_fit_test(x_train: Tensor, y_train: Tensor, error_train: Tensor, x_test: Tensor, y_test: Tensor, error_test: Tensor, gp_obj_model: SingleTaskGP, gp_error_model: SingleTaskGP, tkwargs: Dict[str, Any], gp_test_folder: str, obj_out_wp: bool = False, err_out_wp: bool = False) -> None: """ 1) Estimates mean test error between predicted and the true objective function values. 2) Estimates mean test error between predicted recon. error by the gp_model and the true recon. error of the vae_model. :param x_train: normalised points at which the gps were trained :param y_train: objective value function corresponding to x_train that were used as targets of `gp_obj_model` :param error_train: reconstruction error value at points x_train that were used as targets of `gp_error_model` :param x_test: normalised test points :param y_test: objective value function corresponding to x_test :param error_test: reconstruction error at test points :param gp_obj_model: the gp model trained to predict the black box objective function values :param gp_error_model: the gp model trained to predict reconstruction error :param tkwargs: dict of type and device :param gp_test_folder: folder to save test results :param obj_out_wp: if the `gp_obj_model` was trained with output warping then need to apply the same transform :param err_out_wp: if the `gp_error_model` was trained with output warping then need to apply the same transform :return: (Sum_i||true_y_i - pred_y_i||^2 / n_points, Sum_i||true_recon_i - pred_recon_i||^2 / n_points) """ do_robust = True if gp_error_model is not None else False if not os.path.exists(gp_test_folder): os.mkdir(gp_test_folder) gp_obj_model.eval() gp_obj_model.to(tkwargs['device']) y_train = y_train.view(-1) if do_robust: gp_error_model.eval() gp_error_model.to(tkwargs['device']) error_train = error_train.view(-1) with torch.no_grad(): if obj_out_wp: Y_numpy = y_train.cpu().numpy() if Y_numpy.min() <= 0: y_train = torch.FloatTensor( power_transform(Y_numpy / Y_numpy.std(), method='yeo-johnson')) else: y_train = torch.FloatTensor( power_transform(Y_numpy / Y_numpy.std(), method='box-cox')) if y_train.std() < 0.5: Y_numpy = y_train.numpy() y_train = torch.FloatTensor( power_transform(Y_numpy / Y_numpy.std(), method='yeo-johnson')).to(x_train) Y_numpy = y_test.cpu().numpy() if Y_numpy.min() <= 0: y_test = torch.FloatTensor( power_transform(Y_numpy / Y_numpy.std(), method='yeo-johnson')) else: y_test = torch.FloatTensor( power_transform(Y_numpy / Y_numpy.std(), method='box-cox')) if y_test.std() < 0.5: Y_numpy = y_test.numpy() y_test = torch.FloatTensor( power_transform(Y_numpy / Y_numpy.std(), method='yeo-johnson')).to(x_test) y_train = y_train.view(-1).to(**tkwargs) y_test = y_test.view(-1).to(**tkwargs) gp_obj_val_model_mse_train = ( gp_obj_model.posterior(x_train).mean.view(-1) - y_train).pow(2).div(len(y_train)) gp_obj_val_model_mse_test = ( gp_obj_model.posterior(x_test).mean.view(-1) - y_test).pow(2).div( len(y_test)) torch.save( gp_obj_val_model_mse_train, os.path.join(gp_test_folder, 'gp_obj_val_model_mse_train.npz')) torch.save(gp_obj_val_model_mse_test, os.path.join(gp_test_folder, 'gp_obj_val_model_test.npz')) print( f'GP training fit on objective value: MSE={gp_obj_val_model_mse_train.sum().item():.5f}' ) print( f'GP testing fit on objective value: MSE={gp_obj_val_model_mse_test.sum().item():.5f}' ) if do_robust: if err_out_wp: error_train = error_train.view(-1, 1) R_numpy = error_train.cpu().numpy() if R_numpy.min() <= 0: error_train = torch.FloatTensor( power_transform(R_numpy / R_numpy.std(), method='yeo-johnson')) else: error_train = torch.FloatTensor( power_transform(R_numpy / R_numpy.std(), method='box-cox')) if error_train.std() < 0.5: R_numpy = error_train.numpy() error_train = torch.FloatTensor( power_transform(R_numpy / R_numpy.std(), method='yeo-johnson')).to(x_train) R_numpy = error_test.cpu().numpy() if R_numpy.min() <= 0: error_test = torch.FloatTensor( power_transform(R_numpy / R_numpy.std(), method='yeo-johnson')) else: error_test = torch.FloatTensor( power_transform(R_numpy / R_numpy.std(), method='box-cox')) if error_test.std() < 0.5: R_numpy = error_test.numpy() error_test = torch.FloatTensor( power_transform(R_numpy / R_numpy.std(), method='yeo-johnson')).to(x_test) error_train = error_train.view(-1).to(**tkwargs) error_test = error_test.view(-1).to(**tkwargs) pred_recon_train = gp_error_model.posterior(x_train).mean.view(-1) pred_recon_test = gp_error_model.posterior(x_test).mean.view(-1) gp_error_model_mse_train = (error_train - pred_recon_train).pow(2).div( len(error_train)) gp_error_model_mse_test = (error_test - pred_recon_test).pow(2).div( len(error_test)) torch.save( gp_error_model_mse_train, os.path.join(gp_test_folder, 'gp_error_model_mse_train.npz')) torch.save( gp_error_model_mse_test, os.path.join(gp_test_folder, 'gp_error_model_mse_test.npz')) print( f'GP training fit on reconstruction errors: MSE={gp_error_model_mse_train.sum().item():.5f}' ) print( f'GP testing fit on reconstruction errors: MSE={gp_error_model_mse_test.sum().item():.5f}' ) torch.save(error_test, os.path.join(gp_test_folder, f"true_rec_err_z.pt")) torch.save(error_train, os.path.join(gp_test_folder, f"error_train.pt")) torch.save(x_train, os.path.join(gp_test_folder, f"train_x.pt")) torch.save(x_test, os.path.join(gp_test_folder, f"test_x.pt")) torch.save(y_train, os.path.join(gp_test_folder, f"y_train.pt")) torch.save(x_test, os.path.join(gp_test_folder, f"X_test.pt")) torch.save(y_test, os.path.join(gp_test_folder, f"y_test.pt")) # y plots plt.hist(y_train.cpu().numpy(), bins=100, label='y train', alpha=0.5, density=True) plt.hist(gp_obj_model.posterior(x_train).mean.view( -1).detach().cpu().numpy(), bins=100, label='y pred', alpha=0.5, density=True) plt.legend() plt.title('Training set') plt.savefig(os.path.join(gp_test_folder, 'gp_obj_train.pdf')) plt.close() plt.hist(gp_obj_val_model_mse_train.detach().cpu().numpy(), bins=100, alpha=0.5, density=True) plt.title('MSE of gp_obj_val model on training set') plt.savefig(os.path.join(gp_test_folder, 'gp_obj_train_mse.pdf')) plt.close() plt.hist(y_test.cpu().numpy(), bins=100, label='y true', alpha=0.5, density=True) plt.hist(gp_obj_model.posterior(x_test).mean.detach().cpu().numpy(), bins=100, alpha=0.5, label='y pred', density=True) plt.legend() plt.title('Validation set') plt.savefig(os.path.join(gp_test_folder, 'gp_obj_test.pdf')) plt.close() plt.hist(gp_obj_val_model_mse_test.detach().cpu().numpy(), bins=100, alpha=0.5, density=True) plt.title('MSE of gp_obj_val model on validation set') plt.savefig(os.path.join(gp_test_folder, 'gp_obj_test_mse.pdf')) plt.close() if do_robust: # error plots plt.hist(error_train.cpu().numpy(), bins=100, label='error train', alpha=0.5, density=True) plt.hist( gp_error_model.posterior(x_train).mean.detach().cpu().numpy(), bins=100, label='error pred', alpha=0.5, density=True) plt.legend() plt.title('Training set') plt.savefig(os.path.join(gp_test_folder, 'gp_error_train.pdf')) plt.close() plt.hist(gp_error_model_mse_train.detach().cpu().numpy(), bins=100, alpha=0.5, density=True) plt.title('MSE of gp_error model on training set') plt.savefig(os.path.join(gp_test_folder, 'gp_error_train_mse.pdf')) plt.close() plt.hist(error_test.cpu().numpy(), bins=100, label='error true', alpha=0.5, density=True) plt.hist( gp_error_model.posterior(x_test).mean.detach().cpu().numpy(), bins=100, alpha=0.5, label='error pred', density=True) plt.legend() plt.title('Validation set') plt.savefig(os.path.join(gp_test_folder, 'gp_error_test.pdf')) plt.close() plt.hist(gp_error_model_mse_test.detach().cpu().numpy(), bins=100, alpha=0.5, density=True) plt.title('MSE of gp_error model on validation set') plt.savefig(os.path.join(gp_test_folder, 'gp_error_test_mse.pdf')) plt.close() # y-error plots y_train_sorted, indices_train = torch.sort(y_train) error_train_sorted = error_train[indices_train] gp_y_train_pred_sorted, indices_train_pred = torch.sort( gp_obj_model.posterior(x_train).mean.view(-1)) gp_r_train_pred_sorted = (gp_error_model.posterior( x_train).mean.view(-1))[indices_train_pred] plt.scatter(y_train_sorted.cpu().numpy(), error_train_sorted.cpu().numpy(), label='true', marker='+') plt.scatter(gp_y_train_pred_sorted.detach().cpu().numpy(), gp_r_train_pred_sorted.detach().cpu().numpy(), label='pred', marker='*') plt.xlabel('y train targets') plt.ylabel('recon. error train targets') plt.title('y_train vs. error_train') plt.legend() plt.savefig( os.path.join(gp_test_folder, 'scatter_obj_error_train.pdf')) plt.close() y_test_std_sorted, indices_test = torch.sort(y_test) error_test_sorted = error_test[indices_test] gp_y_test_pred_sorted, indices_test_pred = torch.sort( gp_obj_model.posterior(x_test).mean.view(-1)) gp_r_test_pred_sorted = (gp_error_model.posterior( x_test).mean.view(-1))[indices_test_pred] plt.scatter(y_test_std_sorted.cpu().numpy(), error_test_sorted.cpu().numpy(), label='true', marker='+') plt.scatter(gp_y_test_pred_sorted.detach().cpu().numpy(), gp_r_test_pred_sorted.detach().cpu().numpy(), label='pred', marker='*') plt.xlabel('y test targets') plt.ylabel('recon. error test targets') plt.title('y_test vs. error_test') plt.legend() plt.savefig( os.path.join(gp_test_folder, 'scatter_obj_error_test.pdf')) plt.close() # error var plots error_train_sorted, indices_train_pred = torch.sort(error_train) # error_train_sorted = error_train # indices_train_pred = np.arange(len(error_train)) gp_r_train_pred_sorted = gp_error_model.posterior( x_train).mean[indices_train_pred].view(-1) gp_r_train_pred_std_sorted = gp_error_model.posterior( x_train).variance.view(-1).sqrt()[indices_train_pred] plt.scatter(np.arange(len(indices_train_pred)), error_train_sorted.cpu().numpy(), label='err true', marker='+', color='C1', s=15) plt.errorbar( np.arange(len(indices_train_pred)), gp_r_train_pred_sorted.detach().cpu().numpy().flatten(), yerr=gp_r_train_pred_std_sorted.detach().cpu().numpy().flatten( ), fmt='*', alpha=0.05, label='err pred', color='C0', ecolor='C0') plt.scatter(np.arange(len(indices_train_pred)), gp_r_train_pred_sorted.detach().cpu().numpy(), marker='*', alpha=0.2, s=10, color='C0') # plt.scatter(np.arange(len(indices_train_pred)), # (gp_r_train_pred_sorted + gp_r_train_pred_std_sorted).detach().cpu().numpy(), # label='err pred mean+std', marker='.') # plt.scatter(np.arange(len(indices_train_pred)), # (gp_r_train_pred_sorted - gp_r_train_pred_std_sorted).detach().cpu().numpy(), # label='err pred mean-std', marker='.') plt.legend() plt.title('error predictions and uncertainty on train set') plt.savefig( os.path.join(gp_test_folder, 'gp_error_train_uncertainty.pdf')) plt.close() error_test_sorted, indices_test_pred = torch.sort(error_test) # error_test_sorted = error_test # indices_test_pred = np.arange(len(error_test_sorted)) gp_r_test_pred_sorted = gp_error_model.posterior(x_test).mean.view( -1)[indices_test_pred] gp_r_test_pred_std_sorted = gp_error_model.posterior( x_test).variance.view(-1).sqrt()[indices_test_pred] plt.scatter(np.arange(len(indices_test_pred)), error_test_sorted.cpu().numpy(), label='err true', marker='+', color='C1', s=15) plt.errorbar( np.arange(len(indices_test_pred)), gp_r_test_pred_sorted.detach().cpu().numpy().flatten(), yerr=gp_r_test_pred_std_sorted.detach().cpu().numpy().flatten( ), marker='*', alpha=0.05, label='err pred', color='C0', ecolor='C0') plt.scatter(np.arange(len(indices_test_pred)), gp_r_test_pred_sorted.detach().cpu().numpy().flatten(), marker='*', color='C0', alpha=0.2, s=10) # plt.scatter(np.arange(len(indices_test_pred)), # (gp_r_test_pred_sorted + gp_r_test_pred_std_sorted).detach().cpu().numpy(), # label='err pred mean+std', marker='.') # plt.scatter(np.arange(len(indices_test_pred)), # (gp_r_test_pred_sorted - gp_r_test_pred_std_sorted).detach().cpu().numpy(), # label='err pred mean-std', marker='.') plt.legend() plt.title('error predictions and uncertainty on test set') plt.savefig( os.path.join(gp_test_folder, 'gp_error_test_uncertainty.pdf')) plt.close() # y var plots y_train_std_sorted, indices_train = torch.sort(y_train) gp_y_train_pred_sorted = gp_obj_model.posterior( x_train).mean[indices_train].view(-1) gp_y_train_pred_std_sorted = gp_obj_model.posterior( x_train).variance.sqrt()[indices_train].view(-1) plt.scatter(np.arange(len(indices_train)), y_train_std_sorted.cpu().numpy(), label='y true', marker='+', color='C1', s=15) plt.scatter(np.arange(len(indices_train)), gp_y_train_pred_sorted.detach().cpu().numpy(), marker='*', alpha=0.2, s=10, color='C0') plt.errorbar( np.arange(len(indices_train)), gp_y_train_pred_sorted.detach().cpu().numpy().flatten(), yerr=gp_y_train_pred_std_sorted.detach().cpu().numpy().flatten(), fmt='*', alpha=0.05, label='y pred', color='C0', ecolor='C0') # plt.scatter(np.arange(len(indices_train_pred)), # (gp_y_train_pred_sorted+gp_y_train_pred_std_sorted).detach().cpu().numpy(), # label='y pred mean+std', marker='.') # plt.scatter(np.arange(len(indices_train_pred)), # (gp_y_train_pred_sorted-gp_y_train_pred_std_sorted).detach().cpu().numpy(), # label='y pred mean-std', marker='.') plt.legend() plt.title('y predictions and uncertainty on train set') plt.savefig( os.path.join(gp_test_folder, 'gp_obj_val_train_uncertainty.pdf')) plt.close() y_test_std_sorted, indices_test = torch.sort(y_test) gp_y_test_pred_sorted = gp_obj_model.posterior(x_test).mean.view( -1)[indices_test] gp_y_test_pred_std_sorted = gp_obj_model.posterior( x_test).variance.view(-1).sqrt()[indices_test] plt.scatter(np.arange(len(indices_test)), y_test_std_sorted.cpu().numpy(), label='y true', marker='+', color='C1', s=15) plt.errorbar( np.arange(len(indices_test)), gp_y_test_pred_sorted.detach().cpu().numpy().flatten(), yerr=gp_y_test_pred_std_sorted.detach().cpu().numpy().flatten(), fmt='*', alpha=0.05, label='y pred', color='C0', ecolor='C0') plt.scatter(np.arange(len(indices_test)), gp_y_test_pred_sorted.detach().cpu().numpy(), marker='*', alpha=0.2, s=10, color='C0') # plt.scatter(np.arange(len(indices_test_pred)), # (gp_y_test_pred_sorted + gp_y_test_pred_std_sorted).detach().cpu().numpy(), # label='y pred mean+std', marker='.') # plt.scatter(np.arange(len(indices_test_pred)), # (gp_y_test_pred_sorted - gp_y_test_pred_std_sorted).detach().cpu().numpy(), # label='y pred mean-std', marker='.') plt.legend() plt.title('y predictions and uncertainty on test set') plt.savefig( os.path.join(gp_test_folder, 'gp_obj_val_test_uncertainty.pdf')) plt.close()
def gp_torch_train(train_x: Tensor, train_y: Tensor, n_inducing_points: int, tkwargs: Dict[str, Any], init, scale: bool, covar_name: str, gp_file: Optional[str], save_file: str, input_wp: bool, outcome_transform: Optional[OutcomeTransform] = None, options: Dict[str, Any] = None) -> SingleTaskGP: assert train_y.ndim > 1, train_y.shape assert gp_file or init, (gp_file, init) likelihood = gpytorch.likelihoods.GaussianLikelihood() if init: # build hyp print("Initialize GP hparams...") print("Doing Kmeans init...") assert n_inducing_points > 0, n_inducing_points kmeans = MiniBatchKMeans(n_clusters=n_inducing_points, batch_size=min(10000, train_x.shape[0]), n_init=25) start_time = time.time() kmeans.fit(train_x.cpu().numpy()) end_time = time.time() print(f"K means took {end_time - start_time:.1f}s to finish...") inducing_points = torch.from_numpy(kmeans.cluster_centers_.copy()) output_scale = None if scale: output_scale = train_y.var().item() lscales = torch.empty(1, train_x.shape[1]) for i in range(train_x.shape[1]): lscales[0, i] = torch.pdist(train_x[:, i].view( -1, 1)).median().clamp(min=0.01) base_covar_module = query_covar(covar_name=covar_name, scale=scale, outputscale=output_scale, lscales=lscales) covar_module = InducingPointKernel(base_covar_module, inducing_points=inducing_points, likelihood=likelihood) input_warp_tf = None if input_wp: # Apply input warping # initialize input_warping transformation input_warp_tf = CustomWarp( indices=list(range(train_x.shape[-1])), # use a prior with median at 1. # when a=1 and b=1, the Kumaraswamy CDF is the identity function concentration1_prior=LogNormalPrior(0.0, 0.75**0.5), concentration0_prior=LogNormalPrior(0.0, 0.75**0.5), ) model = SingleTaskGP(train_x, train_y, covar_module=covar_module, likelihood=likelihood, input_transform=input_warp_tf, outcome_transform=outcome_transform) else: # load model output_scale = 1 # will be overwritten when loading model lscales = torch.ones( train_x.shape[1]) # will be overwritten when loading model base_covar_module = query_covar(covar_name=covar_name, scale=scale, outputscale=output_scale, lscales=lscales) covar_module = InducingPointKernel(base_covar_module, inducing_points=torch.empty( n_inducing_points, train_x.shape[1]), likelihood=likelihood) input_warp_tf = None if input_wp: # Apply input warping # initialize input_warping transformation input_warp_tf = Warp( indices=list(range(train_x.shape[-1])), # use a prior with median at 1. # when a=1 and b=1, the Kumaraswamy CDF is the identity function concentration1_prior=LogNormalPrior(0.0, 0.75**0.5), concentration0_prior=LogNormalPrior(0.0, 0.75**0.5), ) model = SingleTaskGP(train_x, train_y, covar_module=covar_module, likelihood=likelihood, input_transform=input_warp_tf, outcome_transform=outcome_transform) print("Loading GP from file") state_dict = torch.load(gp_file) model.load_state_dict(state_dict) print("GP regression") start_time = time.time() model.to(**tkwargs) model.train() mll = ExactMarginalLogLikelihood(model.likelihood, model) # set approx_mll to False since we are using an exact marginal log likelihood # fit_gpytorch_model(mll, optimizer=fit_gpytorch_torch, approx_mll=False, options=options) fit_gpytorch_torch(mll, options=options, approx_mll=False, clip_by_value=True if input_wp else False, clip_value=10.0) end_time = time.time() print(f"Regression took {end_time - start_time:.1f}s to finish...") print("Save GP model...") torch.save(model.state_dict(), save_file) print("Done training of GP.") model.eval() return model