Esempio n. 1
0
    def test_invertible_init(self, D=5):
        mu = torch.randn(D)
        cov = torch.randn(D, D).tril_()
        dist = MultivariateNormal(mu, CholLazyTensor(TriangularLazyTensor(cov)))

        v_dist = TrilNaturalVariationalDistribution(D, mean_init_std=0.0)
        v_dist.initialize_variational_distribution(dist)

        out_dist = v_dist()

        assert torch.allclose(out_dist.mean, dist.mean)
        assert torch.allclose(out_dist.covariance_matrix, dist.covariance_matrix)
Esempio n. 2
0
    def test_natgrad(self, D=5):
        mu = torch.randn(D)
        cov = torch.randn(D, D)
        cov = cov @ cov.t()
        dist = MultivariateNormal(
            mu,
            CholLazyTensor(TriangularLazyTensor(torch.linalg.cholesky(cov))))
        sample = dist.sample()

        v_dist = TrilNaturalVariationalDistribution(D, mean_init_std=0.0)
        v_dist.initialize_variational_distribution(dist)
        v_dist().log_prob(sample).squeeze().backward()
        dout_dnat1 = v_dist.natural_vec.grad
        dout_dnat2 = v_dist.natural_tril_mat.grad

        # mean_init_std=0. because we need to ensure both have the same distribution
        v_dist_ref = NaturalVariationalDistribution(D, mean_init_std=0.0)
        v_dist_ref.initialize_variational_distribution(dist)
        v_dist_ref().log_prob(sample).squeeze().backward()
        dout_dnat1_noforward_ref = v_dist_ref.natural_vec.grad
        dout_dnat2_noforward_ref = v_dist_ref.natural_mat.grad

        def f(natural_vec, natural_tril_mat):
            "Transform natural_tril_mat to L"
            Sigma = torch.inverse(-2 * natural_tril_mat)
            mu = natural_vec
            return mu, torch.linalg.cholesky(Sigma).inverse().tril()

        (mu_ref, natural_tril_mat_ref), (dout_dmu_ref, dout_dnat2_ref) = jvp(
            f,
            (v_dist_ref.natural_vec.detach(), v_dist_ref.natural_mat.detach()),
            (dout_dnat1_noforward_ref, dout_dnat2_noforward_ref),
        )

        assert torch.allclose(natural_tril_mat_ref,
                              v_dist.natural_tril_mat), "Sigma transformation"
        assert torch.allclose(dout_dnat2_ref, dout_dnat2), "Sigma gradient"

        assert torch.allclose(mu_ref, v_dist.natural_vec), "mu transformation"
        assert torch.allclose(dout_dmu_ref, dout_dnat1), "mu gradient"