예제 #1
0
    def _setup_meta_train_step(self, mean_module_str, covar_module_str,
                               mean_nn_layers, kernel_nn_layers, cov_type):
        assert mean_module_str in ['NN', 'constant']
        assert covar_module_str in ['NN', 'SE']
        """ random gp model """
        self.random_gp = RandomGPMeta(size_in=self.input_dim,
                                      prior_factor=1.0,
                                      weight_prior_std=self.weight_prior_std,
                                      bias_prior_std=self.bias_prior_std,
                                      covar_module_str=covar_module_str,
                                      mean_module_str=mean_module_str,
                                      mean_nn_layers=mean_nn_layers,
                                      kernel_nn_layers=kernel_nn_layers)

        param_shapes_dict = self.random_gp.parameter_shapes()
        """ variational posterior """
        self.hyper_posterior = RandomGPPosterior(param_shapes_dict,
                                                 cov_type=cov_type)

        def _tile_data_tuple(task_dict, tile_size):
            x_data, y_data = task_dict['train_x'], task_dict['train_y']
            x_data = x_data.view(torch.Size((1, )) + x_data.shape).repeat(
                tile_size, 1, 1)
            y_data = y_data.view(torch.Size((1, )) + y_data.shape).repeat(
                tile_size, 1)
            return x_data, y_data

        def _hyper_kl(prior_param_sample):
            return torch.mean(
                self.hyper_posterior.log_prob(prior_param_sample) -
                self.random_gp.hyper_prior.log_prob(prior_param_sample))

        def _task_pac_bounds(task_dicts,
                             prior_param_sample,
                             task_kl_weight=1.0,
                             meta_kl_weight=1.0):

            fn = self.random_gp.get_forward_fn(prior_param_sample)

            kl_outer = meta_kl_weight * _hyper_kl(prior_param_sample)

            task_pac_bounds = []
            for task_dict in task_dicts:
                posterior = task_dict["gp_model"](task_dict["train_x"])

                # likelihood
                avg_ll = torch.mean(
                    self.likelihood.expected_log_prob(task_dict["train_y"],
                                                      posterior))

                # task complexity
                x_data_tiled, y_data_tiled = _tile_data_tuple(
                    task_dict, self.svi_batch_size)
                gp, _ = fn(x_data_tiled, None, prior=True)
                prior = gp(x_data_tiled)

                kl_inner = task_kl_weight * torch.mean(
                    _kl_divergence_safe(
                        posterior.expand((self.svi_batch_size, )), prior))

                m = torch.tensor(task_dict["train_y"].shape[0],
                                 dtype=torch.float32)
                n = torch.tensor(self.n_tasks, dtype=torch.float32)
                task_complexity = torch.sqrt(
                    (kl_outer + kl_inner + math.log(2.) + torch.log(m) +
                     torch.log(n) - torch.log(self.delta)) / (2 * (m - 1)))

                diagnostics_dict = {
                    'avg_ll': avg_ll.item(),
                    'kl_outer_weighted': kl_outer.item(),
                    'kl_inner_weighted': kl_inner.item(),
                }

                task_pac_bound = -avg_ll + task_complexity
                task_pac_bounds.append(task_pac_bound)
            return task_pac_bounds, diagnostics_dict

        def _meta_complexity(prior_param_sample, meta_kl_weight=1.0):
            outer_kl = _hyper_kl(prior_param_sample)
            n = torch.tensor(self.n_tasks, dtype=torch.float32)
            return torch.sqrt(meta_kl_weight * outer_kl + math.log(2.) +
                              torch.log(n) - torch.log(self.delta) / (2 *
                                                                      (n - 1)))

        def _meta_train_pac_bound(task_dicts):
            param_sample = self.hyper_posterior.rsample(
                sample_shape=(self.svi_batch_size, ))

            task_pac_bounds, diagnostics_dict = _task_pac_bounds(
                task_dicts,
                param_sample,
                task_kl_weight=self.task_kl_weight,
                meta_kl_weight=self.meta_kl_weight)
            meta_complexity = _meta_complexity(
                param_sample, meta_kl_weight=self.meta_kl_weight)

            pac_bound = torch.mean(
                torch.stack(task_pac_bounds)) + meta_complexity
            return pac_bound, diagnostics_dict

        self._task_pac_bounds = _task_pac_bounds
        self._meta_train_pac_bound = _meta_train_pac_bound
예제 #2
0
class GPRegressionMetaLearnedPAC(RegressionModelMetaLearned):
    def __init__(self,
                 meta_train_data,
                 num_iter_fit=40000,
                 feature_dim=1,
                 weight_prior_std=0.5,
                 bias_prior_std=3.0,
                 delta=0.1,
                 task_kl_weight=1.0,
                 meta_kl_weight=1.0,
                 posterior_lr_multiplier=1.0,
                 covar_module='SE',
                 mean_module='zero',
                 mean_nn_layers=(32, 32),
                 kernel_nn_layers=(32, 32),
                 optimizer='Adam',
                 lr=1e-3,
                 lr_decay=1.0,
                 svi_batch_size=5,
                 cov_type='diag',
                 task_batch_size=4,
                 normalize_data=True,
                 random_seed=None):
        """
        PACOH-VI: Variational Inference on the PAC-optimal hyper-posterior with Gaussian family.
        Meta-Learns a distribution over GP-priors.

        Args:
            meta_train_data: list of tuples of ndarrays[(train_x_1, train_t_1), ..., (train_x_n, train_t_n)]
            num_iter_fit: (int) number of gradient steps for fitting the parameters
            feature_dim: (int) output dimensionality of NN feature map for kernel function
            prior_factor: (float) weighting of the hyper-prior (--> meta-regularization parameter)
            weight_prior_std (float): std of Gaussian hyper-prior on weights
            bias_prior_std (float): std of Gaussian hyper-prior on biases
            covar_module: (gpytorch.mean.Kernel) optional kernel module, default: RBF kernel
            mean_module: (gpytorch.mean.Mean) optional mean module, default: ZeroMean
            mean_nn_layers: (tuple) hidden layer sizes of mean NN
            kernel_nn_layers: (tuple) hidden layer sizes of kernel NN
            optimizer: (str) type of optimizer to use - must be either 'Adam' or 'SGD'
            lr: (float) learning rate for prior parameters
            lr_decay: (float) lr rate decay multiplier applied after every 1000 steps
            kernel (std): SVGD kernel, either 'RBF' or 'IMQ'
            bandwidth (float): bandwidth of kernel, if None the bandwidth is chosen via heuristic
            num_particles: (int) number particles to approximate the hyper-posterior
            task_batch_size: (int) mini-batch size of tasks for estimating gradients
            normalize_data: (bool) whether the data should be normalized
            random_seed: (int) seed for pytorch
        """
        super().__init__(normalize_data, random_seed)

        assert mean_module in ['NN', 'constant', 'zero'] or isinstance(
            mean_module, gpytorch.means.Mean)
        assert covar_module in ['NN', 'SE'] or isinstance(
            covar_module, gpytorch.kernels.Kernel)
        assert optimizer in ['Adam', 'SGD']

        self.num_iter_fit, self.feature_dim = num_iter_fit, feature_dim
        self.task_kl_weight, self.meta_kl_weight = task_kl_weight, meta_kl_weight
        self.weight_prior_std, self.bias_prior_std = weight_prior_std, bias_prior_std
        self.svi_batch_size = svi_batch_size
        self.lr = lr
        self.n_tasks = len(meta_train_data)
        self.delta = torch.tensor(delta, dtype=torch.float32)
        if task_batch_size < 1:
            self.task_batch_size = len(meta_train_data)
        else:
            self.task_batch_size = min(task_batch_size, len(meta_train_data))

        # Check that data all has the same size
        self._check_meta_data_shapes(meta_train_data)
        self._compute_normalization_stats(meta_train_data)
        """ --- Setup model & inference --- """
        self.meta_train_params = []

        self._setup_meta_train_step(mean_module, covar_module, mean_nn_layers,
                                    kernel_nn_layers, cov_type)
        self.meta_train_params.append({
            'params':
            self.hyper_posterior.parameters(),
            'lr':
            lr
        })

        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()
        self.meta_train_params.append({
            'params': self.likelihood.parameters(),
            'lr': lr
        })

        # Setup components that are different across tasks
        self.task_dicts, posterior_params = self._setup_task_dicts(
            meta_train_data)
        self.meta_train_params.append({
            'params': posterior_params,
            'lr': posterior_lr_multiplier * lr
        })

        self._setup_optimizer(optimizer, lr, lr_decay)

        self.fitted = False

    def meta_fit(self,
                 valid_tuples=None,
                 verbose=True,
                 log_period=500,
                 eval_period=5000,
                 n_iter=None):
        """
        fits the variational hyper-posterior by minimizing the negative ELBO

        Args:
            valid_tuples: list of valid tuples, i.e. [(test_context_x_1, test_context_t_1, test_x_1, test_t_1), ...]
            verbose: (boolean) whether to print training progress
            log_period (int) number of steps after which to print stats
            n_iter: (int) number of gradient descent iterations
        """
        assert eval_period % log_period == 0, "eval_period should be multiple of log_period"
        assert (valid_tuples is None) or (all(
            [len(valid_tuple) == 4 for valid_tuple in valid_tuples]))

        t = time.time()

        if n_iter is None:
            n_iter = self.num_iter_fit

        for itr in range(1, n_iter + 1):

            task_dict_batch = self.rds_numpy.choice(self.task_dicts,
                                                    size=self.task_batch_size)
            self.optimizer.zero_grad()
            loss, diagnostics_dict = self._meta_train_pac_bound(
                task_dict_batch)
            loss.backward()
            self.optimizer.step()
            self.lr_scheduler.step()

            # print training stats stats
            if verbose and (itr == 1 or itr % log_period == 0):
                duration = time.time() - t
                t = time.time()

                message = 'Iter %d/%d - Loss: %.6f - Time %.2f sec - ' % (
                    itr, self.num_iter_fit, loss.item(), duration)

                # if validation data is provided  -> compute the valid log-likelihood
                if valid_tuples is not None and itr % eval_period == 0 and itr > 0:
                    valid_ll, valid_rmse, calibr_err = self.eval_datasets(
                        valid_tuples)
                    message += ' - Valid-LL: %.3f - Valid-RMSE: %.3f - Calib-Err %.3f' % (
                        valid_ll, valid_rmse, calibr_err)

                # add diagnostics
                message += ' - '.join([
                    '%s: %.4f' % (key, value)
                    for key, value in diagnostics_dict.items()
                ])
                self.logger.info(message)

        self.fitted = True
        return loss.item(), diagnostics_dict

    def predict(self,
                context_x,
                context_y,
                test_x,
                n_iter_meta_test=3000,
                return_density=False):
        """
        computes the predictive distribution of the targets p(t|test_x, test_context_x, context_y)

        Args:
            context_x: (ndarray) context input data for which to compute the posterior
            context_y: (ndarray) context targets for which to compute the posterior
            test_x: (ndarray) query input data of shape (n_samples, ndim_x)
                        n_posterior_samples: (int) number of samples from posterior to average over
            mode: (std) either of ['Bayes' , 'MAP']
            return_density: (bool) whether to return result as mean and std ndarray or as MultivariateNormal pytorch object

        Returns:
            (pred_mean, pred_std) predicted mean and standard deviation corresponding to p(t|test_x, test_context_x, context_y)
        """
        context_x, context_y = _handle_input_dimensionality(
            context_x, context_y)
        test_x = _handle_input_dimensionality(test_x)
        assert test_x.shape[1] == context_x.shape[1]

        # meta-test training / inference
        task_dict = self._meta_test_inference([(context_x, context_y)],
                                              verbose=True,
                                              log_period=500,
                                              n_iter=n_iter_meta_test)[0]

        with torch.no_grad():
            # meta-test evaluation
            test_x = self._normalize_data(X=test_x, Y=None)
            test_x = torch.from_numpy(test_x).float().to(device)

            gp_model = task_dict["gp_model"]
            gp_model.eval()
            pred_dist = self.likelihood(gp_model(test_x))
            pred_dist = AffineTransformedDistribution(
                pred_dist,
                normalization_mean=self.y_mean,
                normalization_std=self.y_std)
            if return_density:
                return pred_dist
            else:
                pred_mean = pred_dist.mean.cpu().numpy()
                pred_std = pred_dist.stddev.cpu().numpy()
                return pred_mean, pred_std

    def eval_datasets(self, test_tuples, n_iter_meta_test=3000, **kwargs):
        """
        Performs meta-testing on multiple tasks / datasets.
        Computes the average test log likelihood, the rmse and the calibration error over multiple test datasets

        Args:
            test_tuples: list of test set tuples, i.e. [(test_context_x_1, test_context_y_1, test_x_1, test_y_1), ...]

        Returns: (avg_log_likelihood, rmse, calibr_error)

        """
        assert (all([len(valid_tuple) == 4 for valid_tuple in test_tuples]))

        # meta-test training / inference
        context_tuples = [test_tuple[:2] for test_tuple in test_tuples]
        task_dicts = self._meta_test_inference(context_tuples,
                                               verbose=True,
                                               log_period=500,
                                               n_iter=n_iter_meta_test)

        # meta-test evaluation
        ll_list, rmse_list, calibr_err_list = [], [], []
        for task_dict, test_tuple in zip(task_dicts, test_tuples):
            # data prep
            _, _, test_x, test_y = test_tuple
            test_x_tensor = torch.from_numpy(
                self._normalize_data(X=test_x, Y=None)).float().to(device)
            test_y_tensor = torch.from_numpy(test_y).float().flatten().to(
                device)

            # get predictive dist
            gp_model = task_dict["gp_model"]
            gp_model.eval()
            pred_dist = self.likelihood(gp_model(test_x_tensor))
            pred_dist = AffineTransformedDistribution(
                pred_dist,
                normalization_mean=self.y_mean,
                normalization_std=self.y_std)

            # compute eval metrics
            ll_list.append(
                torch.mean(
                    pred_dist.log_prob(test_y_tensor) /
                    test_y_tensor.shape[0]).cpu().item())
            rmse_list.append(
                torch.mean(torch.pow(pred_dist.mean - test_y_tensor,
                                     2)).sqrt().cpu().item())
            pred_dist_vect = self._vectorize_pred_dist(pred_dist)
            calibr_err_list.append(
                self._calib_error(pred_dist_vect, test_y_tensor).cpu().item())

        return np.mean(ll_list), np.mean(rmse_list), np.mean(calibr_err_list)

    def state_dict(self):
        state_dict = {
            'optimizer': self.optimizer.state_dict(),
            'model': self.task_dicts[0]['model'].state_dict()
        }
        for task_dict in self.task_dicts:
            for key, tensor in task_dict['model'].state_dict().items():
                assert torch.all(state_dict['model'][key] == tensor).item()
        return state_dict

    def load_state_dict(self, state_dict):
        for task_dict in self.task_dicts:
            task_dict['model'].load_state_dict(state_dict['model'])
        self.optimizer.load_state_dict(state_dict['optimizer'])

    def _setup_task_dicts(self, train_data_tuples):
        task_dicts, parameters = [], []

        for train_x, train_y in train_data_tuples:
            task_dict = OrderedDict()

            # a) prepare data
            x_tensor, y_tensor = self._prepare_data_per_task(train_x, train_y)
            task_dict['train_x'], task_dict['train_y'] = x_tensor, y_tensor
            task_dict['gp_model'] = LearnedGPRegressionModelApproximate(
                x_tensor,
                y_tensor,
                self.likelihood,
                mean_module=gpytorch.means.ZeroMean(),
                covar_module=gpytorch.kernels.RBFKernel())
            parameters.extend(task_dict['gp_model'].variational_parameters())
            task_dicts.append(task_dict)

        return task_dicts, parameters

    def _meta_test_inference(self,
                             context_tuples,
                             n_iter=3000,
                             lr=1e-2,
                             log_period=100,
                             verbose=False):
        n_tasks = len(context_tuples)
        task_dicts, posterior_params = self._setup_task_dicts(context_tuples)

        optimizer = torch.optim.Adam(posterior_params, lr=lr)

        t = time.time()
        for itr in range(n_iter):
            optimizer.zero_grad()
            param_sample = self.hyper_posterior.rsample(
                sample_shape=(self.svi_batch_size, ))
            task_pac_bounds, diagnostics_dict = self._task_pac_bounds(
                task_dicts,
                param_sample,
                task_kl_weight=self.task_kl_weight,
                meta_kl_weight=self.meta_kl_weight)
            loss = torch.sum(torch.stack(task_pac_bounds))
            loss.backward()
            optimizer.step()

            if itr % log_period == 0 and verbose:
                duration = time.time() - t
                t = time.time()
                message = '\t Meta-Test Iter %d/%d - Loss: %.6f - Time %.2f sec - ' % (
                    itr, n_iter, loss.item() / n_tasks, duration)
                # add diagnostics
                message += ' - '.join([
                    '%s: %.4f' % (key, value)
                    for key, value in diagnostics_dict.items()
                ])
                self.logger.info(message)

        return task_dicts

    def _setup_meta_train_step(self, mean_module_str, covar_module_str,
                               mean_nn_layers, kernel_nn_layers, cov_type):
        assert mean_module_str in ['NN', 'constant']
        assert covar_module_str in ['NN', 'SE']
        """ random gp model """
        self.random_gp = RandomGPMeta(size_in=self.input_dim,
                                      prior_factor=1.0,
                                      weight_prior_std=self.weight_prior_std,
                                      bias_prior_std=self.bias_prior_std,
                                      covar_module_str=covar_module_str,
                                      mean_module_str=mean_module_str,
                                      mean_nn_layers=mean_nn_layers,
                                      kernel_nn_layers=kernel_nn_layers)

        param_shapes_dict = self.random_gp.parameter_shapes()
        """ variational posterior """
        self.hyper_posterior = RandomGPPosterior(param_shapes_dict,
                                                 cov_type=cov_type)

        def _tile_data_tuple(task_dict, tile_size):
            x_data, y_data = task_dict['train_x'], task_dict['train_y']
            x_data = x_data.view(torch.Size((1, )) + x_data.shape).repeat(
                tile_size, 1, 1)
            y_data = y_data.view(torch.Size((1, )) + y_data.shape).repeat(
                tile_size, 1)
            return x_data, y_data

        def _hyper_kl(prior_param_sample):
            return torch.mean(
                self.hyper_posterior.log_prob(prior_param_sample) -
                self.random_gp.hyper_prior.log_prob(prior_param_sample))

        def _task_pac_bounds(task_dicts,
                             prior_param_sample,
                             task_kl_weight=1.0,
                             meta_kl_weight=1.0):

            fn = self.random_gp.get_forward_fn(prior_param_sample)

            kl_outer = meta_kl_weight * _hyper_kl(prior_param_sample)

            task_pac_bounds = []
            for task_dict in task_dicts:
                posterior = task_dict["gp_model"](task_dict["train_x"])

                # likelihood
                avg_ll = torch.mean(
                    self.likelihood.expected_log_prob(task_dict["train_y"],
                                                      posterior))

                # task complexity
                x_data_tiled, y_data_tiled = _tile_data_tuple(
                    task_dict, self.svi_batch_size)
                gp, _ = fn(x_data_tiled, None, prior=True)
                prior = gp(x_data_tiled)

                kl_inner = task_kl_weight * torch.mean(
                    _kl_divergence_safe(
                        posterior.expand((self.svi_batch_size, )), prior))

                m = torch.tensor(task_dict["train_y"].shape[0],
                                 dtype=torch.float32)
                n = torch.tensor(self.n_tasks, dtype=torch.float32)
                task_complexity = torch.sqrt(
                    (kl_outer + kl_inner + math.log(2.) + torch.log(m) +
                     torch.log(n) - torch.log(self.delta)) / (2 * (m - 1)))

                diagnostics_dict = {
                    'avg_ll': avg_ll.item(),
                    'kl_outer_weighted': kl_outer.item(),
                    'kl_inner_weighted': kl_inner.item(),
                }

                task_pac_bound = -avg_ll + task_complexity
                task_pac_bounds.append(task_pac_bound)
            return task_pac_bounds, diagnostics_dict

        def _meta_complexity(prior_param_sample, meta_kl_weight=1.0):
            outer_kl = _hyper_kl(prior_param_sample)
            n = torch.tensor(self.n_tasks, dtype=torch.float32)
            return torch.sqrt(meta_kl_weight * outer_kl + math.log(2.) +
                              torch.log(n) - torch.log(self.delta) / (2 *
                                                                      (n - 1)))

        def _meta_train_pac_bound(task_dicts):
            param_sample = self.hyper_posterior.rsample(
                sample_shape=(self.svi_batch_size, ))

            task_pac_bounds, diagnostics_dict = _task_pac_bounds(
                task_dicts,
                param_sample,
                task_kl_weight=self.task_kl_weight,
                meta_kl_weight=self.meta_kl_weight)
            meta_complexity = _meta_complexity(
                param_sample, meta_kl_weight=self.meta_kl_weight)

            pac_bound = torch.mean(
                torch.stack(task_pac_bounds)) + meta_complexity
            return pac_bound, diagnostics_dict

        self._task_pac_bounds = _task_pac_bounds
        self._meta_train_pac_bound = _meta_train_pac_bound

    def _setup_optimizer(self, optimizer, lr, lr_decay):
        if optimizer == 'Adam':
            self.optimizer = torch.optim.Adam(self.meta_train_params, lr=lr)
        elif optimizer == 'SGD':
            self.optimizer = torch.optim.SGD(self.meta_train_params, lr=lr)
        else:
            raise NotImplementedError('Optimizer must be Adam or SGD')

        if lr_decay < 1.0:
            self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer,
                                                                1000,
                                                                gamma=lr_decay)
        else:
            self.lr_scheduler = DummyLRScheduler()

    def _vectorize_pred_dist(self, pred_dist):
        # converts a multivariate gaussian into a vectorized univariate gaussian
        return torch.distributions.Normal(pred_dist.mean, pred_dist.stddev)

    def prior_mean(self, x, n_hyperposterior_samples=1000):
        x = (x - self.x_mean) / self.x_std
        assert x.ndim == 1 or (x.ndim == 2 and x.shape[-1] == 1)
        x_data_tiled = np.tile(x.reshape(1, x.shape[0], 1),
                               (n_hyperposterior_samples, 1, 1))
        x_data_tiled = torch.tensor(x_data_tiled, dtype=torch.float32)

        with torch.no_grad():
            param_sample = self.hyper_posterior.rsample(
                sample_shape=(n_hyperposterior_samples, ))
            fn = self.random_gp.get_forward_fn(param_sample)
            gp, _ = fn(x_data_tiled, None, prior=True)
            prior = gp(x_data_tiled)
            mean = torch.mean(prior.mean,
                              axis=0).numpy() * self.y_std + self.y_mean
            # torch.mean(gp.learned_mean(x_data_tiled), axis=0).numpy() * self.y_std + self.y_mean
        return mean
예제 #3
0
    def _setup_model_inference(self, mean_module_str, covar_module_str, mean_nn_layers, kernel_nn_layers, cov_type):
        assert mean_module_str in ['NN', 'constant']
        assert covar_module_str in ['NN', 'SE']

        """ random gp model """
        self.random_gp = RandomGPMeta(size_in=self.input_dim, prior_factor=self.prior_factor,
                                  weight_prior_std=self.weight_prior_std, bias_prior_std=self.bias_prior_std,
                                  covar_module_str=covar_module_str, mean_module_str=mean_module_str,
                                  mean_nn_layers=mean_nn_layers, kernel_nn_layers=kernel_nn_layers)

        param_shapes_dict = self.random_gp.parameter_shapes()

        """ variational posterior """
        self.posterior = RandomGPPosterior(param_shapes_dict, cov_type=cov_type)

        def _tile_data_tuples(tasks_dicts, tile_size):
            train_data_tuples_tiled = []
            for task_dict in tasks_dicts:
                x_data, y_data = task_dict['train_x'], task_dict['train_y']
                x_data = x_data.view(torch.Size((1,)) + x_data.shape).repeat(tile_size, 1, 1)
                y_data = y_data.view(torch.Size((1,)) + y_data.shape).repeat(tile_size, 1)
                train_data_tuples_tiled.append((x_data, y_data))
            return train_data_tuples_tiled

        """ define negative ELBO """
        def get_neg_elbo(tasks_dicts):
            # tile data to svi_batch_shape
            data_tuples_tiled = _tile_data_tuples(tasks_dicts, self.svi_batch_size)

            param_sample = self.posterior.rsample(sample_shape=(self.svi_batch_size,))
            elbo = self.random_gp.log_prob(param_sample, data_tuples_tiled) - self.prior_factor * self.posterior.log_prob(param_sample)

            assert elbo.ndim == 1 and elbo.shape[0] == self.svi_batch_size
            return - torch.mean(elbo)

        self.get_neg_elbo = get_neg_elbo

        """ define predictive dist """
        def get_pred_dist(x_context, y_context, x_valid, n_post_samples=100):
            with torch.no_grad():
                x_context = x_context.view(torch.Size((1,)) + x_context.shape).repeat(n_post_samples, 1, 1)
                y_context = y_context.view(torch.Size((1,)) + y_context.shape).repeat(n_post_samples, 1)
                x_valid = x_valid.view(torch.Size((1,)) + x_valid.shape).repeat(n_post_samples, 1, 1)

                param_sample = self.posterior.sample(sample_shape=(n_post_samples,))
                gp_fn = self.random_gp.get_forward_fn(param_sample)
                gp, likelihood = gp_fn(x_context, y_context, train=False)
                pred_dist = likelihood(gp(x_valid))
            return pred_dist

        def get_pred_dist_map(x_context, y_context, x_valid):
            with torch.no_grad():
                x_context = x_context.view(torch.Size((1,)) + x_context.shape).repeat(1, 1, 1)
                y_context = y_context.view(torch.Size((1,)) + y_context.shape).repeat(1, 1)
                x_valid = x_valid.view(torch.Size((1,)) + x_valid.shape).repeat(1, 1, 1)
                param = self.posterior.mode
                param = param.view(torch.Size((1,)) + param.shape).repeat(1, 1)

                gp_fn = self.random_gp.get_forward_fn(param)
                gp, likelihood = gp_fn(x_context, y_context, train=False)
                pred_dist = likelihood(gp(x_valid))
            return MultivariateNormal(pred_dist.loc, pred_dist.covariance_matrix[0])


        self.get_pred_dist = get_pred_dist
        self.get_pred_dist_map = get_pred_dist_map
예제 #4
0
class GPRegressionMetaLearnedVI(RegressionModelMetaLearned):

    def __init__(self, meta_train_data, num_iter_fit=10000, feature_dim=1,
                 prior_factor=0.01, weight_prior_std=0.5, bias_prior_std=3.0,
                 covar_module='NN', mean_module='NN', mean_nn_layers=(32, 32), kernel_nn_layers=(32, 32),
                 optimizer='Adam', lr=1e-3, lr_decay=1.0, svi_batch_size=10, cov_type='diag',
                 task_batch_size=-1, normalize_data=True, random_seed=None):
        """
        PACOH-VI: Variational Inference on the PAC-optimal hyper-posterior with Gaussian family.
        Meta-Learns a distribution over GP-priors.

        Args:
            meta_train_data: list of tuples of ndarrays[(train_x_1, train_t_1), ..., (train_x_n, train_t_n)]
            num_iter_fit: (int) number of gradient steps for fitting the parameters
            feature_dim: (int) output dimensionality of NN feature map for kernel function
            prior_factor: (float) weighting of the hyper-prior (--> meta-regularization parameter)
            weight_prior_std (float): std of Gaussian hyper-prior on weights
            bias_prior_std (float): std of Gaussian hyper-prior on biases
            covar_module: (gpytorch.mean.Kernel) optional kernel module, default: RBF kernel
            mean_module: (gpytorch.mean.Mean) optional mean module, default: ZeroMean
            mean_nn_layers: (tuple) hidden layer sizes of mean NN
            kernel_nn_layers: (tuple) hidden layer sizes of kernel NN
            optimizer: (str) type of optimizer to use - must be either 'Adam' or 'SGD'
            lr: (float) learning rate for prior parameters
            lr_decay: (float) lr rate decay multiplier applied after every 1000 steps
            kernel (std): SVGD kernel, either 'RBF' or 'IMQ'
            bandwidth (float): bandwidth of kernel, if None the bandwidth is chosen via heuristic
            num_particles: (int) number particles to approximate the hyper-posterior
            task_batch_size: (int) mini-batch size of tasks for estimating gradients
            normalize_data: (bool) whether the data should be normalized
            random_seed: (int) seed for pytorch
        """
        super().__init__(normalize_data, random_seed)

        assert mean_module in ['NN', 'constant', 'zero'] or isinstance(mean_module, gpytorch.means.Mean)
        assert covar_module in ['NN', 'SE'] or isinstance(covar_module, gpytorch.kernels.Kernel)
        assert optimizer in ['Adam', 'SGD']

        self.num_iter_fit, self.prior_factor, self.feature_dim = num_iter_fit, prior_factor, feature_dim
        self.weight_prior_std, self.bias_prior_std = weight_prior_std, bias_prior_std
        self.svi_batch_size = svi_batch_size
        if task_batch_size < 1:
            self.task_batch_size = len(meta_train_data)
        else:
            self.task_batch_size = min(task_batch_size, len(meta_train_data))

        # Check that data all has the same size
        self._check_meta_data_shapes(meta_train_data)
        self._compute_normalization_stats(meta_train_data)

        """ --- Setup model & inference --- """
        self._setup_model_inference(mean_module, covar_module, mean_nn_layers, kernel_nn_layers,
                                    cov_type)

        self._setup_optimizer(optimizer, lr, lr_decay)

        # Setup components that are different across tasks
        self.task_dicts = []

        for train_x, train_y in meta_train_data:
            task_dict = {}

            # a) prepare data
            x_tensor, y_tensor = self._prepare_data_per_task(train_x, train_y)
            task_dict['train_x'], task_dict['train_y'] = x_tensor, y_tensor
            self.task_dicts.append(task_dict)

        self.fitted = False


    def meta_fit(self, valid_tuples=None, verbose=True, log_period=500, n_iter=None):

        """
        fits the variational hyper-posterior by minimizing the negative ELBO

        Args:
            valid_tuples: list of valid tuples, i.e. [(test_context_x_1, test_context_t_1, test_x_1, test_t_1), ...]
            verbose: (boolean) whether to print training progress
            log_period (int) number of steps after which to print stats
            n_iter: (int) number of gradient descent iterations
        """

        assert (valid_tuples is None) or (all([len(valid_tuple) == 4 for valid_tuple in valid_tuples]))

        t = time.time()

        if n_iter is None:
            n_iter = self.num_iter_fit

        for itr in range(1, n_iter + 1):

            task_dict_batch = self.rds_numpy.choice(self.task_dicts, size=self.task_batch_size)
            self.optimizer.zero_grad()
            loss = self.get_neg_elbo(task_dict_batch)
            loss.backward()
            self.optimizer.step()
            self.lr_scheduler.step()

            # print training stats stats
            if itr == 1 or itr % log_period == 0:
                duration = time.time() - t
                t = time.time()

                message = 'Iter %d/%d - Loss: %.6f - Time %.2f sec' % (itr, self.num_iter_fit, loss.item(), duration)

                # if validation data is provided  -> compute the valid log-likelihood
                if valid_tuples is not None:
                    valid_ll, valid_rmse, calibr_err = self.eval_datasets(valid_tuples)
                    message += ' - Valid-LL: %.3f - Valid-RMSE: %.3f - Calib-Err %.3f' % (valid_ll, valid_rmse, calibr_err)

                if verbose:
                    self.logger.info(message)

        self.fitted = True
        return loss.item()

    def predict(self, context_x, context_y, test_x, n_posterior_samples=100, mode='Bayes', return_density=False):
        """
        computes the predictive distribution of the targets p(t|test_x, test_context_x, context_y)

        Args:
            context_x: (ndarray) context input data for which to compute the posterior
            context_y: (ndarray) context targets for which to compute the posterior
            test_x: (ndarray) query input data of shape (n_samples, ndim_x)
                        n_posterior_samples: (int) number of samples from posterior to average over
            mode: (std) either of ['Bayes' , 'MAP']
            return_density: (bool) whether to return result as mean and std ndarray or as MultivariateNormal pytorch object

        Returns:
            (pred_mean, pred_std) predicted mean and standard deviation corresponding to p(t|test_x, test_context_x, context_y)
        """
        assert mode in ['bayes', 'Bayes', 'MAP', 'map']

        context_x, context_y = _handle_input_dimensionality(context_x, context_y)
        test_x = _handle_input_dimensionality(test_x)
        assert test_x.shape[1] == context_x.shape[1]

        # normalize data and convert to tensor
        context_x, context_y = self._prepare_data_per_task(context_x, context_y)

        test_x = self._normalize_data(X=test_x, Y=None)
        test_x = torch.from_numpy(test_x).float().to(device)

        with torch.no_grad():

            if mode == 'Bayes' or mode == 'bayes':
                pred_dist = self.get_pred_dist(context_x, context_y, test_x, n_post_samples=n_posterior_samples)
                pred_dist = AffineTransformedDistribution(pred_dist, normalization_mean=self.y_mean,
                                                      normalization_std=self.y_std)

                pred_dist = EqualWeightedMixtureDist(pred_dist, batched=True)
            else:
                pred_dist = self.get_pred_dist_map(context_x, context_y, test_x)
                pred_dist = AffineTransformedDistribution(pred_dist, normalization_mean=self.y_mean,
                                                      normalization_std=self.y_std)
            if return_density:
                return pred_dist
            else:
                pred_mean = pred_dist.mean.cpu().numpy()
                pred_std = pred_dist.stddev.cpu().numpy()
                return pred_mean, pred_std

    def state_dict(self):
        state_dict = {
            'optimizer': self.optimizer.state_dict(),
            'model': self.task_dicts[0]['model'].state_dict()
        }
        for task_dict in self.task_dicts:
            for key, tensor in task_dict['model'].state_dict().items():
                assert torch.all(state_dict['model'][key] == tensor).item()
        return state_dict

    def load_state_dict(self, state_dict):
        for task_dict in self.task_dicts:
            task_dict['model'].load_state_dict(state_dict['model'])
        self.optimizer.load_state_dict(state_dict['optimizer'])

    def _setup_model_inference(self, mean_module_str, covar_module_str, mean_nn_layers, kernel_nn_layers, cov_type):
        assert mean_module_str in ['NN', 'constant']
        assert covar_module_str in ['NN', 'SE']

        """ random gp model """
        self.random_gp = RandomGPMeta(size_in=self.input_dim, prior_factor=self.prior_factor,
                                  weight_prior_std=self.weight_prior_std, bias_prior_std=self.bias_prior_std,
                                  covar_module_str=covar_module_str, mean_module_str=mean_module_str,
                                  mean_nn_layers=mean_nn_layers, kernel_nn_layers=kernel_nn_layers)

        param_shapes_dict = self.random_gp.parameter_shapes()

        """ variational posterior """
        self.posterior = RandomGPPosterior(param_shapes_dict, cov_type=cov_type)

        def _tile_data_tuples(tasks_dicts, tile_size):
            train_data_tuples_tiled = []
            for task_dict in tasks_dicts:
                x_data, y_data = task_dict['train_x'], task_dict['train_y']
                x_data = x_data.view(torch.Size((1,)) + x_data.shape).repeat(tile_size, 1, 1)
                y_data = y_data.view(torch.Size((1,)) + y_data.shape).repeat(tile_size, 1)
                train_data_tuples_tiled.append((x_data, y_data))
            return train_data_tuples_tiled

        """ define negative ELBO """
        def get_neg_elbo(tasks_dicts):
            # tile data to svi_batch_shape
            data_tuples_tiled = _tile_data_tuples(tasks_dicts, self.svi_batch_size)

            param_sample = self.posterior.rsample(sample_shape=(self.svi_batch_size,))
            elbo = self.random_gp.log_prob(param_sample, data_tuples_tiled) - self.prior_factor * self.posterior.log_prob(param_sample)

            assert elbo.ndim == 1 and elbo.shape[0] == self.svi_batch_size
            return - torch.mean(elbo)

        self.get_neg_elbo = get_neg_elbo

        """ define predictive dist """
        def get_pred_dist(x_context, y_context, x_valid, n_post_samples=100):
            with torch.no_grad():
                x_context = x_context.view(torch.Size((1,)) + x_context.shape).repeat(n_post_samples, 1, 1)
                y_context = y_context.view(torch.Size((1,)) + y_context.shape).repeat(n_post_samples, 1)
                x_valid = x_valid.view(torch.Size((1,)) + x_valid.shape).repeat(n_post_samples, 1, 1)

                param_sample = self.posterior.sample(sample_shape=(n_post_samples,))
                gp_fn = self.random_gp.get_forward_fn(param_sample)
                gp, likelihood = gp_fn(x_context, y_context, train=False)
                pred_dist = likelihood(gp(x_valid))
            return pred_dist

        def get_pred_dist_map(x_context, y_context, x_valid):
            with torch.no_grad():
                x_context = x_context.view(torch.Size((1,)) + x_context.shape).repeat(1, 1, 1)
                y_context = y_context.view(torch.Size((1,)) + y_context.shape).repeat(1, 1)
                x_valid = x_valid.view(torch.Size((1,)) + x_valid.shape).repeat(1, 1, 1)
                param = self.posterior.mode
                param = param.view(torch.Size((1,)) + param.shape).repeat(1, 1)

                gp_fn = self.random_gp.get_forward_fn(param)
                gp, likelihood = gp_fn(x_context, y_context, train=False)
                pred_dist = likelihood(gp(x_valid))
            return MultivariateNormal(pred_dist.loc, pred_dist.covariance_matrix[0])


        self.get_pred_dist = get_pred_dist
        self.get_pred_dist_map = get_pred_dist_map

    def _setup_optimizer(self, optimizer, lr, lr_decay):
        if optimizer == 'Adam':
            self.optimizer = torch.optim.Adam(self.posterior.parameters(), lr=lr)
        elif optimizer == 'SGD':
            self.optimizer = torch.optim.SGD(self.posterior.parameters(), lr=lr)
        else:
            raise NotImplementedError('Optimizer must be Adam or SGD')

        if lr_decay < 1.0:
            self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 1000, gamma=lr_decay)
        else:
            self.lr_scheduler = DummyLRScheduler()

    def _vectorize_pred_dist(self, pred_dist):
        multiv_normal_batched = pred_dist.dists
        normal_batched = torch.distributions.Normal(multiv_normal_batched.mean, multiv_normal_batched.stddev)
        return EqualWeightedMixtureDist(normal_batched, batched=True, num_dists=multiv_normal_batched.batch_shape[0])
예제 #5
0
    def _setup_model_inference(self, mean_module_str, covar_module_str,
                               mean_nn_layers, kernel_nn_layers, kernel,
                               bandwidth, optimizer, lr, lr_decay):
        assert mean_module_str in ['NN', 'constant']
        assert covar_module_str in ['NN', 'SE']
        """ random gp model """
        self.random_gp = RandomGPMeta(size_in=self.input_dim,
                                      prior_factor=self.prior_factor,
                                      weight_prior_std=self.weight_prior_std,
                                      bias_prior_std=self.bias_prior_std,
                                      covar_module_str=covar_module_str,
                                      mean_module_str=mean_module_str,
                                      mean_nn_layers=mean_nn_layers,
                                      kernel_nn_layers=kernel_nn_layers)
        """ Setup SVGD inference"""

        if kernel == 'RBF':
            kernel = RBF_Kernel(bandwidth=bandwidth)
        elif kernel == 'IMQ':
            kernel = IMQSteinKernel(bandwidth=bandwidth)
        else:
            raise NotImplemented

        # sample initial particle locations from prior
        self.particles = self.random_gp.sample_params_from_prior(
            shape=(self.num_particles, ))

        self._setup_optimizer(optimizer, lr, lr_decay)

        self.svgd = SVGD(self.random_gp, kernel, optimizer=self.optimizer)
        """ define svgd step """
        def svgd_step(tasks_dicts):
            # tile data to svi_batch_shape
            train_data_tuples_tiled = []
            for task_dict in tasks_dicts:
                x_data, y_data = task_dict['train_x'], task_dict['train_y']
                x_data = x_data.view(torch.Size((1, )) + x_data.shape).repeat(
                    self.num_particles, 1, 1)
                y_data = y_data.view(torch.Size((1, )) + y_data.shape).repeat(
                    self.num_particles, 1)
                train_data_tuples_tiled.append((x_data, y_data))

            self.svgd.step(self.particles, train_data_tuples_tiled)

        """ define predictive dist """

        def get_pred_dist(x_context, y_context, x_valid):
            with torch.no_grad():
                x_context = x_context.view(
                    torch.Size((1, )) + x_context.shape).repeat(
                        self.num_particles, 1, 1)
                y_context = y_context.view(
                    torch.Size((1, )) + y_context.shape).repeat(
                        self.num_particles, 1)
                x_valid = x_valid.view(torch.Size((1, )) +
                                       x_valid.shape).repeat(
                                           self.num_particles, 1, 1)

                gp_fn = self.random_gp.get_forward_fn(self.particles)
                gp, likelihood = gp_fn(x_context, y_context, train=False)
                pred_dist = likelihood(gp(x_valid))
            return pred_dist

        self.svgd_step = svgd_step
        self.get_pred_dist = get_pred_dist
예제 #6
0
class GPRegressionMetaLearnedSVGD(RegressionModelMetaLearned):
    def __init__(self,
                 meta_train_data,
                 num_iter_fit=10000,
                 feature_dim=1,
                 prior_factor=0.01,
                 weight_prior_std=0.5,
                 bias_prior_std=3.0,
                 covar_module='NN',
                 mean_module='NN',
                 mean_nn_layers=(32, 32),
                 kernel_nn_layers=(32, 32),
                 optimizer='Adam',
                 lr=1e-3,
                 lr_decay=1.0,
                 kernel='RBF',
                 bandwidth=None,
                 num_particles=10,
                 task_batch_size=-1,
                 normalize_data=True,
                 random_seed=None):
        """
        PACOH-SVGD: Stein Variational Gradient Descent on PAC-optimal hyper-posterior.
        Meta-learns a set of GP priors (i.e. mean and kernel function)

        Args:
            meta_train_data: list of tuples of ndarrays[(train_x_1, train_t_1), ..., (train_x_n, train_t_n)]
            num_iter_fit: (int) number of gradient steps for fitting the parameters
            feature_dim: (int) output dimensionality of NN feature map for kernel function
            prior_factor: (float) weighting of the hyper-prior (--> meta-regularization parameter)
            weight_prior_std (float): std of Gaussian hyper-prior on weights
            bias_prior_std (float): std of Gaussian hyper-prior on biases
            covar_module: (gpytorch.mean.Kernel) optional kernel module, default: RBF kernel
            mean_module: (gpytorch.mean.Mean) optional mean module, default: ZeroMean
            mean_nn_layers: (tuple) hidden layer sizes of mean NN
            kernel_nn_layers: (tuple) hidden layer sizes of kernel NN
            optimizer: (str) type of optimizer to use - must be either 'Adam' or 'SGD'
            lr: (float) learning rate for prior parameters
            lr_decay: (float) lr rate decay multiplier applied after every 1000 steps
            kernel (std): SVGD kernel, either 'RBF' or 'IMQ'
            bandwidth (float): bandwidth of kernel, if None the bandwidth is chosen via heuristic
            num_particles: (int) number particles to approximate the hyper-posterior
            task_batch_size: (int) mini-batch size of tasks for estimating gradients
            normalize_data: (bool) whether the data should be normalized
            random_seed: (int) seed for pytorch
        """
        super().__init__(normalize_data, random_seed)

        assert mean_module in ['NN', 'constant', 'zero'] or isinstance(
            mean_module, gpytorch.means.Mean)
        assert covar_module in ['NN', 'SE'] or isinstance(
            covar_module, gpytorch.kernels.Kernel)
        assert optimizer in ['Adam', 'SGD']

        self.num_iter_fit, self.prior_factor, self.feature_dim = num_iter_fit, prior_factor, feature_dim
        self.weight_prior_std, self.bias_prior_std = weight_prior_std, bias_prior_std
        self.num_particles = num_particles
        if task_batch_size < 1:
            self.task_batch_size = len(meta_train_data)
        else:
            self.task_batch_size = min(task_batch_size, len(meta_train_data))

        # Check that data all has the same size
        self._check_meta_data_shapes(meta_train_data)
        self._compute_normalization_stats(meta_train_data)
        """ --- Setup model & inference --- """
        self._setup_model_inference(mean_module, covar_module, mean_nn_layers,
                                    kernel_nn_layers, kernel, bandwidth,
                                    optimizer, lr, lr_decay)

        # Setup components that are different across tasks
        self.task_dicts = []

        for train_x, train_y in meta_train_data:
            task_dict = {}

            # a) prepare data
            x_tensor, y_tensor = self._prepare_data_per_task(train_x, train_y)
            task_dict['train_x'], task_dict['train_y'] = x_tensor, y_tensor
            self.task_dicts.append(task_dict)

        self.fitted = False

    def meta_fit(self,
                 valid_tuples=None,
                 verbose=True,
                 log_period=500,
                 n_iter=None):
        """
        fits the hyper-posterior particles with SVGD

        Args:
            valid_tuples: list of valid tuples, i.e. [(test_context_x_1, test_context_t_1, test_x_1, test_t_1), ...]
            verbose: (boolean) whether to print training progress
            log_period (int) number of steps after which to print stats
            n_iter: (int) number of gradient descent iterations
        """

        assert (valid_tuples is None) or (all(
            [len(valid_tuple) == 4 for valid_tuple in valid_tuples]))

        t = time.time()

        if n_iter is None:
            n_iter = self.num_iter_fit

        for itr in range(1, n_iter + 1):

            task_dict_batch = self.rds_numpy.choice(self.task_dicts,
                                                    size=self.task_batch_size)
            self.svgd_step(task_dict_batch)
            self.lr_scheduler.step()

            # print training stats stats
            if itr == 1 or itr % log_period == 0:
                duration = time.time() - t
                t = time.time()

                message = 'Iter %d/%d - Time %.2f sec' % (
                    itr, self.num_iter_fit, duration)

                # if validation data is provided  -> compute the valid log-likelihood
                if valid_tuples is not None:
                    valid_ll, valid_rmse, calibr_err = self.eval_datasets(
                        valid_tuples)
                    message += ' - Valid-LL: %.3f - Valid-RMSE: %.3f - Calib-Err %.3f' % (
                        valid_ll, valid_rmse, calibr_err)

                if verbose:
                    self.logger.info(message)

        self.fitted = True

    def predict(self, context_x, context_y, test_x, return_density=False):
        """
        Performs posterior inference (target training) with (context_x, context_y) as training data and then
        computes the predictive distribution of the targets p(y|test_x, test_context_x, context_y) in the test points

        Args:
            context_x: (ndarray) context input data for which to compute the posterior
            context_y: (ndarray) context targets for which to compute the posterior
            test_x: (ndarray) query input data of shape (n_samples, ndim_x)
            return_density: (bool) whether to return result as mean and std ndarray or as MultivariateNormal pytorch object

        Returns:
            (pred_mean, pred_std) predicted mean and standard deviation corresponding to p(t|test_x, test_context_x, context_y)
        """

        context_x, context_y = _handle_input_dimensionality(
            context_x, context_y)
        test_x = _handle_input_dimensionality(test_x)
        assert test_x.shape[1] == context_x.shape[1]

        # normalize data and convert to tensor
        context_x, context_y = self._prepare_data_per_task(
            context_x, context_y)

        test_x = self._normalize_data(X=test_x, Y=None)
        test_x = torch.from_numpy(test_x).float().to(device)

        with torch.no_grad():
            pred_dist = self.get_pred_dist(context_x, context_y, test_x)
            pred_dist = AffineTransformedDistribution(
                pred_dist,
                normalization_mean=self.y_mean,
                normalization_std=self.y_std)
            pred_dist = EqualWeightedMixtureDist(pred_dist, batched=True)

            if return_density:
                return pred_dist
            else:
                pred_mean = pred_dist.mean.cpu().numpy()
                pred_std = pred_dist.stddev.cpu().numpy()
                return pred_mean, pred_std

    def _setup_model_inference(self, mean_module_str, covar_module_str,
                               mean_nn_layers, kernel_nn_layers, kernel,
                               bandwidth, optimizer, lr, lr_decay):
        assert mean_module_str in ['NN', 'constant']
        assert covar_module_str in ['NN', 'SE']
        """ random gp model """
        self.random_gp = RandomGPMeta(size_in=self.input_dim,
                                      prior_factor=self.prior_factor,
                                      weight_prior_std=self.weight_prior_std,
                                      bias_prior_std=self.bias_prior_std,
                                      covar_module_str=covar_module_str,
                                      mean_module_str=mean_module_str,
                                      mean_nn_layers=mean_nn_layers,
                                      kernel_nn_layers=kernel_nn_layers)
        """ Setup SVGD inference"""

        if kernel == 'RBF':
            kernel = RBF_Kernel(bandwidth=bandwidth)
        elif kernel == 'IMQ':
            kernel = IMQSteinKernel(bandwidth=bandwidth)
        else:
            raise NotImplemented

        # sample initial particle locations from prior
        self.particles = self.random_gp.sample_params_from_prior(
            shape=(self.num_particles, ))

        self._setup_optimizer(optimizer, lr, lr_decay)

        self.svgd = SVGD(self.random_gp, kernel, optimizer=self.optimizer)
        """ define svgd step """
        def svgd_step(tasks_dicts):
            # tile data to svi_batch_shape
            train_data_tuples_tiled = []
            for task_dict in tasks_dicts:
                x_data, y_data = task_dict['train_x'], task_dict['train_y']
                x_data = x_data.view(torch.Size((1, )) + x_data.shape).repeat(
                    self.num_particles, 1, 1)
                y_data = y_data.view(torch.Size((1, )) + y_data.shape).repeat(
                    self.num_particles, 1)
                train_data_tuples_tiled.append((x_data, y_data))

            self.svgd.step(self.particles, train_data_tuples_tiled)

        """ define predictive dist """

        def get_pred_dist(x_context, y_context, x_valid):
            with torch.no_grad():
                x_context = x_context.view(
                    torch.Size((1, )) + x_context.shape).repeat(
                        self.num_particles, 1, 1)
                y_context = y_context.view(
                    torch.Size((1, )) + y_context.shape).repeat(
                        self.num_particles, 1)
                x_valid = x_valid.view(torch.Size((1, )) +
                                       x_valid.shape).repeat(
                                           self.num_particles, 1, 1)

                gp_fn = self.random_gp.get_forward_fn(self.particles)
                gp, likelihood = gp_fn(x_context, y_context, train=False)
                pred_dist = likelihood(gp(x_valid))
            return pred_dist

        self.svgd_step = svgd_step
        self.get_pred_dist = get_pred_dist

    def _setup_optimizer(self, optimizer, lr, lr_decay):
        assert hasattr(
            self, 'particles'
        ), "SVGD must be initialized before setting up optimizer"

        if optimizer == 'Adam':
            self.optimizer = torch.optim.Adam([self.particles], lr=lr)
        elif optimizer == 'SGD':
            self.optimizer = torch.optim.SGD([self.particles], lr=lr)
        else:
            raise NotImplementedError('Optimizer must be Adam or SGD')

        if lr_decay < 1.0:
            self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer,
                                                                1000,
                                                                gamma=lr_decay)
        else:
            self.lr_scheduler = DummyLRScheduler()

    def _vectorize_pred_dist(self, pred_dist):
        multiv_normal_batched = pred_dist.dists
        normal_batched = torch.distributions.Normal(
            multiv_normal_batched.mean, multiv_normal_batched.stddev)
        return EqualWeightedMixtureDist(
            normal_batched,
            batched=True,
            num_dists=multiv_normal_batched.batch_shape[0])