Exemple #1
0
def test_natural_normal():
    chol = B.randn(2, 2)
    dist = Normal(B.randn(2, 1), B.reg(chol @ chol.T, diag=1e-1))
    nat = NaturalNormal.from_normal(dist)

    # Test properties.
    assert dist.dtype == nat.dtype
    for name in ["dim", "mean", "var", "m2"]:
        approx(getattr(dist, name), getattr(nat, name))

    # Test sampling.
    state = B.create_random_state(dist.dtype, seed=0)
    state, sample = nat.sample(state, num=1_000_000)
    emp_mean = B.mean(B.dense(sample), axis=1, squeeze=False)
    emp_var = (sample - emp_mean) @ (sample - emp_mean).T / 1_000_000
    approx(dist.mean, emp_mean, rtol=5e-2)
    approx(dist.var, emp_var, rtol=5e-2)

    # Test KL.
    chol = B.randn(2, 2)
    other_dist = Normal(B.randn(2, 1), B.reg(chol @ chol.T, diag=1e-2))
    other_nat = NaturalNormal.from_normal(other_dist)
    approx(dist.kl(other_dist), nat.kl(other_nat))

    # Test log-pdf.
    x = B.randn(2, 1)
    approx(dist.logpdf(x), nat.logpdf(x))
Exemple #2
0
def elbo(lik, p: Normal, q: Normal, num_samples=1):
    return B.mean(lik(q.sample(num_samples))) - q.kl(p)