Example #1
0
 def create_lazy_tensor(self):
     chol = torch.tensor(
         [[3, 0, 0, 0, 0], [-1, 2, 0, 0, 0], [1, 4, 1, 0, 0],
          [0, 2, 3, 2, 0], [-4, -2, 1, 3, 4]],
         dtype=torch.float,
         requires_grad=True,
     )
     return CholLazyTensor(TriangularLazyTensor(chol))
Example #2
0
    def test_natgrad(self, D=5):
        mu = torch.randn(D)
        cov = torch.randn(D, D).tril_()
        dist = MultivariateNormal(mu, CholLazyTensor(TriangularLazyTensor(cov)))
        sample = dist.sample()

        v_dist = NaturalVariationalDistribution(D)
        v_dist.initialize_variational_distribution(dist)
        mu = v_dist().mean.detach()

        v_dist().log_prob(sample).squeeze().backward()

        eta1 = mu.clone().requires_grad_(True)
        eta2 = (mu[:, None] * mu + cov @ cov.t()).requires_grad_(True)
        L = torch.cholesky(eta2 - eta1[:, None] * eta1)
        dist2 = MultivariateNormal(eta1, CholLazyTensor(TriangularLazyTensor(L)))
        dist2.log_prob(sample).squeeze().backward()

        assert torch.allclose(v_dist.natural_vec.grad, eta1.grad)
        assert torch.allclose(v_dist.natural_mat.grad, eta2.grad)
Example #3
0
    def test_invertible_init(self, D=5):
        mu = torch.randn(D)
        cov = torch.randn(D, D).tril_()
        dist = MultivariateNormal(mu, CholLazyTensor(TriangularLazyTensor(cov)))

        v_dist = TrilNaturalVariationalDistribution(D, mean_init_std=0.0)
        v_dist.initialize_variational_distribution(dist)

        out_dist = v_dist()

        assert torch.allclose(out_dist.mean, dist.mean)
        assert torch.allclose(out_dist.covariance_matrix, dist.covariance_matrix)
Example #4
0
 def create_lazy_tensor(self):
     chol = torch.tensor(
         [
             [[3, 0, 0, 0, 0], [-1, 2, 0, 0, 0], [1, 4, 1, 0, 0],
              [0, 2, 3, 2, 0], [-4, -2, 1, 3, 4]],
             [[2, 0, 0, 0, 0], [3, 1, 0, 0, 0], [-2, 3, 2, 0, 0],
              [-2, 1, -1, 3, 0], [-4, -4, 5, 2, 3]],
         ],
         dtype=torch.float,
     )
     chol.add_(torch.eye(5).unsqueeze(0))
     chol.requires_grad_(True)
     return CholLazyTensor(TriangularLazyTensor(chol))
Example #5
0
    def test_natgrad(self, D=5):
        mu = torch.randn(D)
        cov = torch.randn(D, D)
        cov = cov @ cov.t()
        dist = MultivariateNormal(
            mu,
            CholLazyTensor(TriangularLazyTensor(torch.linalg.cholesky(cov))))
        sample = dist.sample()

        v_dist = TrilNaturalVariationalDistribution(D, mean_init_std=0.0)
        v_dist.initialize_variational_distribution(dist)
        v_dist().log_prob(sample).squeeze().backward()
        dout_dnat1 = v_dist.natural_vec.grad
        dout_dnat2 = v_dist.natural_tril_mat.grad

        # mean_init_std=0. because we need to ensure both have the same distribution
        v_dist_ref = NaturalVariationalDistribution(D, mean_init_std=0.0)
        v_dist_ref.initialize_variational_distribution(dist)
        v_dist_ref().log_prob(sample).squeeze().backward()
        dout_dnat1_noforward_ref = v_dist_ref.natural_vec.grad
        dout_dnat2_noforward_ref = v_dist_ref.natural_mat.grad

        def f(natural_vec, natural_tril_mat):
            "Transform natural_tril_mat to L"
            Sigma = torch.inverse(-2 * natural_tril_mat)
            mu = natural_vec
            return mu, torch.linalg.cholesky(Sigma).inverse().tril()

        (mu_ref, natural_tril_mat_ref), (dout_dmu_ref, dout_dnat2_ref) = jvp(
            f,
            (v_dist_ref.natural_vec.detach(), v_dist_ref.natural_mat.detach()),
            (dout_dnat1_noforward_ref, dout_dnat2_noforward_ref),
        )

        assert torch.allclose(natural_tril_mat_ref,
                              v_dist.natural_tril_mat), "Sigma transformation"
        assert torch.allclose(dout_dnat2_ref, dout_dnat2), "Sigma gradient"

        assert torch.allclose(mu_ref, v_dist.natural_vec), "mu transformation"
        assert torch.allclose(dout_dmu_ref, dout_dnat1), "mu gradient"