Пример #1
0
def test_tracegraph_normal_normal():
    # normal-normal; known covariance
    lam0 = jnp.array([0.1, 0.1])  # precision of prior
    loc0 = jnp.array([0.0, 0.5])  # prior mean
    # known precision of observation noise
    lam = jnp.array([6.0, 4.0])
    data = []
    data.append(jnp.array([-0.1, 0.3]))
    data.append(jnp.array([0.0, 0.4]))
    data.append(jnp.array([0.2, 0.5]))
    data.append(jnp.array([0.1, 0.7]))
    n_data = len(data)
    sum_data = data[0] + data[1] + data[2] + data[3]
    analytic_lam_n = lam0 + n_data * lam
    analytic_log_sig_n = -0.5 * jnp.log(analytic_lam_n)
    analytic_loc_n = sum_data * (lam / analytic_lam_n) + loc0 * (
        lam0 / analytic_lam_n)

    class FakeNormal(dist.Normal):
        reparametrized_params = []

    def model():
        with numpyro.plate("plate", 2):
            loc_latent = numpyro.sample(
                "loc_latent", FakeNormal(loc0, jnp.power(lam0, -0.5)))
            for i, x in enumerate(data):
                numpyro.sample(
                    "obs_{}".format(i),
                    dist.Normal(loc_latent, jnp.power(lam, -0.5)),
                    obs=x,
                )
        return loc_latent

    def guide():
        loc_q = numpyro.param("loc_q",
                              analytic_loc_n + jnp.array([0.334, 0.334]))
        log_sig_q = numpyro.param(
            "log_sig_q", analytic_log_sig_n + jnp.array([-0.29, -0.29]))
        sig_q = jnp.exp(log_sig_q)
        with numpyro.plate("plate", 2):
            loc_latent = numpyro.sample("loc_latent", FakeNormal(loc_q, sig_q))
        return loc_latent

    adam = optim.Adam(step_size=0.0015, b1=0.97, b2=0.999)
    svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())
    svi_result = svi.run(jax.random.PRNGKey(0), 5000)

    loc_error = jnp.sum(
        jnp.power(analytic_loc_n - svi_result.params["loc_q"], 2.0))
    log_sig_error = jnp.sum(
        jnp.power(analytic_log_sig_n - svi_result.params["log_sig_q"], 2.0))

    assert_allclose(loc_error, 0, atol=0.05)
    assert_allclose(log_sig_error, 0, atol=0.05)
Пример #2
0
def test_tracegraph_gamma_exponential():
    # exponential-gamma model
    # gamma prior hyperparameter
    alpha0 = 1.0
    # gamma prior hyperparameter
    beta0 = 1.0
    n_data = 2
    data = jnp.array([3.0, 2.0])  # two observations
    alpha_n = alpha0 + n_data  # posterior alpha
    beta_n = beta0 + data.sum()  # posterior beta
    log_alpha_n = jnp.log(alpha_n)
    log_beta_n = jnp.log(beta_n)

    class FakeGamma(dist.Gamma):
        reparametrized_params = []

    def model():
        lambda_latent = numpyro.sample("lambda_latent",
                                       FakeGamma(alpha0, beta0))
        with numpyro.plate("data", len(data)):
            numpyro.sample("obs", dist.Exponential(lambda_latent), obs=data)
        return lambda_latent

    def guide():
        alpha_q_log = numpyro.param("alpha_q_log", log_alpha_n + 0.17)
        beta_q_log = numpyro.param("beta_q_log", log_beta_n - 0.143)
        alpha_q, beta_q = jnp.exp(alpha_q_log), jnp.exp(beta_q_log)
        numpyro.sample("lambda_latent", FakeGamma(alpha_q, beta_q))
        with numpyro.plate("data", len(data)):
            pass

    adam = optim.Adam(step_size=0.0007, b1=0.95, b2=0.999)
    svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())
    svi_result = svi.run(jax.random.PRNGKey(0), 8000)

    alpha_error = jnp.sum(
        jnp.power(log_alpha_n - svi_result.params["alpha_q_log"], 2.0))
    beta_error = jnp.sum(
        jnp.power(log_beta_n - svi_result.params["beta_q_log"], 2.0))

    assert_allclose(alpha_error, 0, atol=0.04)
    assert_allclose(beta_error, 0, atol=0.04)
Пример #3
0
def test_tracegraph_beta_bernoulli():
    # bernoulli-beta model
    # beta prior hyperparameter
    alpha0 = 1.0
    beta0 = 1.0  # beta prior hyperparameter
    data = jnp.array([0.0, 1.0, 1.0, 1.0])
    n_data = float(len(data))
    data_sum = data.sum()
    alpha_n = alpha0 + data_sum  # posterior alpha
    beta_n = beta0 - data_sum + n_data  # posterior beta
    log_alpha_n = jnp.log(alpha_n)
    log_beta_n = jnp.log(beta_n)

    class FakeBeta(dist.Beta):
        reparametrized_params = []

    def model():
        p_latent = numpyro.sample("p_latent", FakeBeta(alpha0, beta0))
        with numpyro.plate("data", len(data)):
            numpyro.sample("obs", dist.Bernoulli(p_latent), obs=data)
        return p_latent

    def guide():
        alpha_q_log = numpyro.param("alpha_q_log", log_alpha_n + 0.17)
        beta_q_log = numpyro.param("beta_q_log", log_beta_n - 0.143)
        alpha_q, beta_q = jnp.exp(alpha_q_log), jnp.exp(beta_q_log)
        p_latent = numpyro.sample("p_latent", FakeBeta(alpha_q, beta_q))
        with numpyro.plate("data", len(data)):
            pass
        return p_latent

    adam = optim.Adam(step_size=0.0007, b1=0.95, b2=0.999)
    svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())
    svi_result = svi.run(jax.random.PRNGKey(0), 3000)

    alpha_error = jnp.sum(
        jnp.power(log_alpha_n - svi_result.params["alpha_q_log"], 2.0))
    beta_error = jnp.sum(
        jnp.power(log_beta_n - svi_result.params["beta_q_log"], 2.0))

    assert_allclose(alpha_error, 0, atol=0.03)
    assert_allclose(beta_error, 0, atol=0.04)
Пример #4
0
def test_svi_discrete_latent():
    cont_inf_only_cls = [RenyiELBO(), Trace_ELBO(), TraceMeanField_ELBO()]
    mixed_inf_cls = [TraceGraph_ELBO()]

    assert not any([c.can_infer_discrete for c in cont_inf_only_cls])
    assert all([c.can_infer_discrete for c in mixed_inf_cls])

    def model():
        numpyro.sample("x", dist.Bernoulli(0.5))

    def guide():
        probs = numpyro.param("probs", 0.2)
        numpyro.sample("x", dist.Bernoulli(probs))

    for elbo in cont_inf_only_cls:
        svi = SVI(model, guide, optim.Adam(1), elbo)
        s_name = type(elbo).__name__
        w_msg = f"Currently, SVI with {s_name} loss does not support models with discrete latent variables"
        with pytest.warns(UserWarning, match=w_msg):
            svi.run(random.PRNGKey(0), 10)
Пример #5
0
def test_tracegraph_gaussian_chain(num_latents, num_steps, step_size, atol,
                                   difficulty):
    loc0 = 0.2
    data = jnp.array([-0.1, 0.03, 0.2, 0.1])
    n_data = data.shape[0]
    sum_data = data.sum()
    N = num_latents
    lambdas = [1.5 * (k + 1) / N for k in range(N + 1)]
    lambdas = list(map(lambda x: jnp.array([x]), lambdas))
    lambda_tilde_posts = [lambdas[0]]
    for k in range(1, N):
        lambda_tilde_k = (lambdas[k] * lambda_tilde_posts[k - 1]) / (
            lambdas[k] + lambda_tilde_posts[k - 1])
        lambda_tilde_posts.append(lambda_tilde_k)
    lambda_posts = [
        None
    ]  # this is never used (just a way of shifting the indexing by 1)
    for k in range(1, N):
        lambda_k = lambdas[k] + lambda_tilde_posts[k - 1]
        lambda_posts.append(lambda_k)
    lambda_N_post = (n_data * lambdas[N]) + lambda_tilde_posts[N - 1]
    lambda_posts.append(lambda_N_post)
    target_kappas = [None]
    target_kappas.extend([lambdas[k] / lambda_posts[k] for k in range(1, N)])
    target_mus = [None]
    target_mus.extend([
        loc0 * lambda_tilde_posts[k - 1] / lambda_posts[k]
        for k in range(1, N)
    ])
    target_loc_N = (sum_data * lambdas[N] / lambda_N_post +
                    loc0 * lambda_tilde_posts[N - 1] / lambda_N_post)
    target_mus.append(target_loc_N)
    np.random.seed(0)
    while True:
        mask = np.random.binomial(1, 0.3, (N, ))
        if mask.sum() < 0.4 * N and mask.sum() > 0.5:
            which_nodes_reparam = mask
            break

    class FakeNormal(dist.Normal):
        reparametrized_params = []

    def model(difficulty=0.0):
        next_mean = loc0
        for k in range(1, N + 1):
            latent_dist = dist.Normal(next_mean,
                                      jnp.power(lambdas[k - 1], -0.5))
            loc_latent = numpyro.sample("loc_latent_{}".format(k), latent_dist)
            next_mean = loc_latent

        loc_N = next_mean
        with numpyro.plate("data", data.shape[0]):
            numpyro.sample("obs",
                           dist.Normal(loc_N, jnp.power(lambdas[N], -0.5)),
                           obs=data)
        return loc_N

    def guide(difficulty=0.0):
        previous_sample = None
        for k in reversed(range(1, N + 1)):
            loc_q = numpyro.param(
                f"loc_q_{k}",
                lambda key: target_mus[k] + difficulty *
                (0.1 * random.normal(key) - 0.53),
            )
            log_sig_q = numpyro.param(
                f"log_sig_q_{k}",
                lambda key: -0.5 * jnp.log(lambda_posts[k]) + difficulty *
                (0.1 * random.normal(key) - 0.53),
            )
            sig_q = jnp.exp(log_sig_q)
            kappa_q = None
            if k != N:
                kappa_q = numpyro.param(
                    "kappa_q_%d" % k,
                    lambda key: target_kappas[k] + difficulty *
                    (0.1 * random.normal(key) - 0.53),
                )
            mean_function = loc_q if k == N else kappa_q * previous_sample + loc_q
            node_flagged = True if which_nodes_reparam[k - 1] == 1.0 else False
            Normal = dist.Normal if node_flagged else FakeNormal
            loc_latent = numpyro.sample(f"loc_latent_{k}",
                                        Normal(mean_function, sig_q))
            previous_sample = loc_latent
        return previous_sample

    adam = optim.Adam(step_size=step_size, b1=0.95, b2=0.999)
    svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())
    svi_result = svi.run(jax.random.PRNGKey(0),
                         num_steps,
                         difficulty=difficulty)

    kappa_errors, log_sig_errors, loc_errors = [], [], []
    for k in range(1, N + 1):
        if k != N:
            kappa_error = jnp.sum(
                jnp.power(svi_result.params[f"kappa_q_{k}"] - target_kappas[k],
                          2))
            kappa_errors.append(kappa_error)

        loc_errors.append(
            jnp.sum(
                jnp.power(svi_result.params[f"loc_q_{k}"] - target_mus[k], 2)))
        log_sig_error = jnp.sum(
            jnp.power(
                svi_result.params[f"log_sig_q_{k}"] +
                0.5 * jnp.log(lambda_posts[k]), 2))
        log_sig_errors.append(log_sig_error)

    max_errors = (np.max(loc_errors), np.max(log_sig_errors),
                  np.max(kappa_errors))

    for i in range(3):
        assert_allclose(max_errors[i], 0, atol=atol)