def create_lazy_tensor(self): chol = torch.tensor( [[3, 0, 0, 0, 0], [-1, 2, 0, 0, 0], [1, 4, 1, 0, 0], [0, 2, 3, 2, 0], [-4, -2, 1, 3, 4]], dtype=torch.float, requires_grad=True, ) return CholLazyTensor(TriangularLazyTensor(chol))
def test_natgrad(self, D=5): mu = torch.randn(D) cov = torch.randn(D, D).tril_() dist = MultivariateNormal(mu, CholLazyTensor(TriangularLazyTensor(cov))) sample = dist.sample() v_dist = NaturalVariationalDistribution(D) v_dist.initialize_variational_distribution(dist) mu = v_dist().mean.detach() v_dist().log_prob(sample).squeeze().backward() eta1 = mu.clone().requires_grad_(True) eta2 = (mu[:, None] * mu + cov @ cov.t()).requires_grad_(True) L = torch.cholesky(eta2 - eta1[:, None] * eta1) dist2 = MultivariateNormal(eta1, CholLazyTensor(TriangularLazyTensor(L))) dist2.log_prob(sample).squeeze().backward() assert torch.allclose(v_dist.natural_vec.grad, eta1.grad) assert torch.allclose(v_dist.natural_mat.grad, eta2.grad)
def test_invertible_init(self, D=5): mu = torch.randn(D) cov = torch.randn(D, D).tril_() dist = MultivariateNormal(mu, CholLazyTensor(TriangularLazyTensor(cov))) v_dist = TrilNaturalVariationalDistribution(D, mean_init_std=0.0) v_dist.initialize_variational_distribution(dist) out_dist = v_dist() assert torch.allclose(out_dist.mean, dist.mean) assert torch.allclose(out_dist.covariance_matrix, dist.covariance_matrix)
def create_lazy_tensor(self): chol = torch.tensor( [ [[3, 0, 0, 0, 0], [-1, 2, 0, 0, 0], [1, 4, 1, 0, 0], [0, 2, 3, 2, 0], [-4, -2, 1, 3, 4]], [[2, 0, 0, 0, 0], [3, 1, 0, 0, 0], [-2, 3, 2, 0, 0], [-2, 1, -1, 3, 0], [-4, -4, 5, 2, 3]], ], dtype=torch.float, ) chol.add_(torch.eye(5).unsqueeze(0)) chol.requires_grad_(True) return CholLazyTensor(TriangularLazyTensor(chol))
def test_natgrad(self, D=5): mu = torch.randn(D) cov = torch.randn(D, D) cov = cov @ cov.t() dist = MultivariateNormal( mu, CholLazyTensor(TriangularLazyTensor(torch.linalg.cholesky(cov)))) sample = dist.sample() v_dist = TrilNaturalVariationalDistribution(D, mean_init_std=0.0) v_dist.initialize_variational_distribution(dist) v_dist().log_prob(sample).squeeze().backward() dout_dnat1 = v_dist.natural_vec.grad dout_dnat2 = v_dist.natural_tril_mat.grad # mean_init_std=0. because we need to ensure both have the same distribution v_dist_ref = NaturalVariationalDistribution(D, mean_init_std=0.0) v_dist_ref.initialize_variational_distribution(dist) v_dist_ref().log_prob(sample).squeeze().backward() dout_dnat1_noforward_ref = v_dist_ref.natural_vec.grad dout_dnat2_noforward_ref = v_dist_ref.natural_mat.grad def f(natural_vec, natural_tril_mat): "Transform natural_tril_mat to L" Sigma = torch.inverse(-2 * natural_tril_mat) mu = natural_vec return mu, torch.linalg.cholesky(Sigma).inverse().tril() (mu_ref, natural_tril_mat_ref), (dout_dmu_ref, dout_dnat2_ref) = jvp( f, (v_dist_ref.natural_vec.detach(), v_dist_ref.natural_mat.detach()), (dout_dnat1_noforward_ref, dout_dnat2_noforward_ref), ) assert torch.allclose(natural_tril_mat_ref, v_dist.natural_tril_mat), "Sigma transformation" assert torch.allclose(dout_dnat2_ref, dout_dnat2), "Sigma gradient" assert torch.allclose(mu_ref, v_dist.natural_vec), "mu transformation" assert torch.allclose(dout_dmu_ref, dout_dnat1), "mu gradient"