Пример #1
0
def test_gaussian_model(kernel_cls, D=2, num_warmup=5000, num_samples=5000):
    np.random.seed(0)
    cov = np.random.randn(4 * D * D).reshape((2 * D, 2 * D))
    cov = jnp.matmul(jnp.transpose(cov), cov) + 0.25 * jnp.eye(2 * D)

    cov00 = cov[:D, :D]
    cov01 = cov[:D, D:]
    cov10 = cov[D:, :D]
    cov11 = cov[D:, D:]

    cov_01_cov11_inv = jnp.matmul(cov01, inv(cov11))
    cov_10_cov00_inv = jnp.matmul(cov10, inv(cov00))

    posterior_cov0 = cov00 - jnp.matmul(cov_01_cov11_inv, cov10)
    posterior_cov1 = cov11 - jnp.matmul(cov_10_cov00_inv, cov01)

    # we consider a model in which (x0, x1) ~ MVN(0, cov)

    def gaussian_gibbs_fn(rng_key, hmc_sites, gibbs_sites):
        x1 = hmc_sites["x1"]
        posterior_loc0 = jnp.matmul(cov_01_cov11_inv, x1)
        x0_proposal = dist.MultivariateNormal(
            loc=posterior_loc0,
            covariance_matrix=posterior_cov0).sample(rng_key)
        return {"x0": x0_proposal}

    def model():
        x0 = numpyro.sample(
            "x0",
            dist.MultivariateNormal(loc=jnp.zeros(D), covariance_matrix=cov00))
        posterior_loc1 = jnp.matmul(cov_10_cov00_inv, x0)
        numpyro.sample(
            "x1",
            dist.MultivariateNormal(loc=posterior_loc1,
                                    covariance_matrix=posterior_cov1),
        )

    hmc_kernel = kernel_cls(model, dense_mass=True)
    kernel = HMCGibbs(hmc_kernel,
                      gibbs_fn=gaussian_gibbs_fn,
                      gibbs_sites=["x0"])
    mcmc = MCMC(kernel,
                num_warmup=num_warmup,
                num_samples=num_samples,
                progress_bar=False)

    mcmc.run(random.PRNGKey(0))

    x0_mean = np.mean(mcmc.get_samples()["x0"], axis=0)
    x1_mean = np.mean(mcmc.get_samples()["x1"], axis=0)

    x0_std = np.std(mcmc.get_samples()["x0"], axis=0)
    x1_std = np.std(mcmc.get_samples()["x1"], axis=0)

    assert_allclose(x0_mean, np.zeros(D), atol=0.2)
    assert_allclose(x1_mean, np.zeros(D), atol=0.2)

    assert_allclose(x0_std, np.sqrt(np.diagonal(cov00)), rtol=0.05)
    assert_allclose(x1_std, np.sqrt(np.diagonal(cov11)), rtol=0.1)
Пример #2
0
def DP_inv_pd(v):
    m = len(np.shape(v))
    if m == 1:
        out = vector_DP_inv_pd(v)
    else:
        out = -linalg.inv(v).T
    return out
Пример #3
0
        def inverse_fun(params, inputs, **kwargs):
            L, U, S = params
            L = np.tril(L, -1) + identity
            U = np.triu(U, 1)
            W = P @ L @ (U + np.diag(S))

            outputs = inputs @ linalg.inv(W)
            log_det_jacobian = np.full(inputs.shape[:1],
                                       -np.log(np.abs(S)).sum())
            return outputs, log_det_jacobian
Пример #4
0
    def init_fun(rng, input_dim, **kwargs):
        W = orthogonal()(rng, (input_dim, input_dim))
        W_inv = linalg.inv(W)
        W_log_det = np.linalg.slogdet(W)[-1]

        def direct_fun(params, inputs, **kwargs):
            outputs = inputs @ W
            log_det_jacobian = np.full(inputs.shape[:1], W_log_det)
            return outputs, log_det_jacobian

        def inverse_fun(params, inputs, **kwargs):
            outputs = inputs @ W_inv
            log_det_jacobian = np.full(inputs.shape[:1], -W_log_det)
            return outputs, log_det_jacobian

        return (), direct_fun, inverse_fun