def _setup_meta_train_step(self, mean_module_str, covar_module_str, mean_nn_layers, kernel_nn_layers, cov_type): assert mean_module_str in ['NN', 'constant'] assert covar_module_str in ['NN', 'SE'] """ random gp model """ self.random_gp = RandomGPMeta(size_in=self.input_dim, prior_factor=1.0, weight_prior_std=self.weight_prior_std, bias_prior_std=self.bias_prior_std, covar_module_str=covar_module_str, mean_module_str=mean_module_str, mean_nn_layers=mean_nn_layers, kernel_nn_layers=kernel_nn_layers) param_shapes_dict = self.random_gp.parameter_shapes() """ variational posterior """ self.hyper_posterior = RandomGPPosterior(param_shapes_dict, cov_type=cov_type) def _tile_data_tuple(task_dict, tile_size): x_data, y_data = task_dict['train_x'], task_dict['train_y'] x_data = x_data.view(torch.Size((1, )) + x_data.shape).repeat( tile_size, 1, 1) y_data = y_data.view(torch.Size((1, )) + y_data.shape).repeat( tile_size, 1) return x_data, y_data def _hyper_kl(prior_param_sample): return torch.mean( self.hyper_posterior.log_prob(prior_param_sample) - self.random_gp.hyper_prior.log_prob(prior_param_sample)) def _task_pac_bounds(task_dicts, prior_param_sample, task_kl_weight=1.0, meta_kl_weight=1.0): fn = self.random_gp.get_forward_fn(prior_param_sample) kl_outer = meta_kl_weight * _hyper_kl(prior_param_sample) task_pac_bounds = [] for task_dict in task_dicts: posterior = task_dict["gp_model"](task_dict["train_x"]) # likelihood avg_ll = torch.mean( self.likelihood.expected_log_prob(task_dict["train_y"], posterior)) # task complexity x_data_tiled, y_data_tiled = _tile_data_tuple( task_dict, self.svi_batch_size) gp, _ = fn(x_data_tiled, None, prior=True) prior = gp(x_data_tiled) kl_inner = task_kl_weight * torch.mean( _kl_divergence_safe( posterior.expand((self.svi_batch_size, )), prior)) m = torch.tensor(task_dict["train_y"].shape[0], dtype=torch.float32) n = torch.tensor(self.n_tasks, dtype=torch.float32) task_complexity = torch.sqrt( (kl_outer + kl_inner + math.log(2.) + torch.log(m) + torch.log(n) - torch.log(self.delta)) / (2 * (m - 1))) diagnostics_dict = { 'avg_ll': avg_ll.item(), 'kl_outer_weighted': kl_outer.item(), 'kl_inner_weighted': kl_inner.item(), } task_pac_bound = -avg_ll + task_complexity task_pac_bounds.append(task_pac_bound) return task_pac_bounds, diagnostics_dict def _meta_complexity(prior_param_sample, meta_kl_weight=1.0): outer_kl = _hyper_kl(prior_param_sample) n = torch.tensor(self.n_tasks, dtype=torch.float32) return torch.sqrt(meta_kl_weight * outer_kl + math.log(2.) + torch.log(n) - torch.log(self.delta) / (2 * (n - 1))) def _meta_train_pac_bound(task_dicts): param_sample = self.hyper_posterior.rsample( sample_shape=(self.svi_batch_size, )) task_pac_bounds, diagnostics_dict = _task_pac_bounds( task_dicts, param_sample, task_kl_weight=self.task_kl_weight, meta_kl_weight=self.meta_kl_weight) meta_complexity = _meta_complexity( param_sample, meta_kl_weight=self.meta_kl_weight) pac_bound = torch.mean( torch.stack(task_pac_bounds)) + meta_complexity return pac_bound, diagnostics_dict self._task_pac_bounds = _task_pac_bounds self._meta_train_pac_bound = _meta_train_pac_bound
def _setup_model_inference(self, mean_module_str, covar_module_str, mean_nn_layers, kernel_nn_layers, cov_type): assert mean_module_str in ['NN', 'constant'] assert covar_module_str in ['NN', 'SE'] """ random gp model """ self.random_gp = RandomGPMeta(size_in=self.input_dim, prior_factor=self.prior_factor, weight_prior_std=self.weight_prior_std, bias_prior_std=self.bias_prior_std, covar_module_str=covar_module_str, mean_module_str=mean_module_str, mean_nn_layers=mean_nn_layers, kernel_nn_layers=kernel_nn_layers) param_shapes_dict = self.random_gp.parameter_shapes() """ variational posterior """ self.posterior = RandomGPPosterior(param_shapes_dict, cov_type=cov_type) def _tile_data_tuples(tasks_dicts, tile_size): train_data_tuples_tiled = [] for task_dict in tasks_dicts: x_data, y_data = task_dict['train_x'], task_dict['train_y'] x_data = x_data.view(torch.Size((1,)) + x_data.shape).repeat(tile_size, 1, 1) y_data = y_data.view(torch.Size((1,)) + y_data.shape).repeat(tile_size, 1) train_data_tuples_tiled.append((x_data, y_data)) return train_data_tuples_tiled """ define negative ELBO """ def get_neg_elbo(tasks_dicts): # tile data to svi_batch_shape data_tuples_tiled = _tile_data_tuples(tasks_dicts, self.svi_batch_size) param_sample = self.posterior.rsample(sample_shape=(self.svi_batch_size,)) elbo = self.random_gp.log_prob(param_sample, data_tuples_tiled) - self.prior_factor * self.posterior.log_prob(param_sample) assert elbo.ndim == 1 and elbo.shape[0] == self.svi_batch_size return - torch.mean(elbo) self.get_neg_elbo = get_neg_elbo """ define predictive dist """ def get_pred_dist(x_context, y_context, x_valid, n_post_samples=100): with torch.no_grad(): x_context = x_context.view(torch.Size((1,)) + x_context.shape).repeat(n_post_samples, 1, 1) y_context = y_context.view(torch.Size((1,)) + y_context.shape).repeat(n_post_samples, 1) x_valid = x_valid.view(torch.Size((1,)) + x_valid.shape).repeat(n_post_samples, 1, 1) param_sample = self.posterior.sample(sample_shape=(n_post_samples,)) gp_fn = self.random_gp.get_forward_fn(param_sample) gp, likelihood = gp_fn(x_context, y_context, train=False) pred_dist = likelihood(gp(x_valid)) return pred_dist def get_pred_dist_map(x_context, y_context, x_valid): with torch.no_grad(): x_context = x_context.view(torch.Size((1,)) + x_context.shape).repeat(1, 1, 1) y_context = y_context.view(torch.Size((1,)) + y_context.shape).repeat(1, 1) x_valid = x_valid.view(torch.Size((1,)) + x_valid.shape).repeat(1, 1, 1) param = self.posterior.mode param = param.view(torch.Size((1,)) + param.shape).repeat(1, 1) gp_fn = self.random_gp.get_forward_fn(param) gp, likelihood = gp_fn(x_context, y_context, train=False) pred_dist = likelihood(gp(x_valid)) return MultivariateNormal(pred_dist.loc, pred_dist.covariance_matrix[0]) self.get_pred_dist = get_pred_dist self.get_pred_dist_map = get_pred_dist_map
class GPRegressionMetaLearnedPAC(RegressionModelMetaLearned): def __init__(self, meta_train_data, num_iter_fit=40000, feature_dim=1, weight_prior_std=0.5, bias_prior_std=3.0, delta=0.1, task_kl_weight=1.0, meta_kl_weight=1.0, posterior_lr_multiplier=1.0, covar_module='SE', mean_module='zero', mean_nn_layers=(32, 32), kernel_nn_layers=(32, 32), optimizer='Adam', lr=1e-3, lr_decay=1.0, svi_batch_size=5, cov_type='diag', task_batch_size=4, normalize_data=True, random_seed=None): """ PACOH-VI: Variational Inference on the PAC-optimal hyper-posterior with Gaussian family. Meta-Learns a distribution over GP-priors. Args: meta_train_data: list of tuples of ndarrays[(train_x_1, train_t_1), ..., (train_x_n, train_t_n)] num_iter_fit: (int) number of gradient steps for fitting the parameters feature_dim: (int) output dimensionality of NN feature map for kernel function prior_factor: (float) weighting of the hyper-prior (--> meta-regularization parameter) weight_prior_std (float): std of Gaussian hyper-prior on weights bias_prior_std (float): std of Gaussian hyper-prior on biases covar_module: (gpytorch.mean.Kernel) optional kernel module, default: RBF kernel mean_module: (gpytorch.mean.Mean) optional mean module, default: ZeroMean mean_nn_layers: (tuple) hidden layer sizes of mean NN kernel_nn_layers: (tuple) hidden layer sizes of kernel NN optimizer: (str) type of optimizer to use - must be either 'Adam' or 'SGD' lr: (float) learning rate for prior parameters lr_decay: (float) lr rate decay multiplier applied after every 1000 steps kernel (std): SVGD kernel, either 'RBF' or 'IMQ' bandwidth (float): bandwidth of kernel, if None the bandwidth is chosen via heuristic num_particles: (int) number particles to approximate the hyper-posterior task_batch_size: (int) mini-batch size of tasks for estimating gradients normalize_data: (bool) whether the data should be normalized random_seed: (int) seed for pytorch """ super().__init__(normalize_data, random_seed) assert mean_module in ['NN', 'constant', 'zero'] or isinstance( mean_module, gpytorch.means.Mean) assert covar_module in ['NN', 'SE'] or isinstance( covar_module, gpytorch.kernels.Kernel) assert optimizer in ['Adam', 'SGD'] self.num_iter_fit, self.feature_dim = num_iter_fit, feature_dim self.task_kl_weight, self.meta_kl_weight = task_kl_weight, meta_kl_weight self.weight_prior_std, self.bias_prior_std = weight_prior_std, bias_prior_std self.svi_batch_size = svi_batch_size self.lr = lr self.n_tasks = len(meta_train_data) self.delta = torch.tensor(delta, dtype=torch.float32) if task_batch_size < 1: self.task_batch_size = len(meta_train_data) else: self.task_batch_size = min(task_batch_size, len(meta_train_data)) # Check that data all has the same size self._check_meta_data_shapes(meta_train_data) self._compute_normalization_stats(meta_train_data) """ --- Setup model & inference --- """ self.meta_train_params = [] self._setup_meta_train_step(mean_module, covar_module, mean_nn_layers, kernel_nn_layers, cov_type) self.meta_train_params.append({ 'params': self.hyper_posterior.parameters(), 'lr': lr }) self.likelihood = gpytorch.likelihoods.GaussianLikelihood() self.meta_train_params.append({ 'params': self.likelihood.parameters(), 'lr': lr }) # Setup components that are different across tasks self.task_dicts, posterior_params = self._setup_task_dicts( meta_train_data) self.meta_train_params.append({ 'params': posterior_params, 'lr': posterior_lr_multiplier * lr }) self._setup_optimizer(optimizer, lr, lr_decay) self.fitted = False def meta_fit(self, valid_tuples=None, verbose=True, log_period=500, eval_period=5000, n_iter=None): """ fits the variational hyper-posterior by minimizing the negative ELBO Args: valid_tuples: list of valid tuples, i.e. [(test_context_x_1, test_context_t_1, test_x_1, test_t_1), ...] verbose: (boolean) whether to print training progress log_period (int) number of steps after which to print stats n_iter: (int) number of gradient descent iterations """ assert eval_period % log_period == 0, "eval_period should be multiple of log_period" assert (valid_tuples is None) or (all( [len(valid_tuple) == 4 for valid_tuple in valid_tuples])) t = time.time() if n_iter is None: n_iter = self.num_iter_fit for itr in range(1, n_iter + 1): task_dict_batch = self.rds_numpy.choice(self.task_dicts, size=self.task_batch_size) self.optimizer.zero_grad() loss, diagnostics_dict = self._meta_train_pac_bound( task_dict_batch) loss.backward() self.optimizer.step() self.lr_scheduler.step() # print training stats stats if verbose and (itr == 1 or itr % log_period == 0): duration = time.time() - t t = time.time() message = 'Iter %d/%d - Loss: %.6f - Time %.2f sec - ' % ( itr, self.num_iter_fit, loss.item(), duration) # if validation data is provided -> compute the valid log-likelihood if valid_tuples is not None and itr % eval_period == 0 and itr > 0: valid_ll, valid_rmse, calibr_err = self.eval_datasets( valid_tuples) message += ' - Valid-LL: %.3f - Valid-RMSE: %.3f - Calib-Err %.3f' % ( valid_ll, valid_rmse, calibr_err) # add diagnostics message += ' - '.join([ '%s: %.4f' % (key, value) for key, value in diagnostics_dict.items() ]) self.logger.info(message) self.fitted = True return loss.item(), diagnostics_dict def predict(self, context_x, context_y, test_x, n_iter_meta_test=3000, return_density=False): """ computes the predictive distribution of the targets p(t|test_x, test_context_x, context_y) Args: context_x: (ndarray) context input data for which to compute the posterior context_y: (ndarray) context targets for which to compute the posterior test_x: (ndarray) query input data of shape (n_samples, ndim_x) n_posterior_samples: (int) number of samples from posterior to average over mode: (std) either of ['Bayes' , 'MAP'] return_density: (bool) whether to return result as mean and std ndarray or as MultivariateNormal pytorch object Returns: (pred_mean, pred_std) predicted mean and standard deviation corresponding to p(t|test_x, test_context_x, context_y) """ context_x, context_y = _handle_input_dimensionality( context_x, context_y) test_x = _handle_input_dimensionality(test_x) assert test_x.shape[1] == context_x.shape[1] # meta-test training / inference task_dict = self._meta_test_inference([(context_x, context_y)], verbose=True, log_period=500, n_iter=n_iter_meta_test)[0] with torch.no_grad(): # meta-test evaluation test_x = self._normalize_data(X=test_x, Y=None) test_x = torch.from_numpy(test_x).float().to(device) gp_model = task_dict["gp_model"] gp_model.eval() pred_dist = self.likelihood(gp_model(test_x)) pred_dist = AffineTransformedDistribution( pred_dist, normalization_mean=self.y_mean, normalization_std=self.y_std) if return_density: return pred_dist else: pred_mean = pred_dist.mean.cpu().numpy() pred_std = pred_dist.stddev.cpu().numpy() return pred_mean, pred_std def eval_datasets(self, test_tuples, n_iter_meta_test=3000, **kwargs): """ Performs meta-testing on multiple tasks / datasets. Computes the average test log likelihood, the rmse and the calibration error over multiple test datasets Args: test_tuples: list of test set tuples, i.e. [(test_context_x_1, test_context_y_1, test_x_1, test_y_1), ...] Returns: (avg_log_likelihood, rmse, calibr_error) """ assert (all([len(valid_tuple) == 4 for valid_tuple in test_tuples])) # meta-test training / inference context_tuples = [test_tuple[:2] for test_tuple in test_tuples] task_dicts = self._meta_test_inference(context_tuples, verbose=True, log_period=500, n_iter=n_iter_meta_test) # meta-test evaluation ll_list, rmse_list, calibr_err_list = [], [], [] for task_dict, test_tuple in zip(task_dicts, test_tuples): # data prep _, _, test_x, test_y = test_tuple test_x_tensor = torch.from_numpy( self._normalize_data(X=test_x, Y=None)).float().to(device) test_y_tensor = torch.from_numpy(test_y).float().flatten().to( device) # get predictive dist gp_model = task_dict["gp_model"] gp_model.eval() pred_dist = self.likelihood(gp_model(test_x_tensor)) pred_dist = AffineTransformedDistribution( pred_dist, normalization_mean=self.y_mean, normalization_std=self.y_std) # compute eval metrics ll_list.append( torch.mean( pred_dist.log_prob(test_y_tensor) / test_y_tensor.shape[0]).cpu().item()) rmse_list.append( torch.mean(torch.pow(pred_dist.mean - test_y_tensor, 2)).sqrt().cpu().item()) pred_dist_vect = self._vectorize_pred_dist(pred_dist) calibr_err_list.append( self._calib_error(pred_dist_vect, test_y_tensor).cpu().item()) return np.mean(ll_list), np.mean(rmse_list), np.mean(calibr_err_list) def state_dict(self): state_dict = { 'optimizer': self.optimizer.state_dict(), 'model': self.task_dicts[0]['model'].state_dict() } for task_dict in self.task_dicts: for key, tensor in task_dict['model'].state_dict().items(): assert torch.all(state_dict['model'][key] == tensor).item() return state_dict def load_state_dict(self, state_dict): for task_dict in self.task_dicts: task_dict['model'].load_state_dict(state_dict['model']) self.optimizer.load_state_dict(state_dict['optimizer']) def _setup_task_dicts(self, train_data_tuples): task_dicts, parameters = [], [] for train_x, train_y in train_data_tuples: task_dict = OrderedDict() # a) prepare data x_tensor, y_tensor = self._prepare_data_per_task(train_x, train_y) task_dict['train_x'], task_dict['train_y'] = x_tensor, y_tensor task_dict['gp_model'] = LearnedGPRegressionModelApproximate( x_tensor, y_tensor, self.likelihood, mean_module=gpytorch.means.ZeroMean(), covar_module=gpytorch.kernels.RBFKernel()) parameters.extend(task_dict['gp_model'].variational_parameters()) task_dicts.append(task_dict) return task_dicts, parameters def _meta_test_inference(self, context_tuples, n_iter=3000, lr=1e-2, log_period=100, verbose=False): n_tasks = len(context_tuples) task_dicts, posterior_params = self._setup_task_dicts(context_tuples) optimizer = torch.optim.Adam(posterior_params, lr=lr) t = time.time() for itr in range(n_iter): optimizer.zero_grad() param_sample = self.hyper_posterior.rsample( sample_shape=(self.svi_batch_size, )) task_pac_bounds, diagnostics_dict = self._task_pac_bounds( task_dicts, param_sample, task_kl_weight=self.task_kl_weight, meta_kl_weight=self.meta_kl_weight) loss = torch.sum(torch.stack(task_pac_bounds)) loss.backward() optimizer.step() if itr % log_period == 0 and verbose: duration = time.time() - t t = time.time() message = '\t Meta-Test Iter %d/%d - Loss: %.6f - Time %.2f sec - ' % ( itr, n_iter, loss.item() / n_tasks, duration) # add diagnostics message += ' - '.join([ '%s: %.4f' % (key, value) for key, value in diagnostics_dict.items() ]) self.logger.info(message) return task_dicts def _setup_meta_train_step(self, mean_module_str, covar_module_str, mean_nn_layers, kernel_nn_layers, cov_type): assert mean_module_str in ['NN', 'constant'] assert covar_module_str in ['NN', 'SE'] """ random gp model """ self.random_gp = RandomGPMeta(size_in=self.input_dim, prior_factor=1.0, weight_prior_std=self.weight_prior_std, bias_prior_std=self.bias_prior_std, covar_module_str=covar_module_str, mean_module_str=mean_module_str, mean_nn_layers=mean_nn_layers, kernel_nn_layers=kernel_nn_layers) param_shapes_dict = self.random_gp.parameter_shapes() """ variational posterior """ self.hyper_posterior = RandomGPPosterior(param_shapes_dict, cov_type=cov_type) def _tile_data_tuple(task_dict, tile_size): x_data, y_data = task_dict['train_x'], task_dict['train_y'] x_data = x_data.view(torch.Size((1, )) + x_data.shape).repeat( tile_size, 1, 1) y_data = y_data.view(torch.Size((1, )) + y_data.shape).repeat( tile_size, 1) return x_data, y_data def _hyper_kl(prior_param_sample): return torch.mean( self.hyper_posterior.log_prob(prior_param_sample) - self.random_gp.hyper_prior.log_prob(prior_param_sample)) def _task_pac_bounds(task_dicts, prior_param_sample, task_kl_weight=1.0, meta_kl_weight=1.0): fn = self.random_gp.get_forward_fn(prior_param_sample) kl_outer = meta_kl_weight * _hyper_kl(prior_param_sample) task_pac_bounds = [] for task_dict in task_dicts: posterior = task_dict["gp_model"](task_dict["train_x"]) # likelihood avg_ll = torch.mean( self.likelihood.expected_log_prob(task_dict["train_y"], posterior)) # task complexity x_data_tiled, y_data_tiled = _tile_data_tuple( task_dict, self.svi_batch_size) gp, _ = fn(x_data_tiled, None, prior=True) prior = gp(x_data_tiled) kl_inner = task_kl_weight * torch.mean( _kl_divergence_safe( posterior.expand((self.svi_batch_size, )), prior)) m = torch.tensor(task_dict["train_y"].shape[0], dtype=torch.float32) n = torch.tensor(self.n_tasks, dtype=torch.float32) task_complexity = torch.sqrt( (kl_outer + kl_inner + math.log(2.) + torch.log(m) + torch.log(n) - torch.log(self.delta)) / (2 * (m - 1))) diagnostics_dict = { 'avg_ll': avg_ll.item(), 'kl_outer_weighted': kl_outer.item(), 'kl_inner_weighted': kl_inner.item(), } task_pac_bound = -avg_ll + task_complexity task_pac_bounds.append(task_pac_bound) return task_pac_bounds, diagnostics_dict def _meta_complexity(prior_param_sample, meta_kl_weight=1.0): outer_kl = _hyper_kl(prior_param_sample) n = torch.tensor(self.n_tasks, dtype=torch.float32) return torch.sqrt(meta_kl_weight * outer_kl + math.log(2.) + torch.log(n) - torch.log(self.delta) / (2 * (n - 1))) def _meta_train_pac_bound(task_dicts): param_sample = self.hyper_posterior.rsample( sample_shape=(self.svi_batch_size, )) task_pac_bounds, diagnostics_dict = _task_pac_bounds( task_dicts, param_sample, task_kl_weight=self.task_kl_weight, meta_kl_weight=self.meta_kl_weight) meta_complexity = _meta_complexity( param_sample, meta_kl_weight=self.meta_kl_weight) pac_bound = torch.mean( torch.stack(task_pac_bounds)) + meta_complexity return pac_bound, diagnostics_dict self._task_pac_bounds = _task_pac_bounds self._meta_train_pac_bound = _meta_train_pac_bound def _setup_optimizer(self, optimizer, lr, lr_decay): if optimizer == 'Adam': self.optimizer = torch.optim.Adam(self.meta_train_params, lr=lr) elif optimizer == 'SGD': self.optimizer = torch.optim.SGD(self.meta_train_params, lr=lr) else: raise NotImplementedError('Optimizer must be Adam or SGD') if lr_decay < 1.0: self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 1000, gamma=lr_decay) else: self.lr_scheduler = DummyLRScheduler() def _vectorize_pred_dist(self, pred_dist): # converts a multivariate gaussian into a vectorized univariate gaussian return torch.distributions.Normal(pred_dist.mean, pred_dist.stddev) def prior_mean(self, x, n_hyperposterior_samples=1000): x = (x - self.x_mean) / self.x_std assert x.ndim == 1 or (x.ndim == 2 and x.shape[-1] == 1) x_data_tiled = np.tile(x.reshape(1, x.shape[0], 1), (n_hyperposterior_samples, 1, 1)) x_data_tiled = torch.tensor(x_data_tiled, dtype=torch.float32) with torch.no_grad(): param_sample = self.hyper_posterior.rsample( sample_shape=(n_hyperposterior_samples, )) fn = self.random_gp.get_forward_fn(param_sample) gp, _ = fn(x_data_tiled, None, prior=True) prior = gp(x_data_tiled) mean = torch.mean(prior.mean, axis=0).numpy() * self.y_std + self.y_mean # torch.mean(gp.learned_mean(x_data_tiled), axis=0).numpy() * self.y_std + self.y_mean return mean
class GPRegressionMetaLearnedVI(RegressionModelMetaLearned): def __init__(self, meta_train_data, num_iter_fit=10000, feature_dim=1, prior_factor=0.01, weight_prior_std=0.5, bias_prior_std=3.0, covar_module='NN', mean_module='NN', mean_nn_layers=(32, 32), kernel_nn_layers=(32, 32), optimizer='Adam', lr=1e-3, lr_decay=1.0, svi_batch_size=10, cov_type='diag', task_batch_size=-1, normalize_data=True, random_seed=None): """ PACOH-VI: Variational Inference on the PAC-optimal hyper-posterior with Gaussian family. Meta-Learns a distribution over GP-priors. Args: meta_train_data: list of tuples of ndarrays[(train_x_1, train_t_1), ..., (train_x_n, train_t_n)] num_iter_fit: (int) number of gradient steps for fitting the parameters feature_dim: (int) output dimensionality of NN feature map for kernel function prior_factor: (float) weighting of the hyper-prior (--> meta-regularization parameter) weight_prior_std (float): std of Gaussian hyper-prior on weights bias_prior_std (float): std of Gaussian hyper-prior on biases covar_module: (gpytorch.mean.Kernel) optional kernel module, default: RBF kernel mean_module: (gpytorch.mean.Mean) optional mean module, default: ZeroMean mean_nn_layers: (tuple) hidden layer sizes of mean NN kernel_nn_layers: (tuple) hidden layer sizes of kernel NN optimizer: (str) type of optimizer to use - must be either 'Adam' or 'SGD' lr: (float) learning rate for prior parameters lr_decay: (float) lr rate decay multiplier applied after every 1000 steps kernel (std): SVGD kernel, either 'RBF' or 'IMQ' bandwidth (float): bandwidth of kernel, if None the bandwidth is chosen via heuristic num_particles: (int) number particles to approximate the hyper-posterior task_batch_size: (int) mini-batch size of tasks for estimating gradients normalize_data: (bool) whether the data should be normalized random_seed: (int) seed for pytorch """ super().__init__(normalize_data, random_seed) assert mean_module in ['NN', 'constant', 'zero'] or isinstance(mean_module, gpytorch.means.Mean) assert covar_module in ['NN', 'SE'] or isinstance(covar_module, gpytorch.kernels.Kernel) assert optimizer in ['Adam', 'SGD'] self.num_iter_fit, self.prior_factor, self.feature_dim = num_iter_fit, prior_factor, feature_dim self.weight_prior_std, self.bias_prior_std = weight_prior_std, bias_prior_std self.svi_batch_size = svi_batch_size if task_batch_size < 1: self.task_batch_size = len(meta_train_data) else: self.task_batch_size = min(task_batch_size, len(meta_train_data)) # Check that data all has the same size self._check_meta_data_shapes(meta_train_data) self._compute_normalization_stats(meta_train_data) """ --- Setup model & inference --- """ self._setup_model_inference(mean_module, covar_module, mean_nn_layers, kernel_nn_layers, cov_type) self._setup_optimizer(optimizer, lr, lr_decay) # Setup components that are different across tasks self.task_dicts = [] for train_x, train_y in meta_train_data: task_dict = {} # a) prepare data x_tensor, y_tensor = self._prepare_data_per_task(train_x, train_y) task_dict['train_x'], task_dict['train_y'] = x_tensor, y_tensor self.task_dicts.append(task_dict) self.fitted = False def meta_fit(self, valid_tuples=None, verbose=True, log_period=500, n_iter=None): """ fits the variational hyper-posterior by minimizing the negative ELBO Args: valid_tuples: list of valid tuples, i.e. [(test_context_x_1, test_context_t_1, test_x_1, test_t_1), ...] verbose: (boolean) whether to print training progress log_period (int) number of steps after which to print stats n_iter: (int) number of gradient descent iterations """ assert (valid_tuples is None) or (all([len(valid_tuple) == 4 for valid_tuple in valid_tuples])) t = time.time() if n_iter is None: n_iter = self.num_iter_fit for itr in range(1, n_iter + 1): task_dict_batch = self.rds_numpy.choice(self.task_dicts, size=self.task_batch_size) self.optimizer.zero_grad() loss = self.get_neg_elbo(task_dict_batch) loss.backward() self.optimizer.step() self.lr_scheduler.step() # print training stats stats if itr == 1 or itr % log_period == 0: duration = time.time() - t t = time.time() message = 'Iter %d/%d - Loss: %.6f - Time %.2f sec' % (itr, self.num_iter_fit, loss.item(), duration) # if validation data is provided -> compute the valid log-likelihood if valid_tuples is not None: valid_ll, valid_rmse, calibr_err = self.eval_datasets(valid_tuples) message += ' - Valid-LL: %.3f - Valid-RMSE: %.3f - Calib-Err %.3f' % (valid_ll, valid_rmse, calibr_err) if verbose: self.logger.info(message) self.fitted = True return loss.item() def predict(self, context_x, context_y, test_x, n_posterior_samples=100, mode='Bayes', return_density=False): """ computes the predictive distribution of the targets p(t|test_x, test_context_x, context_y) Args: context_x: (ndarray) context input data for which to compute the posterior context_y: (ndarray) context targets for which to compute the posterior test_x: (ndarray) query input data of shape (n_samples, ndim_x) n_posterior_samples: (int) number of samples from posterior to average over mode: (std) either of ['Bayes' , 'MAP'] return_density: (bool) whether to return result as mean and std ndarray or as MultivariateNormal pytorch object Returns: (pred_mean, pred_std) predicted mean and standard deviation corresponding to p(t|test_x, test_context_x, context_y) """ assert mode in ['bayes', 'Bayes', 'MAP', 'map'] context_x, context_y = _handle_input_dimensionality(context_x, context_y) test_x = _handle_input_dimensionality(test_x) assert test_x.shape[1] == context_x.shape[1] # normalize data and convert to tensor context_x, context_y = self._prepare_data_per_task(context_x, context_y) test_x = self._normalize_data(X=test_x, Y=None) test_x = torch.from_numpy(test_x).float().to(device) with torch.no_grad(): if mode == 'Bayes' or mode == 'bayes': pred_dist = self.get_pred_dist(context_x, context_y, test_x, n_post_samples=n_posterior_samples) pred_dist = AffineTransformedDistribution(pred_dist, normalization_mean=self.y_mean, normalization_std=self.y_std) pred_dist = EqualWeightedMixtureDist(pred_dist, batched=True) else: pred_dist = self.get_pred_dist_map(context_x, context_y, test_x) pred_dist = AffineTransformedDistribution(pred_dist, normalization_mean=self.y_mean, normalization_std=self.y_std) if return_density: return pred_dist else: pred_mean = pred_dist.mean.cpu().numpy() pred_std = pred_dist.stddev.cpu().numpy() return pred_mean, pred_std def state_dict(self): state_dict = { 'optimizer': self.optimizer.state_dict(), 'model': self.task_dicts[0]['model'].state_dict() } for task_dict in self.task_dicts: for key, tensor in task_dict['model'].state_dict().items(): assert torch.all(state_dict['model'][key] == tensor).item() return state_dict def load_state_dict(self, state_dict): for task_dict in self.task_dicts: task_dict['model'].load_state_dict(state_dict['model']) self.optimizer.load_state_dict(state_dict['optimizer']) def _setup_model_inference(self, mean_module_str, covar_module_str, mean_nn_layers, kernel_nn_layers, cov_type): assert mean_module_str in ['NN', 'constant'] assert covar_module_str in ['NN', 'SE'] """ random gp model """ self.random_gp = RandomGPMeta(size_in=self.input_dim, prior_factor=self.prior_factor, weight_prior_std=self.weight_prior_std, bias_prior_std=self.bias_prior_std, covar_module_str=covar_module_str, mean_module_str=mean_module_str, mean_nn_layers=mean_nn_layers, kernel_nn_layers=kernel_nn_layers) param_shapes_dict = self.random_gp.parameter_shapes() """ variational posterior """ self.posterior = RandomGPPosterior(param_shapes_dict, cov_type=cov_type) def _tile_data_tuples(tasks_dicts, tile_size): train_data_tuples_tiled = [] for task_dict in tasks_dicts: x_data, y_data = task_dict['train_x'], task_dict['train_y'] x_data = x_data.view(torch.Size((1,)) + x_data.shape).repeat(tile_size, 1, 1) y_data = y_data.view(torch.Size((1,)) + y_data.shape).repeat(tile_size, 1) train_data_tuples_tiled.append((x_data, y_data)) return train_data_tuples_tiled """ define negative ELBO """ def get_neg_elbo(tasks_dicts): # tile data to svi_batch_shape data_tuples_tiled = _tile_data_tuples(tasks_dicts, self.svi_batch_size) param_sample = self.posterior.rsample(sample_shape=(self.svi_batch_size,)) elbo = self.random_gp.log_prob(param_sample, data_tuples_tiled) - self.prior_factor * self.posterior.log_prob(param_sample) assert elbo.ndim == 1 and elbo.shape[0] == self.svi_batch_size return - torch.mean(elbo) self.get_neg_elbo = get_neg_elbo """ define predictive dist """ def get_pred_dist(x_context, y_context, x_valid, n_post_samples=100): with torch.no_grad(): x_context = x_context.view(torch.Size((1,)) + x_context.shape).repeat(n_post_samples, 1, 1) y_context = y_context.view(torch.Size((1,)) + y_context.shape).repeat(n_post_samples, 1) x_valid = x_valid.view(torch.Size((1,)) + x_valid.shape).repeat(n_post_samples, 1, 1) param_sample = self.posterior.sample(sample_shape=(n_post_samples,)) gp_fn = self.random_gp.get_forward_fn(param_sample) gp, likelihood = gp_fn(x_context, y_context, train=False) pred_dist = likelihood(gp(x_valid)) return pred_dist def get_pred_dist_map(x_context, y_context, x_valid): with torch.no_grad(): x_context = x_context.view(torch.Size((1,)) + x_context.shape).repeat(1, 1, 1) y_context = y_context.view(torch.Size((1,)) + y_context.shape).repeat(1, 1) x_valid = x_valid.view(torch.Size((1,)) + x_valid.shape).repeat(1, 1, 1) param = self.posterior.mode param = param.view(torch.Size((1,)) + param.shape).repeat(1, 1) gp_fn = self.random_gp.get_forward_fn(param) gp, likelihood = gp_fn(x_context, y_context, train=False) pred_dist = likelihood(gp(x_valid)) return MultivariateNormal(pred_dist.loc, pred_dist.covariance_matrix[0]) self.get_pred_dist = get_pred_dist self.get_pred_dist_map = get_pred_dist_map def _setup_optimizer(self, optimizer, lr, lr_decay): if optimizer == 'Adam': self.optimizer = torch.optim.Adam(self.posterior.parameters(), lr=lr) elif optimizer == 'SGD': self.optimizer = torch.optim.SGD(self.posterior.parameters(), lr=lr) else: raise NotImplementedError('Optimizer must be Adam or SGD') if lr_decay < 1.0: self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 1000, gamma=lr_decay) else: self.lr_scheduler = DummyLRScheduler() def _vectorize_pred_dist(self, pred_dist): multiv_normal_batched = pred_dist.dists normal_batched = torch.distributions.Normal(multiv_normal_batched.mean, multiv_normal_batched.stddev) return EqualWeightedMixtureDist(normal_batched, batched=True, num_dists=multiv_normal_batched.batch_shape[0])