def marginal_log_likelihood(self, output, target): """ Returns the marginal log likelihood of the data Args: - output: (GaussianRandomVariable) - the output of the model - target: (Variable) - target """ mean, covar = output.representation() # Exact inference if self.exact_inference: return gpytorch.exact_gp_marginal_log_likelihood( covar, target - mean) # Approximate inference else: # Get inducing points if not hasattr(self, 'train_inputs'): raise RuntimeError('Must condition on data.') train_x = self.train_inputs[0] if hasattr(self, 'inducing_points'): inducing_points = Variable(self.inducing_points) else: inducing_points = train_x chol_var_covar = self.chol_variational_covar.triu() # Negate each row with a negative diagonal (the Cholesky decomposition # of a matrix requires that the diagonal elements be positive). inside = chol_var_covar.diag().sign().unsqueeze(1).expand_as( chol_var_covar).triu() chol_var_covar = chol_var_covar.mul(inside) _, train_covar = output.representation() inducing_output = super(GPModel, self).__call__(inducing_points) inducing_mean, inducing_covar = inducing_output.representation() train_covar = gpytorch.add_jitter(train_covar) log_likelihood = gpytorch.monte_carlo_log_likelihood( self.likelihood.log_probability, target, self.variational_mean, chol_var_covar, train_covar) inducing_covar = gpytorch.add_jitter(inducing_covar) kl_divergence = gpytorch.mvn_kl_divergence(self.variational_mean, chol_var_covar, inducing_mean, inducing_covar) res = log_likelihood.squeeze() - kl_divergence return res
def mvn_kl_divergence(self): mean_diffs = self.inducing_output.mean() - self.variational_mean chol_variational_covar = self.chol_variational_covar if chol_variational_covar.ndimension() == 2: matrix_diag = chol_variational_covar.diag() elif chol_variational_covar.ndimension() == 3: batch_size, diag_size, _ = chol_variational_covar.size() batch_index = chol_variational_covar.data.new(batch_size).long() torch.arange(0, batch_size, out=batch_index) batch_index = batch_index.unsqueeze(1).repeat(1, diag_size).view(-1) diag_index = chol_variational_covar.data.new(diag_size).long() torch.arange(0, diag_size, out=diag_index) diag_index = diag_index.unsqueeze(1).repeat(batch_size, 1).view(-1) matrix_diag = chol_variational_covar[batch_index, diag_index, diag_index].view( batch_size, diag_size) else: raise RuntimeError( 'Invalid number of variational covar dimensions') logdet_variational_covar = matrix_diag.log().sum() * 2 trace_logdet_quad_form = gpytorch.trace_logdet_quad_form( mean_diffs, self.chol_variational_covar, gpytorch.add_jitter(self.inducing_output.covar())) # Compute the KL Divergence. res = 0.5 * (trace_logdet_quad_form - logdet_variational_covar - len(mean_diffs)) return res
def kl_divergence(self): prior_mean = self.prior_dist.mean() prior_covar = self.prior_dist.covar() variational_mean = self.variational_dist.mean() variational_covar = self.variational_dist.covar() if not isinstance(variational_covar, CholLazyVariable): raise RuntimeError('The variational covar for an MVN distribution should be a CholLazyVariable') chol_variational_covar = variational_covar.lhs mean_diffs = prior_mean - variational_mean chol_variational_covar = chol_variational_covar if chol_variational_covar.ndimension() == 2: matrix_diag = chol_variational_covar.diag() elif chol_variational_covar.ndimension() == 3: batch_size, diag_size, _ = chol_variational_covar.size() batch_index = chol_variational_covar.data.new(batch_size).long() torch.arange(0, batch_size, out=batch_index) batch_index = batch_index.unsqueeze(1).repeat(1, diag_size).view(-1) diag_index = chol_variational_covar.data.new(diag_size).long() torch.arange(0, diag_size, out=diag_index) diag_index = diag_index.unsqueeze(1).repeat(batch_size, 1).view(-1) matrix_diag = chol_variational_covar[batch_index, diag_index, diag_index].view(batch_size, diag_size) else: raise RuntimeError('Invalid number of variational covar dimensions') logdet_variational_covar = matrix_diag.log().sum() * 2 trace_logdet_quad_form = gpytorch.trace_logdet_quad_form(mean_diffs, chol_variational_covar, gpytorch.add_jitter(prior_covar)) # Compute the KL Divergence. res = 0.5 * (trace_logdet_quad_form - logdet_variational_covar - len(mean_diffs)) return res
def marginal_log_likelihood(self, output, train_y, num_samples=10): chol_var_covar = self.chol_variational_covar.triu() # Negate each row with a negative diagonal (the Cholesky decomposition # of a matrix requires that the diagonal elements be positive). chol_var_covar = chol_var_covar.mul( chol_var_covar.diag().sign().unsqueeze(1).expand_as( chol_var_covar).triu()) _, train_covar = output.representation() inducing_output = self.forward(*self.inducing_points) inducing_mean = inducing_output.mean() train_covar = gpytorch.add_jitter(train_covar) log_likelihood = gpytorch.monte_carlo_log_likelihood( self.prior_model.likelihood.log_probability, train_y, self.variational_mean, chol_var_covar, train_covar, num_samples) kl_divergence = gpytorch.mvn_kl_divergence(self.variational_mean, chol_var_covar, inducing_mean, train_covar, num_samples) return log_likelihood.squeeze() - kl_divergence
def forward(self, x, full_cov=True): h = x.mm(self.freq / self.lengthscales) h = torch.sqrt(self.kernel.outputscale) / math.sqrt( self.n_units) * torch.cat( [torch.cos(h), torch.sin(h)], -1) f_mean = h.mm(self.w_mean) if self.residual: f_mean += self.mlp(x) f_mean = f_mean.squeeze(-1) w_cov_tril = softplus_tril(self.w_cov_raw) f_cov_half = h.mm(w_cov_tril) if full_cov: f_cov = f_cov_half.mm(f_cov_half.t()) f_cov = gpytorch.add_jitter(f_cov) if self.mvn: f_dist = gpytorch.distributions.MultivariateNormal( f_mean, f_cov) return f_dist else: return f_mean, f_cov else: f_var = f_cov_half.pow(2).sum(-1) f_var += 1e-6 return f_mean, f_var
def forward(self, x, full_cov=True): batch_size = x.size(0) if self.periodic: x = torch.cat([x, torch.sin(self.periodic_fc(x))], -1) h = self.layers(x) f_mean = h.squeeze(-1) grad_w_means = map(partial(jacobian, y=f_mean), self.means) f_cov_half = [(grad_w_mean * w_std).reshape(batch_size, -1) for (grad_w_mean, w_std) in zip(grad_w_means, self.stds)] if full_cov: f_cov = sum([i.mm(i.t()) for i in f_cov_half]) f_cov = gpytorch.add_jitter(f_cov) if self.mvn: return MultivariateNormal(f_mean, f_cov) else: return f_mean, f_cov else: f_var = sum([torch.sum(i.pow(2), -1) for i in f_cov_half]) f_var += 1e-6 return f_mean, f_var
def forward(self, *inputs, **params): if not self.training: inducing_point_vars = [ inducing_pt for inducing_pt in self.inducing_points ] full_inputs = [ torch.cat([inducing_point_var, input]) for inducing_point_var, input in zip( inducing_point_vars, inputs) ] else: full_inputs = inputs gaussian_rv_output = self.prior_model.forward(*full_inputs, **params) full_mean, full_covar = gaussian_rv_output.representation() if not self.training: # Get mean/covar components n = self.num_inducing test_mean = full_mean[n:] induc_induc_covar = full_covar[:n, :n] induc_test_covar = full_covar[:n, n:] test_induc_covar = full_covar[n:, :n] test_test_covar = full_covar[n:, n:] # Calculate posterior components if not hasattr(self, 'alpha'): self.alpha = gpytorch.variational_posterior_alpha( induc_induc_covar, self.variational_mean) test_mean = gpytorch.variational_posterior_mean( test_induc_covar, self.alpha) test_covar = gpytorch.variational_posterior_covar( test_induc_covar, induc_test_covar, self.chol_variational_covar, test_test_covar, induc_induc_covar) output = GaussianRandomVariable(test_mean, test_covar) return output else: full_covar = gpytorch.add_jitter(full_covar) f_prior = GaussianRandomVariable(full_mean, full_covar) return f_prior
def forward(self, x, full_cov=True): w_cov_tril = softplus_tril(self.w_cov_raw) h = self.layers(x) f_mean = h.mm(self.w_mean).squeeze(-1) f_cov_half = h.mm(w_cov_tril) if full_cov: f_cov = f_cov_half.mm(f_cov_half.t()) f_cov = gpytorch.add_jitter(f_cov) if self.mvn: return gpytorch.distributions.MultivariateNormal(f_mean, f_cov) else: return f_mean, f_cov else: hw_cov = f_cov_half.mm(w_cov_tril.t()) f_var = torch.sum(hw_cov * h, -1) f_var += 1e-6 return f_mean, f_var
def gpnet(args, dataloader, test_x, prior_gp): N = len(dataloader.dataset) x_dim = 1 prior_gp.train() if args.net == 'tangent': kernel = prior_gp.covar_module bnn_prev = FirstOrder([x_dim] + [args.n_hidden] * args.n_layer, mvn=False) bnn = FirstOrder([x_dim] + [args.n_hidden] * args.n_layer, mvn=True) elif args.net == 'deep': kernel = prior_gp.covar_module bnn_prev = DeepKernel([x_dim] + [args.n_hidden] * args.n_layer, mvn=False) bnn = DeepKernel([x_dim] + [args.n_hidden] * args.n_layer, mvn=True) elif args.net == 'rf': kernel = ScaleKernel(RBFKernel()) kernel_prev = ScaleKernel(RBFKernel()) bnn_prev = RFExpansion(x_dim, args.n_hidden, kernel_prev, mvn=False, fix_ls=args.fix_rf_ls, residual=args.residual) bnn = RFExpansion(x_dim, args.n_hidden, kernel, fix_ls=args.fix_rf_ls, residual=args.residual) bnn_prev.load_state_dict(bnn.state_dict()) else: raise NotImplementedError('Unknown inference net') bnn = bnn.to(args.device) bnn_prev = bnn_prev.to(args.device) prior_gp = prior_gp.to(args.device) infer_gpnet_optimizer = optim.Adam(bnn.parameters(), lr=args.learning_rate) hyper_opt_optimizer = optim.Adam(prior_gp.parameters(), lr=args.hyper_rate) x_min, x_max = dataloader.dataset.range bnn.train() bnn_prev.train() prior_gp.train() mb = master_bar(range(1, args.n_iters + 1)) for t in mb: # Hyperparameter selection beta = args.beta0 * 1. / (1. + args.gamma * math.sqrt(t - 1)) dl_bar = progress_bar(dataloader, parent=mb) for x, y in dl_bar: observed_size = x.size(0) x, y = x.to(args.device), y.to(args.device) x_star = torch.Tensor(args.measurement_size, x_dim).uniform_(x_min, x_max).to(args.device) # [Batch + Measurement Points x x_dims] xx = torch.cat([x, x_star], 0) infer_gpnet_optimizer.zero_grad() hyper_opt_optimizer.zero_grad() # inference net # Eq.(6) Prior p(f) # \mu_1=0, \Sigma_1 mean_prior = torch.zeros(observed_size).to(args.device) K_prior = kernel(xx, xx).add_jitter(1e-6) # q_{\gamma_t}(f_M, f_n) = Normal(mu_2, sigma_2|x_n, x_m) # \mu_2, \Sigma_2 qff_mean_prev, K_prox = bnn_prev(xx) # Eq.(8) adapt prior; p(f)^\beta x q(f)^{1 - \beta} mean_adapt, K_adapt = product_gaussians(mu1=mean_prior, sigma1=K_prior, mu2=qff_mean_prev, sigma2=K_prox, beta=beta) # Eq.(8) (mean_n, mean_m), (Knn, Knm, Kmm) = split_gaussian(mean_adapt, K_adapt, observed_size) # Eq.(2) K_{D,D} + noise / (N\beta_t) Ky = Knn + torch.eye(observed_size).to( args.device) * prior_gp.likelihood.noise / (N / observed_size * beta) Ky_tril = torch.cholesky(Ky) # Eq.(2) mean_target = Knm.t().mm(cholesky_solve(y - mean_n, Ky_tril)) + mean_m mean_target = mean_target.squeeze(-1) K_target = gpytorch.add_jitter( Kmm - Knm.t().mm(cholesky_solve(Knm, Ky_tril)), 1e-6) # \hat{q}_{t+1} (f_M) target_pf_star = MultivariateNormal(mean_target, K_target) # q_\gamma (f_M) qf_star = bnn(x_star) # Eq. (11) kl_obj = kl_div(qf_star, target_pf_star).sum() kl_obj.backward(retain_graph=True) infer_gpnet_optimizer.step() # Hyper paramter update (mean_n_prior, _), (Kn_prior, _, _) = split_gaussian(mean_prior, K_prior, observed_size) pf = MultivariateNormal(mean_n_prior, Kn_prior) (qf_prev_mean, _), (Kn_prox, _, _) = split_gaussian(qff_mean_prev, K_prox, observed_size) qf_prev = MultivariateNormal(qf_prev_mean, Kn_prox) hyper_obj = -(prior_gp.likelihood.expected_log_prob( y.squeeze(-1), qf_prev) - kl_div(qf_prev, pf)) hyper_obj.backward(retain_graph=True) hyper_opt_optimizer.step() mb.child.comment = "kl_obj = {:.3f}, obs_var={:.3f}".format( kl_obj.item(), prior_gp.likelihood.noise.item()) # update q_{\gamma_t} to q_{\gamma_{t+1}} bnn_prev.load_state_dict(bnn.state_dict()) if args.net == 'rf': kernel_prev.load_state_dict(kernel.state_dict()) if t % 50 == 0: mb.write("Iter {}/{}, kl_obj = {:.4f}, noise = {:.4f}".format( t, args.n_iters, kl_obj.item(), prior_gp.likelihood.noise.item())) test_x = test_x.to(args.device) test_stats = evaluate(bnn, prior_gp.likelihood, test_x, args.net == 'tangent') return test_stats
def __call__(self, *args, **kwargs): output = None # Posterior mode if self.posterior: train_xs = self.train_inputs train_y = self.train_target if all([ torch.equal(train_x.data, input.data) for train_x, input in zip(train_xs, args) ]): logging.warning('The input matches the stored training data. ' 'Did you forget to call model.train()?') n_train = len(train_xs[0]) full_inputs = [ torch.cat([train_x, input]) for train_x, input in zip(train_xs, args) ] full_output = super(GPModel, self).__call__(*full_inputs, **kwargs) full_mean, full_covar = full_output.representation() # Exact inference if self.exact_inference: n_train = len(train_xs[0]) full_inputs = [ torch.cat([train_x, input]) for train_x, input in zip(train_xs, args) ] full_output = super(GPModel, self).__call__(*full_inputs, **kwargs) full_mean, full_covar = full_output.representation() train_mean = full_mean[:n_train] test_mean = full_mean[n_train:] train_train_covar = gpytorch.add_diag( full_covar[:n_train, :n_train], self.likelihood.log_noise.exp()) train_test_covar = full_covar[:n_train, n_train:] test_train_covar = full_covar[n_train:, :n_train] test_test_covar = full_covar[n_train:, n_train:] # Calculate posterior components if not self.has_computed_alpha[0]: alpha_strategy = gpytorch.posterior_strategy( train_train_covar) alpha = alpha_strategy.exact_posterior_alpha( train_mean, train_y) self.alpha.copy_(alpha.data) self.has_computed_alpha.fill_(1) else: alpha = Variable(self.alpha) mean_strategy = gpytorch.posterior_strategy(test_train_covar) test_mean = mean_strategy.exact_posterior_mean( test_mean, alpha) covar_strategy = gpytorch.posterior_strategy(train_train_covar) test_covar = covar_strategy.exact_posterior_covar( test_train_covar, train_test_covar, test_test_covar) output = GaussianRandomVariable(test_mean, test_covar) # Approximate inference else: # Ensure variational parameters have been initalized if not self.variational_mean.numel(): raise RuntimeError( 'Variational parameters have not been initalized.' 'Condition on data.') # Get inducing points if hasattr(self, 'inducing_points'): inducing_points = Variable(self.inducing_points) else: inducing_points = train_xs[0] n_induc = len(inducing_points) full_input = torch.cat([inducing_points, args[0]]) full_output = super(GPModel, self).__call__(full_input, **kwargs) full_mean, full_covar = full_output.representation() test_mean = full_mean[n_induc:] induc_induc_covar = full_covar[:n_induc, :n_induc] induc_test_covar = full_covar[:n_induc, n_induc:] test_induc_covar = full_covar[n_induc:, :n_induc] test_test_covar = full_covar[n_induc:, n_induc:] # Calculate posterior components if not self.has_computed_alpha[0]: alpha_strategy = gpytorch.posterior_strategy( induc_induc_covar) alpha = alpha_strategy.variational_posterior_alpha( self.variational_mean) self.alpha.copy_(alpha.data) self.has_computed_alpha.fill_(1) else: alpha = Variable(self.alpha) mean_strategy = gpytorch.posterior_strategy(test_induc_covar) test_mean = mean_strategy.variational_posterior_mean(alpha) covar_strategy = gpytorch.posterior_strategy(test_induc_covar) test_covar = covar_strategy.variational_posterior_covar( induc_test_covar, self.chol_variational_covar, test_test_covar, induc_induc_covar) output = GaussianRandomVariable(test_mean, test_covar) # Training or Prior mode else: output = super(GPModel, self).__call__(*args, **kwargs) # Add some jitter if not self.exact_inference: mean, covar = output.representation() covar = gpytorch.add_jitter(covar) output = GaussianRandomVariable(mean, covar) if self.conditioning: # Reset alpha cache _, covar = output.representation() self.has_computed_alpha.fill_(0) self.alpha.resize_( gpytorch.posterior_strategy(covar).alpha_size()) # Now go through the likelihood if isinstance(output, Variable) or isinstance( output, RandomVariable) or isinstance(output, LazyVariable): output = (output, ) return self.likelihood(*output)