def test_multivariate_normal_batch_non_lazy(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") mean = torch.tensor([0, 1, 2], dtype=torch.float, device=device) covmat = torch.diag(torch.tensor([1, 0.75, 1.5], device=device)) mvn = MultivariateNormal(mean=mean.repeat(2, 1), covariance_matrix=covmat.repeat(2, 1, 1), validate_args=True) self.assertTrue(torch.is_tensor(mvn.covariance_matrix)) self.assertIsInstance(mvn.lazy_covariance_matrix, LazyTensor) self.assertTrue(approx_equal(mvn.variance, covmat.diag().repeat(2, 1))) self.assertTrue(approx_equal(mvn.scale_tril, torch.diag(covmat.diag().sqrt()).repeat(2, 1, 1))) mvn_plus1 = mvn + 1 self.assertTrue(torch.equal(mvn_plus1.mean, mvn.mean + 1)) self.assertTrue(torch.equal(mvn_plus1.covariance_matrix, mvn.covariance_matrix)) mvn_times2 = mvn * 2 self.assertTrue(torch.equal(mvn_times2.mean, mvn.mean * 2)) self.assertTrue(torch.equal(mvn_times2.covariance_matrix, mvn.covariance_matrix * 4)) mvn_divby2 = mvn / 2 self.assertTrue(torch.equal(mvn_divby2.mean, mvn.mean / 2)) self.assertTrue(torch.equal(mvn_divby2.covariance_matrix, mvn.covariance_matrix / 4)) self.assertTrue(approx_equal(mvn.entropy(), 4.3157 * torch.ones(2, device=device))) logprob = mvn.log_prob(torch.zeros(2, 3, device=device)) logprob_expected = -4.8157 * torch.ones(2, device=device) self.assertTrue(approx_equal(logprob, logprob_expected)) logprob = mvn.log_prob(torch.zeros(2, 2, 3, device=device)) logprob_expected = -4.8157 * torch.ones(2, 2, device=device) self.assertTrue(approx_equal(logprob, logprob_expected)) conf_lower, conf_upper = mvn.confidence_region() self.assertTrue(approx_equal(conf_lower, mvn.mean - 2 * mvn.stddev)) self.assertTrue(approx_equal(conf_upper, mvn.mean + 2 * mvn.stddev)) self.assertTrue(mvn.sample().shape == torch.Size([2, 3])) self.assertTrue(mvn.sample(torch.Size([2])).shape == torch.Size([2, 2, 3])) self.assertTrue(mvn.sample(torch.Size([2, 4])).shape == torch.Size([2, 4, 2, 3]))
def test_multivariate_normal_non_lazy(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): mean = torch.tensor([0, 1, 2], device=device, dtype=dtype) covmat = torch.diag( torch.tensor([1, 0.75, 1.5], device=device, dtype=dtype)) mvn = MultivariateNormal(mean=mean, covariance_matrix=covmat, validate_args=True) self.assertTrue(torch.is_tensor(mvn.covariance_matrix)) self.assertIsInstance(mvn.lazy_covariance_matrix, LazyTensor) self.assertTrue(torch.allclose(mvn.variance, torch.diag(covmat))) self.assertTrue(torch.allclose(mvn.scale_tril, covmat.sqrt())) mvn_plus1 = mvn + 1 self.assertTrue(torch.equal(mvn_plus1.mean, mvn.mean + 1)) self.assertTrue( torch.equal(mvn_plus1.covariance_matrix, mvn.covariance_matrix)) mvn_times2 = mvn * 2 self.assertTrue(torch.equal(mvn_times2.mean, mvn.mean * 2)) self.assertTrue( torch.equal(mvn_times2.covariance_matrix, mvn.covariance_matrix * 4)) mvn_divby2 = mvn / 2 self.assertTrue(torch.equal(mvn_divby2.mean, mvn.mean / 2)) self.assertTrue( torch.equal(mvn_divby2.covariance_matrix, mvn.covariance_matrix / 4)) self.assertAlmostEqual(mvn.entropy().item(), 4.3157, places=4) self.assertAlmostEqual(mvn.log_prob( torch.zeros(3, device=device, dtype=dtype)).item(), -4.8157, places=4) logprob = mvn.log_prob( torch.zeros(2, 3, device=device, dtype=dtype)) logprob_expected = torch.tensor([-4.8157, -4.8157], device=device, dtype=dtype) self.assertTrue(torch.allclose(logprob, logprob_expected)) conf_lower, conf_upper = mvn.confidence_region() self.assertTrue( torch.allclose(conf_lower, mvn.mean - 2 * mvn.stddev)) self.assertTrue( torch.allclose(conf_upper, mvn.mean + 2 * mvn.stddev)) self.assertTrue(mvn.sample().shape == torch.Size([3])) self.assertTrue( mvn.sample(torch.Size([2])).shape == torch.Size([2, 3])) self.assertTrue( mvn.sample(torch.Size([2, 4])).shape == torch.Size([2, 4, 3]))
def test_natgrad(self, D=5): mu = torch.randn(D) cov = torch.randn(D, D).tril_() dist = MultivariateNormal(mu, CholLazyTensor(TriangularLazyTensor(cov))) sample = dist.sample() v_dist = NaturalVariationalDistribution(D) v_dist.initialize_variational_distribution(dist) mu = v_dist().mean.detach() v_dist().log_prob(sample).squeeze().backward() eta1 = mu.clone().requires_grad_(True) eta2 = (mu[:, None] * mu + cov @ cov.t()).requires_grad_(True) L = torch.cholesky(eta2 - eta1[:, None] * eta1) dist2 = MultivariateNormal(eta1, CholLazyTensor(TriangularLazyTensor(L))) dist2.log_prob(sample).squeeze().backward() assert torch.allclose(v_dist.natural_vec.grad, eta1.grad) assert torch.allclose(v_dist.natural_mat.grad, eta2.grad)
def compute_ll_for_block(self, vec, mean, var, cov_mat_root): vec = flatten(vec) mean = flatten(mean) var = flatten(var) cov_mat_lt = RootLazyTensor(cov_mat_root.t()) var_lt = DiagLazyTensor(var + 1e-6) covar_lt = AddedDiagLazyTensor(var_lt, cov_mat_lt) qdist = MultivariateNormal(mean, covar_lt) with gpytorch.settings.num_trace_samples(1) and gpytorch.settings.max_cg_iterations(25): return qdist.log_prob(vec)
def test_likelihood(self): x = torch.randn(10, 3)*2 y = torch.randn(10, 1)*2 model = RaoBDenseNet(x, y, 40, noise_std=0.8) device = next(iter(model.parameters())).device x = x.to(device) y = y.to(device) lik1 = model.log_likelihood(x, y, len(x)) f = model.net(x) * model.last_layer_std noise = model.noise_std**2 * torch.eye(x.size(0), dtype=f.dtype, device=f.device) dist = MultivariateNormal(torch.zeros_like(y[:, 0]), [email protected]() + noise) lik2 = dist.log_prob(y.t()).sum() assert torch.allclose(lik1, lik2)
def test_added_diag_lt(self, N=10000, p=20, use_cuda=False, seed=1): torch.manual_seed(seed) if torch.cuda.is_available() and use_cuda: print("Using cuda") device = torch.device("cuda") torch.cuda.manual_seed_all(seed) else: device = torch.device("cpu") D = torch.randn(N, p, device=device) A = torch.randn(N, device=device).abs() * 1e-3 + 0.1 # this is a lazy tensor for DD' D_lt = RootLazyTensor(D) # this is a lazy tensor for diag(A) diag_term = DiagLazyTensor(A) # DD' + diag(A) lowrank_pdiag_lt = AddedDiagLazyTensor(diag_term, D_lt) # z \sim N(0,I), mean = 1 z = torch.randn(N, device=device) mean = torch.ones(N, device=device) diff = mean - z print(lowrank_pdiag_lt.log_det()) logdet = lowrank_pdiag_lt.log_det() inv_matmul = lowrank_pdiag_lt.inv_matmul(diff.unsqueeze(1)).squeeze(1) inv_matmul_quad = torch.dot(diff, inv_matmul) """inv_matmul_quad_qld, logdet_qld = lowrank_pdiag_lt.inv_quad_log_det(inv_quad_rhs=diff.unsqueeze(1), log_det = True) """ """from gpytorch.functions._inv_quad_log_det import InvQuadLogDet iqld_construct = InvQuadLogDet(gpytorch.lazy.lazy_tensor_representation_tree.LazyTensorRepresentationTree(lowrank_pdiag_lt), matrix_shape=lowrank_pdiag_lt.matrix_shape, dtype=lowrank_pdiag_lt.dtype, device=lowrank_pdiag_lt.device, inv_quad=True, log_det=True, preconditioner=lowrank_pdiag_lt._preconditioner()[0], log_det_correction=lowrank_pdiag_lt._preconditioner()[1]) inv_matmul_quad_qld, logdet_qld = iqld_construct(diff.unsqueeze(1))""" num_random_probes = gpytorch.settings.num_trace_samples.value() probe_vectors = torch.empty( lowrank_pdiag_lt.matrix_shape[-1], num_random_probes, dtype=lowrank_pdiag_lt.dtype, device=lowrank_pdiag_lt.device, ) probe_vectors.bernoulli_().mul_(2).add_(-1) probe_vector_norms = torch.norm(probe_vectors, 2, dim=-2, keepdim=True) probe_vectors = probe_vectors.div(probe_vector_norms) # diff_norm = diff.norm() # diff = diff/diff_norm rhs = torch.cat([diff.unsqueeze(1), probe_vectors], dim=1) solves, t_mat = gpytorch.utils.linear_cg( lowrank_pdiag_lt.matmul, rhs, n_tridiag=num_random_probes, max_iter=gpytorch.settings.max_cg_iterations.value(), max_tridiag_iter=gpytorch.settings. max_lanczos_quadrature_iterations.value(), preconditioner=lowrank_pdiag_lt._preconditioner()[0], ) # print(solves) inv_matmul_qld = solves[:, 0] # * diff_norm diff_solve = gpytorch.utils.linear_cg( lowrank_pdiag_lt.matmul, diff.unsqueeze(1), max_iter=gpytorch.settings.max_cg_iterations.value(), preconditioner=lowrank_pdiag_lt._preconditioner()[0], ) print("diff_solve_norm: ", diff_solve.norm()) print( "diff between multiple linear_cg: ", (inv_matmul_qld.unsqueeze(1) - diff_solve).norm() / diff_solve.norm(), ) eigenvalues, eigenvectors = gpytorch.utils.lanczos.lanczos_tridiag_to_diag( t_mat) slq = gpytorch.utils.StochasticLQ() log_det_term, = slq.evaluate( lowrank_pdiag_lt.matrix_shape, eigenvalues, eigenvectors, [lambda x: x.log()], ) logdet_qld = log_det_term + lowrank_pdiag_lt._preconditioner()[1] print("Log det difference: ", (logdet - logdet_qld).norm() / logdet.norm()) print( "inv matmul difference: ", (inv_matmul - inv_matmul_qld).norm() / inv_matmul_quad.norm(), ) # N(1, DD' + diag(A)) lazydist = MultivariateNormal(mean, lowrank_pdiag_lt) lazy_lprob = lazydist.log_prob(z) # exact log probability with Cholesky decomposition exact_dist = torch.distributions.MultivariateNormal( mean, lowrank_pdiag_lt.evaluate().float()) exact_lprob = exact_dist.log_prob(z) print(lazy_lprob, exact_lprob) rel_error = torch.norm(lazy_lprob - exact_lprob) / exact_lprob.norm() self.assertLess(rel_error.cpu().item(), 0.01)