def rethinking_model(B, M, K): # priors a = numpyro.sample("a", dist.Normal(0, 0.5)) muB = numpyro.sample("muB", dist.Normal(0, 0.5)) muM = numpyro.sample("muM", dist.Normal(0, 0.5)) bB = numpyro.sample("bB", dist.Normal(0, 0.5)) bM = numpyro.sample("bM", dist.Normal(0, 0.5)) sigma = numpyro.sample("sigma", dist.Exponential(1)) Rho_BM = numpyro.sample("Rho_BM", dist.LKJ(2, 2)) Sigma_BM = numpyro.sample("Sigma_BM", dist.Exponential(1).expand([2])) # define B_merge as mix of observed and imputed values B_impute = numpyro.sample( "B_impute", dist.Normal(0, 1).expand([int(np.isnan(B).sum())]).mask(False)) B_merge = ops.index_update(B, np.nonzero(np.isnan(B))[0], B_impute) # M and B correlation MB = jnp.stack([M, B_merge], axis=1) cov = jnp.outer(Sigma_BM, Sigma_BM) * Rho_BM numpyro.sample("MB", dist.MultivariateNormal(jnp.stack([muM, muB]), cov), obs=MB) # K as function of B and M mu = a + bB * B_merge + bM * M numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
def linreg_imputation_model(X, y): ndims = X.shape[1] a = numpyro.sample("a", dist.Normal(0, 0.5)) beta = numpyro.sample("beta", dist.Normal(0, 0.5).expand([ndims])) sigma_y = numpyro.sample("sigma_y", dist.Exponential(1)) # X_impute contains imputed data for each feature as a list # X_merged is the observed data filled with imputed values at missing points. X_impute = [None] * ndims X_merged = [None] * ndims for i in range(ndims): # for every feature no_of_missed = int(np.isnan(X[:, i]).sum()) if no_of_missed != 0: # each nan value is associated with a imputed variable of std normal prior. X_impute[i] = numpyro.sample( "X_impute_{}".format(i), dist.Normal(0, 1).expand([no_of_missed]).mask(False)) # merging the observed data with the imputed values. missed_idx = np.nonzero(np.isnan(X[:, i]))[0] X_merged[i] = ops.index_update(X[:, i], missed_idx, X_impute[i]) # if there are no missing values, its just the observed data. else: X_merged[i] = X[:, i] merged_X = jnp.stack(X_merged).T # LKJ is the distribution to model correlation matrices. rho = numpyro.sample("rho", dist.LKJ(ndims, 2)) # correlation matrix sigma_x = numpyro.sample("sigma_x", dist.Exponential(1).expand([ndims])) covariance_x = jnp.outer(sigma_x, sigma_x) * rho # covariance matrix mu_x = numpyro.sample("mu_x", dist.Normal(0, 0.5).expand([ndims])) numpyro.sample("X_merged", dist.MultivariateNormal(mu_x, covariance_x), obs=merged_X) mu_y = a + merged_X @ beta numpyro.sample("y", dist.Normal(mu_y, sigma_y), obs=y)
def linreg_model(X, y): ndims = X.shape[1] a = numpyro.sample("a", dist.Normal(0, 0.5)) beta = numpyro.sample("beta", dist.Normal(0, 0.5).expand([ndims])) sigma_y = numpyro.sample("sigma_y", dist.Exponential(1)) # LKJ is the distribution to model correlation matrices. rho = numpyro.sample("rho", dist.LKJ(ndims, 2)) # correlation matrix sigma_x = numpyro.sample("sigma_x", dist.Exponential(1).expand([ndims])) covariance_x = jnp.outer(sigma_x, sigma_x) * rho # covariance matrix mu_x = numpyro.sample("mu_x", dist.Normal(0, 0.5).expand([ndims])) numpyro.sample("X", dist.MultivariateNormal(mu_x, covariance_x), obs=X) mu_y = a + X @ beta numpyro.sample("y", dist.Normal(mu_y, sigma_y), obs=y)
import jax.numpy as jnp from jax import random, vmap rng_key = random.PRNGKey(0) rng_key, rng_key_ = random.split(rng_key) import numpyro import numpyro.distributions as dist import arviz as az import pyprobml_utils as pml eta_list = [1, 2, 4] colors = ['r', 'k', 'b'] fig, ax = plt.subplots() for i, eta in enumerate(eta_list): R = dist.LKJ(dimension=2, concentration=eta).sample(random.PRNGKey(0), (int(1e4), )) az.plot_kde(R[:, 0, 1], label=f"eta={eta}", plot_kwargs={'color': colors[i]}) plt.legend() ax.set_xlabel('correlation') ax.set_ylabel('density') ax.set_ylim(0, 1.2) ax.set_xlim(-1.1, 1.1) pml.savefig('LKJ_1d_correlation.pdf', dpi=300) plt.show()