Пример #1
0
    def marginal_log_likelihood(self, output, target):
        """
        Returns the marginal log likelihood of the data

        Args:
        - output: (GaussianRandomVariable) - the output of the model
        - target: (Variable) - target
        """
        mean, covar = output.representation()

        # Exact inference
        if self.exact_inference:
            return gpytorch.exact_gp_marginal_log_likelihood(
                covar, target - mean)

        # Approximate inference
        else:
            # Get inducing points
            if not hasattr(self, 'train_inputs'):
                raise RuntimeError('Must condition on data.')

            train_x = self.train_inputs[0]
            if hasattr(self, 'inducing_points'):
                inducing_points = Variable(self.inducing_points)
            else:
                inducing_points = train_x

            chol_var_covar = self.chol_variational_covar.triu()
            # Negate each row with a negative diagonal (the Cholesky decomposition
            # of a matrix requires that the diagonal elements be positive).
            inside = chol_var_covar.diag().sign().unsqueeze(1).expand_as(
                chol_var_covar).triu()
            chol_var_covar = chol_var_covar.mul(inside)

            _, train_covar = output.representation()
            inducing_output = super(GPModel, self).__call__(inducing_points)
            inducing_mean, inducing_covar = inducing_output.representation()

            train_covar = gpytorch.add_jitter(train_covar)
            log_likelihood = gpytorch.monte_carlo_log_likelihood(
                self.likelihood.log_probability, target, self.variational_mean,
                chol_var_covar, train_covar)

            inducing_covar = gpytorch.add_jitter(inducing_covar)
            kl_divergence = gpytorch.mvn_kl_divergence(self.variational_mean,
                                                       chol_var_covar,
                                                       inducing_mean,
                                                       inducing_covar)

            res = log_likelihood.squeeze() - kl_divergence
            return res
Пример #2
0
    def mvn_kl_divergence(self):
        mean_diffs = self.inducing_output.mean() - self.variational_mean
        chol_variational_covar = self.chol_variational_covar

        if chol_variational_covar.ndimension() == 2:
            matrix_diag = chol_variational_covar.diag()
        elif chol_variational_covar.ndimension() == 3:
            batch_size, diag_size, _ = chol_variational_covar.size()
            batch_index = chol_variational_covar.data.new(batch_size).long()
            torch.arange(0, batch_size, out=batch_index)
            batch_index = batch_index.unsqueeze(1).repeat(1,
                                                          diag_size).view(-1)
            diag_index = chol_variational_covar.data.new(diag_size).long()
            torch.arange(0, diag_size, out=diag_index)
            diag_index = diag_index.unsqueeze(1).repeat(batch_size, 1).view(-1)
            matrix_diag = chol_variational_covar[batch_index, diag_index,
                                                 diag_index].view(
                                                     batch_size, diag_size)
        else:
            raise RuntimeError(
                'Invalid number of variational covar dimensions')

        logdet_variational_covar = matrix_diag.log().sum() * 2
        trace_logdet_quad_form = gpytorch.trace_logdet_quad_form(
            mean_diffs, self.chol_variational_covar,
            gpytorch.add_jitter(self.inducing_output.covar()))

        # Compute the KL Divergence.
        res = 0.5 * (trace_logdet_quad_form - logdet_variational_covar -
                     len(mean_diffs))
        return res
Пример #3
0
    def kl_divergence(self):
        prior_mean = self.prior_dist.mean()
        prior_covar = self.prior_dist.covar()
        variational_mean = self.variational_dist.mean()
        variational_covar = self.variational_dist.covar()
        if not isinstance(variational_covar, CholLazyVariable):
            raise RuntimeError('The variational covar for an MVN distribution should be a CholLazyVariable')
        chol_variational_covar = variational_covar.lhs

        mean_diffs = prior_mean - variational_mean
        chol_variational_covar = chol_variational_covar

        if chol_variational_covar.ndimension() == 2:
            matrix_diag = chol_variational_covar.diag()
        elif chol_variational_covar.ndimension() == 3:
            batch_size, diag_size, _ = chol_variational_covar.size()
            batch_index = chol_variational_covar.data.new(batch_size).long()
            torch.arange(0, batch_size, out=batch_index)
            batch_index = batch_index.unsqueeze(1).repeat(1, diag_size).view(-1)
            diag_index = chol_variational_covar.data.new(diag_size).long()
            torch.arange(0, diag_size, out=diag_index)
            diag_index = diag_index.unsqueeze(1).repeat(batch_size, 1).view(-1)
            matrix_diag = chol_variational_covar[batch_index, diag_index, diag_index].view(batch_size, diag_size)
        else:
            raise RuntimeError('Invalid number of variational covar dimensions')

        logdet_variational_covar = matrix_diag.log().sum() * 2
        trace_logdet_quad_form = gpytorch.trace_logdet_quad_form(mean_diffs, chol_variational_covar,
                                                                 gpytorch.add_jitter(prior_covar))

        # Compute the KL Divergence.
        res = 0.5 * (trace_logdet_quad_form - logdet_variational_covar - len(mean_diffs))
        return res
Пример #4
0
    def marginal_log_likelihood(self, output, train_y, num_samples=10):
        chol_var_covar = self.chol_variational_covar.triu()

        # Negate each row with a negative diagonal (the Cholesky decomposition
        # of a matrix requires that the diagonal elements be positive).
        chol_var_covar = chol_var_covar.mul(
            chol_var_covar.diag().sign().unsqueeze(1).expand_as(
                chol_var_covar).triu())

        _, train_covar = output.representation()
        inducing_output = self.forward(*self.inducing_points)
        inducing_mean = inducing_output.mean()

        train_covar = gpytorch.add_jitter(train_covar)

        log_likelihood = gpytorch.monte_carlo_log_likelihood(
            self.prior_model.likelihood.log_probability, train_y,
            self.variational_mean, chol_var_covar, train_covar, num_samples)

        kl_divergence = gpytorch.mvn_kl_divergence(self.variational_mean,
                                                   chol_var_covar,
                                                   inducing_mean, train_covar,
                                                   num_samples)

        return log_likelihood.squeeze() - kl_divergence
Пример #5
0
    def forward(self, x, full_cov=True):
        h = x.mm(self.freq / self.lengthscales)
        h = torch.sqrt(self.kernel.outputscale) / math.sqrt(
            self.n_units) * torch.cat(
                [torch.cos(h), torch.sin(h)], -1)

        f_mean = h.mm(self.w_mean)
        if self.residual:
            f_mean += self.mlp(x)
        f_mean = f_mean.squeeze(-1)

        w_cov_tril = softplus_tril(self.w_cov_raw)
        f_cov_half = h.mm(w_cov_tril)

        if full_cov:
            f_cov = f_cov_half.mm(f_cov_half.t())
            f_cov = gpytorch.add_jitter(f_cov)
            if self.mvn:
                f_dist = gpytorch.distributions.MultivariateNormal(
                    f_mean, f_cov)
                return f_dist
            else:
                return f_mean, f_cov
        else:
            f_var = f_cov_half.pow(2).sum(-1)
            f_var += 1e-6
            return f_mean, f_var
    def forward(self, x, full_cov=True):
        batch_size = x.size(0)
        if self.periodic:
            x = torch.cat([x, torch.sin(self.periodic_fc(x))], -1)

        h = self.layers(x)

        f_mean = h.squeeze(-1)

        grad_w_means = map(partial(jacobian, y=f_mean), self.means)

        f_cov_half = [(grad_w_mean * w_std).reshape(batch_size, -1)
                      for (grad_w_mean, w_std) in zip(grad_w_means, self.stds)]

        if full_cov:
            f_cov = sum([i.mm(i.t()) for i in f_cov_half])
            f_cov = gpytorch.add_jitter(f_cov)
            if self.mvn:
                return MultivariateNormal(f_mean, f_cov)
            else:
                return f_mean, f_cov
        else:
            f_var = sum([torch.sum(i.pow(2), -1) for i in f_cov_half])
            f_var += 1e-6
            return f_mean, f_var
Пример #7
0
    def forward(self, *inputs, **params):
        if not self.training:
            inducing_point_vars = [
                inducing_pt for inducing_pt in self.inducing_points
            ]
            full_inputs = [
                torch.cat([inducing_point_var,
                           input]) for inducing_point_var, input in zip(
                               inducing_point_vars, inputs)
            ]
        else:
            full_inputs = inputs

        gaussian_rv_output = self.prior_model.forward(*full_inputs, **params)
        full_mean, full_covar = gaussian_rv_output.representation()

        if not self.training:
            # Get mean/covar components
            n = self.num_inducing
            test_mean = full_mean[n:]
            induc_induc_covar = full_covar[:n, :n]
            induc_test_covar = full_covar[:n, n:]
            test_induc_covar = full_covar[n:, :n]
            test_test_covar = full_covar[n:, n:]

            # Calculate posterior components
            if not hasattr(self, 'alpha'):
                self.alpha = gpytorch.variational_posterior_alpha(
                    induc_induc_covar, self.variational_mean)
            test_mean = gpytorch.variational_posterior_mean(
                test_induc_covar, self.alpha)
            test_covar = gpytorch.variational_posterior_covar(
                test_induc_covar, induc_test_covar,
                self.chol_variational_covar, test_test_covar,
                induc_induc_covar)
            output = GaussianRandomVariable(test_mean, test_covar)
            return output

        else:
            full_covar = gpytorch.add_jitter(full_covar)
            f_prior = GaussianRandomVariable(full_mean, full_covar)
            return f_prior
Пример #8
0
    def forward(self, x, full_cov=True):
        w_cov_tril = softplus_tril(self.w_cov_raw)

        h = self.layers(x)

        f_mean = h.mm(self.w_mean).squeeze(-1)

        f_cov_half = h.mm(w_cov_tril)

        if full_cov:
            f_cov = f_cov_half.mm(f_cov_half.t())
            f_cov = gpytorch.add_jitter(f_cov)

            if self.mvn:
                return gpytorch.distributions.MultivariateNormal(f_mean, f_cov)
            else:
                return f_mean, f_cov
        else:
            hw_cov = f_cov_half.mm(w_cov_tril.t())
            f_var = torch.sum(hw_cov * h, -1)
            f_var += 1e-6
            return f_mean, f_var
Пример #9
0
def gpnet(args, dataloader, test_x, prior_gp):
    N = len(dataloader.dataset)
    x_dim = 1
    prior_gp.train()

    if args.net == 'tangent':
        kernel = prior_gp.covar_module
        bnn_prev = FirstOrder([x_dim] + [args.n_hidden] * args.n_layer,
                              mvn=False)
        bnn = FirstOrder([x_dim] + [args.n_hidden] * args.n_layer, mvn=True)
    elif args.net == 'deep':
        kernel = prior_gp.covar_module
        bnn_prev = DeepKernel([x_dim] + [args.n_hidden] * args.n_layer,
                              mvn=False)
        bnn = DeepKernel([x_dim] + [args.n_hidden] * args.n_layer, mvn=True)
    elif args.net == 'rf':
        kernel = ScaleKernel(RBFKernel())
        kernel_prev = ScaleKernel(RBFKernel())
        bnn_prev = RFExpansion(x_dim,
                               args.n_hidden,
                               kernel_prev,
                               mvn=False,
                               fix_ls=args.fix_rf_ls,
                               residual=args.residual)
        bnn = RFExpansion(x_dim,
                          args.n_hidden,
                          kernel,
                          fix_ls=args.fix_rf_ls,
                          residual=args.residual)
        bnn_prev.load_state_dict(bnn.state_dict())
    else:
        raise NotImplementedError('Unknown inference net')
    bnn = bnn.to(args.device)
    bnn_prev = bnn_prev.to(args.device)
    prior_gp = prior_gp.to(args.device)

    infer_gpnet_optimizer = optim.Adam(bnn.parameters(), lr=args.learning_rate)
    hyper_opt_optimizer = optim.Adam(prior_gp.parameters(), lr=args.hyper_rate)

    x_min, x_max = dataloader.dataset.range

    bnn.train()
    bnn_prev.train()
    prior_gp.train()

    mb = master_bar(range(1, args.n_iters + 1))

    for t in mb:
        # Hyperparameter selection
        beta = args.beta0 * 1. / (1. + args.gamma * math.sqrt(t - 1))
        dl_bar = progress_bar(dataloader, parent=mb)
        for x, y in dl_bar:
            observed_size = x.size(0)
            x, y = x.to(args.device), y.to(args.device)
            x_star = torch.Tensor(args.measurement_size,
                                  x_dim).uniform_(x_min, x_max).to(args.device)
            # [Batch + Measurement Points x x_dims]
            xx = torch.cat([x, x_star], 0)

            infer_gpnet_optimizer.zero_grad()
            hyper_opt_optimizer.zero_grad()

            # inference net
            # Eq.(6) Prior p(f)
            # \mu_1=0, \Sigma_1
            mean_prior = torch.zeros(observed_size).to(args.device)
            K_prior = kernel(xx, xx).add_jitter(1e-6)

            # q_{\gamma_t}(f_M, f_n) = Normal(mu_2, sigma_2|x_n, x_m)
            # \mu_2, \Sigma_2
            qff_mean_prev, K_prox = bnn_prev(xx)

            # Eq.(8) adapt prior; p(f)^\beta x q(f)^{1 - \beta}
            mean_adapt, K_adapt = product_gaussians(mu1=mean_prior,
                                                    sigma1=K_prior,
                                                    mu2=qff_mean_prev,
                                                    sigma2=K_prox,
                                                    beta=beta)

            # Eq.(8)
            (mean_n, mean_m), (Knn, Knm,
                               Kmm) = split_gaussian(mean_adapt, K_adapt,
                                                     observed_size)

            # Eq.(2) K_{D,D} + noise / (N\beta_t)
            Ky = Knn + torch.eye(observed_size).to(
                args.device) * prior_gp.likelihood.noise / (N / observed_size *
                                                            beta)
            Ky_tril = torch.cholesky(Ky)

            # Eq.(2)
            mean_target = Knm.t().mm(cholesky_solve(y - mean_n,
                                                    Ky_tril)) + mean_m
            mean_target = mean_target.squeeze(-1)
            K_target = gpytorch.add_jitter(
                Kmm - Knm.t().mm(cholesky_solve(Knm, Ky_tril)), 1e-6)
            # \hat{q}_{t+1} (f_M)
            target_pf_star = MultivariateNormal(mean_target, K_target)

            # q_\gamma (f_M)
            qf_star = bnn(x_star)

            # Eq. (11)
            kl_obj = kl_div(qf_star, target_pf_star).sum()

            kl_obj.backward(retain_graph=True)
            infer_gpnet_optimizer.step()

            # Hyper paramter update
            (mean_n_prior, _), (Kn_prior, _,
                                _) = split_gaussian(mean_prior, K_prior,
                                                    observed_size)
            pf = MultivariateNormal(mean_n_prior, Kn_prior)

            (qf_prev_mean, _), (Kn_prox, _,
                                _) = split_gaussian(qff_mean_prev, K_prox,
                                                    observed_size)
            qf_prev = MultivariateNormal(qf_prev_mean, Kn_prox)

            hyper_obj = -(prior_gp.likelihood.expected_log_prob(
                y.squeeze(-1), qf_prev) - kl_div(qf_prev, pf))
            hyper_obj.backward(retain_graph=True)
            hyper_opt_optimizer.step()

            mb.child.comment = "kl_obj = {:.3f}, obs_var={:.3f}".format(
                kl_obj.item(), prior_gp.likelihood.noise.item())

        # update q_{\gamma_t} to q_{\gamma_{t+1}}
        bnn_prev.load_state_dict(bnn.state_dict())
        if args.net == 'rf':
            kernel_prev.load_state_dict(kernel.state_dict())
        if t % 50 == 0:
            mb.write("Iter {}/{}, kl_obj = {:.4f}, noise = {:.4f}".format(
                t, args.n_iters, kl_obj.item(),
                prior_gp.likelihood.noise.item()))

    test_x = test_x.to(args.device)
    test_stats = evaluate(bnn, prior_gp.likelihood, test_x,
                          args.net == 'tangent')
    return test_stats
Пример #10
0
    def __call__(self, *args, **kwargs):
        output = None

        # Posterior mode
        if self.posterior:
            train_xs = self.train_inputs
            train_y = self.train_target
            if all([
                    torch.equal(train_x.data, input.data)
                    for train_x, input in zip(train_xs, args)
            ]):
                logging.warning('The input matches the stored training data. '
                                'Did you forget to call model.train()?')

            n_train = len(train_xs[0])
            full_inputs = [
                torch.cat([train_x, input])
                for train_x, input in zip(train_xs, args)
            ]
            full_output = super(GPModel, self).__call__(*full_inputs, **kwargs)
            full_mean, full_covar = full_output.representation()

            # Exact inference
            if self.exact_inference:
                n_train = len(train_xs[0])
                full_inputs = [
                    torch.cat([train_x, input])
                    for train_x, input in zip(train_xs, args)
                ]
                full_output = super(GPModel,
                                    self).__call__(*full_inputs, **kwargs)
                full_mean, full_covar = full_output.representation()

                train_mean = full_mean[:n_train]
                test_mean = full_mean[n_train:]
                train_train_covar = gpytorch.add_diag(
                    full_covar[:n_train, :n_train],
                    self.likelihood.log_noise.exp())
                train_test_covar = full_covar[:n_train, n_train:]
                test_train_covar = full_covar[n_train:, :n_train]
                test_test_covar = full_covar[n_train:, n_train:]

                # Calculate posterior components
                if not self.has_computed_alpha[0]:
                    alpha_strategy = gpytorch.posterior_strategy(
                        train_train_covar)
                    alpha = alpha_strategy.exact_posterior_alpha(
                        train_mean, train_y)
                    self.alpha.copy_(alpha.data)
                    self.has_computed_alpha.fill_(1)
                else:
                    alpha = Variable(self.alpha)
                mean_strategy = gpytorch.posterior_strategy(test_train_covar)
                test_mean = mean_strategy.exact_posterior_mean(
                    test_mean, alpha)
                covar_strategy = gpytorch.posterior_strategy(train_train_covar)
                test_covar = covar_strategy.exact_posterior_covar(
                    test_train_covar, train_test_covar, test_test_covar)
                output = GaussianRandomVariable(test_mean, test_covar)

            # Approximate inference
            else:
                # Ensure variational parameters have been initalized
                if not self.variational_mean.numel():
                    raise RuntimeError(
                        'Variational parameters have not been initalized.'
                        'Condition on data.')

                # Get inducing points
                if hasattr(self, 'inducing_points'):
                    inducing_points = Variable(self.inducing_points)
                else:
                    inducing_points = train_xs[0]

                n_induc = len(inducing_points)
                full_input = torch.cat([inducing_points, args[0]])
                full_output = super(GPModel,
                                    self).__call__(full_input, **kwargs)
                full_mean, full_covar = full_output.representation()

                test_mean = full_mean[n_induc:]
                induc_induc_covar = full_covar[:n_induc, :n_induc]
                induc_test_covar = full_covar[:n_induc, n_induc:]
                test_induc_covar = full_covar[n_induc:, :n_induc]
                test_test_covar = full_covar[n_induc:, n_induc:]

                # Calculate posterior components
                if not self.has_computed_alpha[0]:
                    alpha_strategy = gpytorch.posterior_strategy(
                        induc_induc_covar)
                    alpha = alpha_strategy.variational_posterior_alpha(
                        self.variational_mean)
                    self.alpha.copy_(alpha.data)
                    self.has_computed_alpha.fill_(1)
                else:
                    alpha = Variable(self.alpha)
                mean_strategy = gpytorch.posterior_strategy(test_induc_covar)
                test_mean = mean_strategy.variational_posterior_mean(alpha)
                covar_strategy = gpytorch.posterior_strategy(test_induc_covar)
                test_covar = covar_strategy.variational_posterior_covar(
                    induc_test_covar, self.chol_variational_covar,
                    test_test_covar, induc_induc_covar)
                output = GaussianRandomVariable(test_mean, test_covar)

        # Training or Prior mode
        else:
            output = super(GPModel, self).__call__(*args, **kwargs)
            # Add some jitter
            if not self.exact_inference:
                mean, covar = output.representation()
                covar = gpytorch.add_jitter(covar)
                output = GaussianRandomVariable(mean, covar)

            if self.conditioning:
                # Reset alpha cache
                _, covar = output.representation()
                self.has_computed_alpha.fill_(0)
                self.alpha.resize_(
                    gpytorch.posterior_strategy(covar).alpha_size())

        # Now go through the likelihood
        if isinstance(output, Variable) or isinstance(
                output, RandomVariable) or isinstance(output, LazyVariable):
            output = (output, )
        return self.likelihood(*output)