def test_tracegraph_normal_normal(): # normal-normal; known covariance lam0 = jnp.array([0.1, 0.1]) # precision of prior loc0 = jnp.array([0.0, 0.5]) # prior mean # known precision of observation noise lam = jnp.array([6.0, 4.0]) data = [] data.append(jnp.array([-0.1, 0.3])) data.append(jnp.array([0.0, 0.4])) data.append(jnp.array([0.2, 0.5])) data.append(jnp.array([0.1, 0.7])) n_data = len(data) sum_data = data[0] + data[1] + data[2] + data[3] analytic_lam_n = lam0 + n_data * lam analytic_log_sig_n = -0.5 * jnp.log(analytic_lam_n) analytic_loc_n = sum_data * (lam / analytic_lam_n) + loc0 * ( lam0 / analytic_lam_n) class FakeNormal(dist.Normal): reparametrized_params = [] def model(): with numpyro.plate("plate", 2): loc_latent = numpyro.sample( "loc_latent", FakeNormal(loc0, jnp.power(lam0, -0.5))) for i, x in enumerate(data): numpyro.sample( "obs_{}".format(i), dist.Normal(loc_latent, jnp.power(lam, -0.5)), obs=x, ) return loc_latent def guide(): loc_q = numpyro.param("loc_q", analytic_loc_n + jnp.array([0.334, 0.334])) log_sig_q = numpyro.param( "log_sig_q", analytic_log_sig_n + jnp.array([-0.29, -0.29])) sig_q = jnp.exp(log_sig_q) with numpyro.plate("plate", 2): loc_latent = numpyro.sample("loc_latent", FakeNormal(loc_q, sig_q)) return loc_latent adam = optim.Adam(step_size=0.0015, b1=0.97, b2=0.999) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) svi_result = svi.run(jax.random.PRNGKey(0), 5000) loc_error = jnp.sum( jnp.power(analytic_loc_n - svi_result.params["loc_q"], 2.0)) log_sig_error = jnp.sum( jnp.power(analytic_log_sig_n - svi_result.params["log_sig_q"], 2.0)) assert_allclose(loc_error, 0, atol=0.05) assert_allclose(log_sig_error, 0, atol=0.05)
def test_tracegraph_gamma_exponential(): # exponential-gamma model # gamma prior hyperparameter alpha0 = 1.0 # gamma prior hyperparameter beta0 = 1.0 n_data = 2 data = jnp.array([3.0, 2.0]) # two observations alpha_n = alpha0 + n_data # posterior alpha beta_n = beta0 + data.sum() # posterior beta log_alpha_n = jnp.log(alpha_n) log_beta_n = jnp.log(beta_n) class FakeGamma(dist.Gamma): reparametrized_params = [] def model(): lambda_latent = numpyro.sample("lambda_latent", FakeGamma(alpha0, beta0)) with numpyro.plate("data", len(data)): numpyro.sample("obs", dist.Exponential(lambda_latent), obs=data) return lambda_latent def guide(): alpha_q_log = numpyro.param("alpha_q_log", log_alpha_n + 0.17) beta_q_log = numpyro.param("beta_q_log", log_beta_n - 0.143) alpha_q, beta_q = jnp.exp(alpha_q_log), jnp.exp(beta_q_log) numpyro.sample("lambda_latent", FakeGamma(alpha_q, beta_q)) with numpyro.plate("data", len(data)): pass adam = optim.Adam(step_size=0.0007, b1=0.95, b2=0.999) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) svi_result = svi.run(jax.random.PRNGKey(0), 8000) alpha_error = jnp.sum( jnp.power(log_alpha_n - svi_result.params["alpha_q_log"], 2.0)) beta_error = jnp.sum( jnp.power(log_beta_n - svi_result.params["beta_q_log"], 2.0)) assert_allclose(alpha_error, 0, atol=0.04) assert_allclose(beta_error, 0, atol=0.04)
def test_tracegraph_beta_bernoulli(): # bernoulli-beta model # beta prior hyperparameter alpha0 = 1.0 beta0 = 1.0 # beta prior hyperparameter data = jnp.array([0.0, 1.0, 1.0, 1.0]) n_data = float(len(data)) data_sum = data.sum() alpha_n = alpha0 + data_sum # posterior alpha beta_n = beta0 - data_sum + n_data # posterior beta log_alpha_n = jnp.log(alpha_n) log_beta_n = jnp.log(beta_n) class FakeBeta(dist.Beta): reparametrized_params = [] def model(): p_latent = numpyro.sample("p_latent", FakeBeta(alpha0, beta0)) with numpyro.plate("data", len(data)): numpyro.sample("obs", dist.Bernoulli(p_latent), obs=data) return p_latent def guide(): alpha_q_log = numpyro.param("alpha_q_log", log_alpha_n + 0.17) beta_q_log = numpyro.param("beta_q_log", log_beta_n - 0.143) alpha_q, beta_q = jnp.exp(alpha_q_log), jnp.exp(beta_q_log) p_latent = numpyro.sample("p_latent", FakeBeta(alpha_q, beta_q)) with numpyro.plate("data", len(data)): pass return p_latent adam = optim.Adam(step_size=0.0007, b1=0.95, b2=0.999) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) svi_result = svi.run(jax.random.PRNGKey(0), 3000) alpha_error = jnp.sum( jnp.power(log_alpha_n - svi_result.params["alpha_q_log"], 2.0)) beta_error = jnp.sum( jnp.power(log_beta_n - svi_result.params["beta_q_log"], 2.0)) assert_allclose(alpha_error, 0, atol=0.03) assert_allclose(beta_error, 0, atol=0.04)
def test_svi_discrete_latent(): cont_inf_only_cls = [RenyiELBO(), Trace_ELBO(), TraceMeanField_ELBO()] mixed_inf_cls = [TraceGraph_ELBO()] assert not any([c.can_infer_discrete for c in cont_inf_only_cls]) assert all([c.can_infer_discrete for c in mixed_inf_cls]) def model(): numpyro.sample("x", dist.Bernoulli(0.5)) def guide(): probs = numpyro.param("probs", 0.2) numpyro.sample("x", dist.Bernoulli(probs)) for elbo in cont_inf_only_cls: svi = SVI(model, guide, optim.Adam(1), elbo) s_name = type(elbo).__name__ w_msg = f"Currently, SVI with {s_name} loss does not support models with discrete latent variables" with pytest.warns(UserWarning, match=w_msg): svi.run(random.PRNGKey(0), 10)
def test_tracegraph_gaussian_chain(num_latents, num_steps, step_size, atol, difficulty): loc0 = 0.2 data = jnp.array([-0.1, 0.03, 0.2, 0.1]) n_data = data.shape[0] sum_data = data.sum() N = num_latents lambdas = [1.5 * (k + 1) / N for k in range(N + 1)] lambdas = list(map(lambda x: jnp.array([x]), lambdas)) lambda_tilde_posts = [lambdas[0]] for k in range(1, N): lambda_tilde_k = (lambdas[k] * lambda_tilde_posts[k - 1]) / ( lambdas[k] + lambda_tilde_posts[k - 1]) lambda_tilde_posts.append(lambda_tilde_k) lambda_posts = [ None ] # this is never used (just a way of shifting the indexing by 1) for k in range(1, N): lambda_k = lambdas[k] + lambda_tilde_posts[k - 1] lambda_posts.append(lambda_k) lambda_N_post = (n_data * lambdas[N]) + lambda_tilde_posts[N - 1] lambda_posts.append(lambda_N_post) target_kappas = [None] target_kappas.extend([lambdas[k] / lambda_posts[k] for k in range(1, N)]) target_mus = [None] target_mus.extend([ loc0 * lambda_tilde_posts[k - 1] / lambda_posts[k] for k in range(1, N) ]) target_loc_N = (sum_data * lambdas[N] / lambda_N_post + loc0 * lambda_tilde_posts[N - 1] / lambda_N_post) target_mus.append(target_loc_N) np.random.seed(0) while True: mask = np.random.binomial(1, 0.3, (N, )) if mask.sum() < 0.4 * N and mask.sum() > 0.5: which_nodes_reparam = mask break class FakeNormal(dist.Normal): reparametrized_params = [] def model(difficulty=0.0): next_mean = loc0 for k in range(1, N + 1): latent_dist = dist.Normal(next_mean, jnp.power(lambdas[k - 1], -0.5)) loc_latent = numpyro.sample("loc_latent_{}".format(k), latent_dist) next_mean = loc_latent loc_N = next_mean with numpyro.plate("data", data.shape[0]): numpyro.sample("obs", dist.Normal(loc_N, jnp.power(lambdas[N], -0.5)), obs=data) return loc_N def guide(difficulty=0.0): previous_sample = None for k in reversed(range(1, N + 1)): loc_q = numpyro.param( f"loc_q_{k}", lambda key: target_mus[k] + difficulty * (0.1 * random.normal(key) - 0.53), ) log_sig_q = numpyro.param( f"log_sig_q_{k}", lambda key: -0.5 * jnp.log(lambda_posts[k]) + difficulty * (0.1 * random.normal(key) - 0.53), ) sig_q = jnp.exp(log_sig_q) kappa_q = None if k != N: kappa_q = numpyro.param( "kappa_q_%d" % k, lambda key: target_kappas[k] + difficulty * (0.1 * random.normal(key) - 0.53), ) mean_function = loc_q if k == N else kappa_q * previous_sample + loc_q node_flagged = True if which_nodes_reparam[k - 1] == 1.0 else False Normal = dist.Normal if node_flagged else FakeNormal loc_latent = numpyro.sample(f"loc_latent_{k}", Normal(mean_function, sig_q)) previous_sample = loc_latent return previous_sample adam = optim.Adam(step_size=step_size, b1=0.95, b2=0.999) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) svi_result = svi.run(jax.random.PRNGKey(0), num_steps, difficulty=difficulty) kappa_errors, log_sig_errors, loc_errors = [], [], [] for k in range(1, N + 1): if k != N: kappa_error = jnp.sum( jnp.power(svi_result.params[f"kappa_q_{k}"] - target_kappas[k], 2)) kappa_errors.append(kappa_error) loc_errors.append( jnp.sum( jnp.power(svi_result.params[f"loc_q_{k}"] - target_mus[k], 2))) log_sig_error = jnp.sum( jnp.power( svi_result.params[f"log_sig_q_{k}"] + 0.5 * jnp.log(lambda_posts[k]), 2)) log_sig_errors.append(log_sig_error) max_errors = (np.max(loc_errors), np.max(log_sig_errors), np.max(kappa_errors)) for i in range(3): assert_allclose(max_errors[i], 0, atol=atol)