def rethinking_model(B, M, K):
    # priors
    a = numpyro.sample("a", dist.Normal(0, 0.5))
    muB = numpyro.sample("muB", dist.Normal(0, 0.5))
    muM = numpyro.sample("muM", dist.Normal(0, 0.5))
    bB = numpyro.sample("bB", dist.Normal(0, 0.5))
    bM = numpyro.sample("bM", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    Rho_BM = numpyro.sample("Rho_BM", dist.LKJ(2, 2))
    Sigma_BM = numpyro.sample("Sigma_BM", dist.Exponential(1).expand([2]))

    # define B_merge as mix of observed and imputed values
    B_impute = numpyro.sample(
        "B_impute",
        dist.Normal(0, 1).expand([int(np.isnan(B).sum())]).mask(False))
    B_merge = ops.index_update(B, np.nonzero(np.isnan(B))[0], B_impute)

    # M and B correlation
    MB = jnp.stack([M, B_merge], axis=1)
    cov = jnp.outer(Sigma_BM, Sigma_BM) * Rho_BM
    numpyro.sample("MB",
                   dist.MultivariateNormal(jnp.stack([muM, muB]), cov),
                   obs=MB)

    # K as function of B and M
    mu = a + bB * B_merge + bM * M
    numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
Ejemplo n.º 2
0
def linreg_imputation_model(X, y):

    ndims = X.shape[1]
    a = numpyro.sample("a", dist.Normal(0, 0.5))

    beta = numpyro.sample("beta", dist.Normal(0, 0.5).expand([ndims]))
    sigma_y = numpyro.sample("sigma_y", dist.Exponential(1))

    # X_impute contains imputed data for each feature as a list
    # X_merged is the observed data filled with imputed values at missing points.
    X_impute = [None] * ndims
    X_merged = [None] * ndims

    for i in range(ndims):  # for every feature
        no_of_missed = int(np.isnan(X[:, i]).sum())

        if no_of_missed != 0:
            # each nan value is associated with a imputed variable of std normal prior.
            X_impute[i] = numpyro.sample(
                "X_impute_{}".format(i),
                dist.Normal(0, 1).expand([no_of_missed]).mask(False))

            # merging the observed data with the imputed values.
            missed_idx = np.nonzero(np.isnan(X[:, i]))[0]
            X_merged[i] = ops.index_update(X[:, i], missed_idx, X_impute[i])

        # if there are no missing values, its just the observed data.
        else:
            X_merged[i] = X[:, i]

    merged_X = jnp.stack(X_merged).T

    # LKJ is the distribution to model correlation matrices.
    rho = numpyro.sample("rho", dist.LKJ(ndims, 2))  # correlation matrix
    sigma_x = numpyro.sample("sigma_x", dist.Exponential(1).expand([ndims]))
    covariance_x = jnp.outer(sigma_x, sigma_x) * rho  # covariance matrix
    mu_x = numpyro.sample("mu_x", dist.Normal(0, 0.5).expand([ndims]))

    numpyro.sample("X_merged",
                   dist.MultivariateNormal(mu_x, covariance_x),
                   obs=merged_X)

    mu_y = a + merged_X @ beta

    numpyro.sample("y", dist.Normal(mu_y, sigma_y), obs=y)
Ejemplo n.º 3
0
def linreg_model(X, y):

    ndims = X.shape[1]
    a = numpyro.sample("a", dist.Normal(0, 0.5))

    beta = numpyro.sample("beta", dist.Normal(0, 0.5).expand([ndims]))
    sigma_y = numpyro.sample("sigma_y", dist.Exponential(1))

    # LKJ is the distribution to model correlation matrices.
    rho = numpyro.sample("rho", dist.LKJ(ndims, 2))  # correlation matrix
    sigma_x = numpyro.sample("sigma_x", dist.Exponential(1).expand([ndims]))
    covariance_x = jnp.outer(sigma_x, sigma_x) * rho  # covariance matrix
    mu_x = numpyro.sample("mu_x", dist.Normal(0, 0.5).expand([ndims]))

    numpyro.sample("X", dist.MultivariateNormal(mu_x, covariance_x), obs=X)

    mu_y = a + X @ beta

    numpyro.sample("y", dist.Normal(mu_y, sigma_y), obs=y)
Ejemplo n.º 4
0
import jax.numpy as jnp
from jax import random, vmap

rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

import numpyro
import numpyro.distributions as dist

import arviz as az

import pyprobml_utils as pml

eta_list = [1, 2, 4]
colors = ['r', 'k', 'b']
fig, ax = plt.subplots()
for i, eta in enumerate(eta_list):
    R = dist.LKJ(dimension=2,
                 concentration=eta).sample(random.PRNGKey(0), (int(1e4), ))
    az.plot_kde(R[:, 0, 1],
                label=f"eta={eta}",
                plot_kwargs={'color': colors[i]})
plt.legend()
ax.set_xlabel('correlation')
ax.set_ylabel('density')
ax.set_ylim(0, 1.2)
ax.set_xlim(-1.1, 1.1)
pml.savefig('LKJ_1d_correlation.pdf', dpi=300)
plt.show()