Esempio n. 1
0
def test_mutable_state(stable_update, num_particles, elbo):
    def model():
        x = numpyro.sample("x", dist.Normal(-1, 1))
        numpyro_mutable("x1p", x + 1)

    def guide():
        loc = numpyro.param("loc", 0.0)
        p = numpyro_mutable("loc1p", {"value": None})
        # we can modify the content of `p` if it is a dict
        p["value"] = loc + 2
        numpyro.sample("x", dist.Normal(loc, 0.1))

    svi = SVI(model, guide, optim.Adam(0.1), elbo(num_particles=num_particles))
    if num_particles > 1:
        with pytest.raises(ValueError, match="mutable state"):
            svi_result = svi.run(random.PRNGKey(0),
                                 1000,
                                 stable_update=stable_update)
        return
    svi_result = svi.run(random.PRNGKey(0), 1000, stable_update=stable_update)
    params = svi_result.params
    mutable_state = svi_result.state.mutable_state
    assert set(mutable_state) == {"x1p", "loc1p"}
    assert_allclose(mutable_state["loc1p"]["value"],
                    params["loc"] + 2,
                    atol=0.1)
    # here, the initial loc has value 0., hence x1p will have init value near 1
    # it won't be updated during SVI run because it is not a mutable state
    assert_allclose(mutable_state["x1p"], 1.0, atol=0.2)
Esempio n. 2
0
def test_run_with_small_num_steps(num_steps):
    def model():
        pass

    def guide():
        pass

    svi = SVI(model, guide, optim.Adam(1), Trace_ELBO())
    svi.run(random.PRNGKey(0), num_steps)
Esempio n. 3
0
def test_plate_inconsistent(size, dim):
    def model():
        with numpyro.plate("a", 10, dim=-1):
            numpyro.sample("x", dist.Normal(0, 1))
        with numpyro.plate("a", size, dim=dim):
            numpyro.sample("y", dist.Normal(0, 1))

    guide = AutoDelta(model)
    svi = SVI(model, guide, numpyro.optim.Adam(step_size=0.1), Trace_ELBO())
    with pytest.raises(AssertionError, match="has inconsistent dim or size"):
        svi.run(random.PRNGKey(0), 10)
Esempio n. 4
0
def test_svi_discrete_latent():
    def model():
        numpyro.sample("x", dist.Bernoulli(0.5))

    def guide():
        probs = numpyro.param("probs", 0.2)
        numpyro.sample("x", dist.Bernoulli(probs))

    svi = SVI(model, guide, optim.Adam(1), Trace_ELBO())
    with pytest.warns(UserWarning,
                      match="SVI does not support models with discrete"):
        svi.run(random.PRNGKey(0), 10)
    def find_map(
        self,
        num_steps: int = 10000,
        handlers: Optional[list] = None,
        reparam: Union[str, hdl.reparam] = "auto",
        svi_kwargs: dict = {},
    ):
        """EXPERIMENTAL: find MAP.

        Args:
            num_steps (int): [description]. Defaults to 10000.
            handlers (list, optional): [description]. Defaults to None.
            reparam (str, or numpyro.handlers.reparam): [description]. Defaults to 'auto'.
            svi_kwargs (dict): [description]. Defaults to {}.
        """
        model = self._add_handlers_to_model(handlers=handlers, reparam=reparam)

        guide = numpyro.infer.autoguide.AutoDelta(model)

        optim = svi_kwargs.pop("optim", numpyro.optim.Minimize())
        loss = svi_kwargs.pop("loss", numpyro.infer.Trace_ELBO())
        map_svi = SVI(model, guide, optim, loss=loss, **svi_kwargs)

        rng_key, self._rng_key = random.split(self._rng_key)
        map_result = map_svi.run(rng_key,
                                 num_steps,
                                 self.n,
                                 nu=self.nu,
                                 nu_err=self.nu_err)

        self._map_loss = map_result.losses
        self._map_guide = map_svi.guide
        self._map_params = map_result.params
Esempio n. 6
0
def fit_svi(model,
            n_draws=1000,
            autoguide=AutoLaplaceApproximation,
            loss=Trace_ELBO(),
            optim=optim.Adam(step_size=.00001),
            num_warmup=2000,
            use_gpu=False,
            num_chains=1,
            progress_bar=False,
            sampler=None,
            **kwargs):
    select_device(use_gpu, num_chains)
    guide = autoguide(model)
    svi = SVI(model=model, guide=guide, loss=loss, optim=optim, **kwargs)
    # Experimental interface:
    svi_result = svi.run(jax.random.PRNGKey(0),
                         num_steps=num_warmup,
                         stable_update=True,
                         progress_bar=progress_bar)
    # Old:
    post = guide.sample_posterior(jax.random.PRNGKey(1),
                                  params=svi_result.params,
                                  sample_shape=(1, n_draws))
    # New:
    #predictive = Predictive(guide,  params=svi_result.params, num_samples=n_draws)
    #post = predictive(jax.random.PRNGKey(1), **kwargs)

    # Old interface:
    # init_state = svi.init(jax.random.PRNGKey(0))
    # state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(n_draws))#, length=num_warmup)
    # svi_params = svi.get_params(state)
    # post = guide.sample_posterior(jax.random.PRNGKey(1), svi_params, (1, n_draws))

    trace = az.from_dict(post)
    return trace, post
Esempio n. 7
0
def run_hmcecs(hmcecs_key, args, data, obs, inner_kernel):
    svi_key, mcmc_key = random.split(hmcecs_key)

    # find reference parameters for second order taylor expansion to estimate likelihood (taylor_proxy)
    optimizer = numpyro.optim.Adam(step_size=1e-3)
    guide = autoguide.AutoDelta(model)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
    svi_result = svi.run(svi_key, args.num_svi_steps, data, obs,
                         args.subsample_size)
    params, losses = svi_result.params, svi_result.losses
    ref_params = {"theta": params["theta_auto_loc"]}

    # taylor proxy estimates log likelihood (ll) by
    # taylor_expansion(ll, theta_curr) +
    #     sum_{i in subsample} ll_i(theta_curr) - taylor_expansion(ll_i, theta_curr) around ref_params
    proxy = HMCECS.taylor_proxy(ref_params)

    kernel = HMCECS(inner_kernel, num_blocks=args.num_blocks, proxy=proxy)
    mcmc = MCMC(kernel,
                num_warmup=args.num_warmup,
                num_samples=args.num_samples)

    mcmc.run(mcmc_key, data, obs, args.subsample_size)
    mcmc.print_summary()
    return losses, mcmc.get_samples()
Esempio n. 8
0
    def fit(self, X, Y, rng_key, n_step):
        self.X_train = X

        # store moments of training y (to normalize)
        self.y_mean = jnp.mean(Y)
        self.y_std = jnp.std(Y)

        # normalize y
        Y = (Y - self.y_mean) / self.y_std

        # setup optimizer and SVI
        optim = numpyro.optim.Adam(step_size=0.005, b1=0.5)

        svi = SVI(
            model,
            guide=AutoDelta(model),
            optim=optim,
            loss=Trace_ELBO(),
            X=X,
            Y=Y,
        )

        params, _ = svi.run(rng_key, n_step)

        # get kernel parameters from guide with proper names
        self.kernel_params = svi.guide.median(params)

        # store cholesky factor of prior covariance
        self.L = linalg.cho_factor(self.kernel(X, X, **self.kernel_params))

        # store inverted prior covariance multiplied by y
        self.alpha = linalg.cho_solve(self.L, Y)

        return self.kernel_params
Esempio n. 9
0
def test_autoguide_deterministic(auto_class):
    def model(y=None):
        n = y.size if y is not None else 1

        mu = numpyro.sample("mu", dist.Normal(0, 5))
        sigma = numpyro.param("sigma", 1, constraint=constraints.positive)

        y = numpyro.sample("y", dist.Normal(mu, sigma).expand((n,)), obs=y)
        numpyro.deterministic("z", (y - mu) / sigma)

    mu, sigma = 2, 3
    y = mu + sigma * random.normal(random.PRNGKey(0), shape=(300,))
    y_train = y[:200]
    y_test = y[200:]

    guide = auto_class(model)
    optimiser = numpyro.optim.Adam(step_size=0.01)
    svi = SVI(model, guide, optimiser, Trace_ELBO())

    params, losses = svi.run(random.PRNGKey(0), num_steps=500, y=y_train)
    posterior_samples = guide.sample_posterior(
        random.PRNGKey(0), params, sample_shape=(1000,)
    )

    predictive = Predictive(model, posterior_samples, params=params)
    predictive_samples = predictive(random.PRNGKey(0), y_test)

    assert predictive_samples["y"].shape == (1000, 100)
    assert predictive_samples["z"].shape == (1000, 100)
    assert_allclose(
        (predictive_samples["y"] - posterior_samples["mu"][..., None])
        / params["sigma"],
        predictive_samples["z"],
        atol=0.05,
    )
Esempio n. 10
0
def test_run(progress_bar):
    data = jnp.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param(
            "alpha_q", lambda key: random.normal(key), constraint=constraints.positive
        )
        beta_q = numpyro.param(
            "beta_q",
            lambda key: random.exponential(key),
            constraint=constraints.positive,
        )
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

    svi = SVI(model, guide, optim.Adam(0.05), Trace_ELBO())
    params, losses = svi.run(random.PRNGKey(1), 1000, data, progress_bar=progress_bar)
    assert losses.shape == (1000,)
    assert_allclose(
        params["alpha_q"] / (params["alpha_q"] + params["beta_q"]),
        0.8,
        atol=0.05,
        rtol=0.05,
    )
Esempio n. 11
0
def run_svi(rng_key, X, Y, guide_family="AutoDiagonalNormal", K=8):
    assert guide_family in ["AutoDiagonalNormal", "AutoDAIS"]

    if guide_family == "AutoDAIS":
        guide = autoguide.AutoDAIS(model, K=K, eta_init=0.02, eta_max=0.5)
        step_size = 5e-4
    elif guide_family == "AutoDiagonalNormal":
        guide = autoguide.AutoDiagonalNormal(model)
        step_size = 3e-3

    optimizer = numpyro.optim.Adam(step_size=step_size)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
    svi_result = svi.run(rng_key, args.num_svi_steps, X, Y)
    params = svi_result.params

    final_elbo = -Trace_ELBO(num_particles=1000).loss(rng_key, params, model,
                                                      guide, X, Y)

    guide_name = guide_family
    if guide_family == "AutoDAIS":
        guide_name += "-{}".format(K)

    print("[{}] final elbo: {:.2f}".format(guide_name, final_elbo))

    return guide.sample_posterior(random.PRNGKey(1),
                                  params,
                                  sample_shape=(args.num_samples, ))
Esempio n. 12
0
def test_predictive_with_guide():
    data = jnp.array([1] * 8 + [0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1., 1.))
        with numpyro.plate("plate", 10):
            numpyro.deterministic("beta_sq", f**2)
            numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q",
                                1.0,
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

    svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO())
    svi_result = svi.run(random.PRNGKey(1), 3000, data)
    params = svi_result.params
    predictive = Predictive(model,
                            guide=guide,
                            params=params,
                            num_samples=1000)(random.PRNGKey(2), data=None)
    assert predictive["beta_sq"].shape == (1000, )
    obs_pred = predictive["obs"].astype(np.float32)
    assert_allclose(jnp.mean(obs_pred), 0.8, atol=0.05)
Esempio n. 13
0
def test_tracegraph_normal_normal():
    # normal-normal; known covariance
    lam0 = jnp.array([0.1, 0.1])  # precision of prior
    loc0 = jnp.array([0.0, 0.5])  # prior mean
    # known precision of observation noise
    lam = jnp.array([6.0, 4.0])
    data = []
    data.append(jnp.array([-0.1, 0.3]))
    data.append(jnp.array([0.0, 0.4]))
    data.append(jnp.array([0.2, 0.5]))
    data.append(jnp.array([0.1, 0.7]))
    n_data = len(data)
    sum_data = data[0] + data[1] + data[2] + data[3]
    analytic_lam_n = lam0 + n_data * lam
    analytic_log_sig_n = -0.5 * jnp.log(analytic_lam_n)
    analytic_loc_n = sum_data * (lam / analytic_lam_n) + loc0 * (
        lam0 / analytic_lam_n)

    class FakeNormal(dist.Normal):
        reparametrized_params = []

    def model():
        with numpyro.plate("plate", 2):
            loc_latent = numpyro.sample(
                "loc_latent", FakeNormal(loc0, jnp.power(lam0, -0.5)))
            for i, x in enumerate(data):
                numpyro.sample(
                    "obs_{}".format(i),
                    dist.Normal(loc_latent, jnp.power(lam, -0.5)),
                    obs=x,
                )
        return loc_latent

    def guide():
        loc_q = numpyro.param("loc_q",
                              analytic_loc_n + jnp.array([0.334, 0.334]))
        log_sig_q = numpyro.param(
            "log_sig_q", analytic_log_sig_n + jnp.array([-0.29, -0.29]))
        sig_q = jnp.exp(log_sig_q)
        with numpyro.plate("plate", 2):
            loc_latent = numpyro.sample("loc_latent", FakeNormal(loc_q, sig_q))
        return loc_latent

    adam = optim.Adam(step_size=0.0015, b1=0.97, b2=0.999)
    svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())
    svi_result = svi.run(jax.random.PRNGKey(0), 5000)

    loc_error = jnp.sum(
        jnp.power(analytic_loc_n - svi_result.params["loc_q"], 2.0))
    log_sig_error = jnp.sum(
        jnp.power(analytic_log_sig_n - svi_result.params["log_sig_q"], 2.0))

    assert_allclose(loc_error, 0, atol=0.05)
    assert_allclose(log_sig_error, 0, atol=0.05)
Esempio n. 14
0
def test_svi_discrete_latent():
    cont_inf_only_cls = [RenyiELBO(), Trace_ELBO(), TraceMeanField_ELBO()]
    mixed_inf_cls = [TraceGraph_ELBO()]

    assert not any([c.can_infer_discrete for c in cont_inf_only_cls])
    assert all([c.can_infer_discrete for c in mixed_inf_cls])

    def model():
        numpyro.sample("x", dist.Bernoulli(0.5))

    def guide():
        probs = numpyro.param("probs", 0.2)
        numpyro.sample("x", dist.Bernoulli(probs))

    for elbo in cont_inf_only_cls:
        svi = SVI(model, guide, optim.Adam(1), elbo)
        s_name = type(elbo).__name__
        w_msg = f"Currently, SVI with {s_name} loss does not support models with discrete latent variables"
        with pytest.warns(UserWarning, match=w_msg):
            svi.run(random.PRNGKey(0), 10)
Esempio n. 15
0
def run_svi(model, guide_family, args, X, Y):
    if guide_family == "AutoDelta":
        guide = autoguide.AutoDelta(model)
    elif guide_family == "AutoDiagonalNormal":
        guide = autoguide.AutoDiagonalNormal(model)

    optimizer = numpyro.optim.Adam(0.001)
    svi = SVI(model, guide, optimizer, Trace_ELBO())
    svi_results = svi.run(PRNGKey(1), args.maxiter, X=X, Y=Y)
    params = svi_results.params

    return params, guide
Esempio n. 16
0
def test_stable_run(stable_run):
    def model():
        var = numpyro.sample("var", dist.Exponential(1))
        numpyro.sample("obs", dist.Normal(0, jnp.sqrt(var)), obs=0.0)

    def guide():
        loc = numpyro.param("loc", 0.0)
        numpyro.sample("var", dist.Normal(loc, 10))

    svi = SVI(model, guide, optim.Adam(1), Trace_ELBO())
    svi_result = svi.run(random.PRNGKey(0), 1000, stable_update=stable_run)
    assert jnp.isfinite(svi_result.params["loc"]) == stable_run
Esempio n. 17
0
def test_subsample_model_with_deterministic():
    def model():
        x = numpyro.sample("x", dist.Normal(0, 1))
        numpyro.deterministic("x2", x * 2)
        with numpyro.plate("N", 10, subsample_size=5):
            numpyro.sample("obs", dist.Normal(x, 1), obs=jnp.ones(5))

    guide = AutoNormal(model)
    svi = SVI(model, guide, optim.Adam(1.0), Trace_ELBO())
    svi_result = svi.run(random.PRNGKey(0), 10)
    samples = guide.sample_posterior(random.PRNGKey(1), svi_result.params)
    assert "x2" in samples
Esempio n. 18
0
def test_laplace_approximation_custom_hessian():
    def model(x, y):
        a = numpyro.sample("a", dist.Normal(0, 10))
        b = numpyro.sample("b", dist.Normal(0, 10))
        mu = a + b * x
        numpyro.sample("y", dist.Normal(mu, 1), obs=y)

    x = random.normal(random.PRNGKey(0), (100, ))
    y = 1 + 2 * x
    guide = AutoLaplaceApproximation(
        model, hessian_fn=lambda f, x: jacobian(jacobian(f))(x))
    svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), x=x, y=y)
    svi_result = svi.run(random.PRNGKey(0), 10000, progress_bar=False)
    guide.get_transform(svi_result.params)
Esempio n. 19
0
def test_pickle_autoguide(guide_class):
    x = np.random.poisson(1.0, size=(100,))

    guide = guide_class(poisson_regression)
    optim = numpyro.optim.Adam(1e-2)
    svi = SVI(poisson_regression, guide, optim, numpyro.infer.Trace_ELBO())
    svi_result = svi.run(random.PRNGKey(1), 3, x, len(x))
    pickled_guide = pickle.loads(pickle.dumps(guide))

    predictive = Predictive(
        poisson_regression,
        guide=pickled_guide,
        params=svi_result.params,
        num_samples=1,
        return_sites=["param", "x"],
    )
    samples = predictive(random.PRNGKey(1), None, 1)
    assert set(samples.keys()) == {"param", "x"}
Esempio n. 20
0
def test_tracegraph_beta_bernoulli():
    # bernoulli-beta model
    # beta prior hyperparameter
    alpha0 = 1.0
    beta0 = 1.0  # beta prior hyperparameter
    data = jnp.array([0.0, 1.0, 1.0, 1.0])
    n_data = float(len(data))
    data_sum = data.sum()
    alpha_n = alpha0 + data_sum  # posterior alpha
    beta_n = beta0 - data_sum + n_data  # posterior beta
    log_alpha_n = jnp.log(alpha_n)
    log_beta_n = jnp.log(beta_n)

    class FakeBeta(dist.Beta):
        reparametrized_params = []

    def model():
        p_latent = numpyro.sample("p_latent", FakeBeta(alpha0, beta0))
        with numpyro.plate("data", len(data)):
            numpyro.sample("obs", dist.Bernoulli(p_latent), obs=data)
        return p_latent

    def guide():
        alpha_q_log = numpyro.param("alpha_q_log", log_alpha_n + 0.17)
        beta_q_log = numpyro.param("beta_q_log", log_beta_n - 0.143)
        alpha_q, beta_q = jnp.exp(alpha_q_log), jnp.exp(beta_q_log)
        p_latent = numpyro.sample("p_latent", FakeBeta(alpha_q, beta_q))
        with numpyro.plate("data", len(data)):
            pass
        return p_latent

    adam = optim.Adam(step_size=0.0007, b1=0.95, b2=0.999)
    svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())
    svi_result = svi.run(jax.random.PRNGKey(0), 3000)

    alpha_error = jnp.sum(
        jnp.power(log_alpha_n - svi_result.params["alpha_q_log"], 2.0))
    beta_error = jnp.sum(
        jnp.power(log_beta_n - svi_result.params["beta_q_log"], 2.0))

    assert_allclose(alpha_error, 0, atol=0.03)
    assert_allclose(beta_error, 0, atol=0.04)
Esempio n. 21
0
def test_tracegraph_gamma_exponential():
    # exponential-gamma model
    # gamma prior hyperparameter
    alpha0 = 1.0
    # gamma prior hyperparameter
    beta0 = 1.0
    n_data = 2
    data = jnp.array([3.0, 2.0])  # two observations
    alpha_n = alpha0 + n_data  # posterior alpha
    beta_n = beta0 + data.sum()  # posterior beta
    log_alpha_n = jnp.log(alpha_n)
    log_beta_n = jnp.log(beta_n)

    class FakeGamma(dist.Gamma):
        reparametrized_params = []

    def model():
        lambda_latent = numpyro.sample("lambda_latent",
                                       FakeGamma(alpha0, beta0))
        with numpyro.plate("data", len(data)):
            numpyro.sample("obs", dist.Exponential(lambda_latent), obs=data)
        return lambda_latent

    def guide():
        alpha_q_log = numpyro.param("alpha_q_log", log_alpha_n + 0.17)
        beta_q_log = numpyro.param("beta_q_log", log_beta_n - 0.143)
        alpha_q, beta_q = jnp.exp(alpha_q_log), jnp.exp(beta_q_log)
        numpyro.sample("lambda_latent", FakeGamma(alpha_q, beta_q))
        with numpyro.plate("data", len(data)):
            pass

    adam = optim.Adam(step_size=0.0007, b1=0.95, b2=0.999)
    svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())
    svi_result = svi.run(jax.random.PRNGKey(0), 8000)

    alpha_error = jnp.sum(
        jnp.power(log_alpha_n - svi_result.params["alpha_q_log"], 2.0))
    beta_error = jnp.sum(
        jnp.power(log_beta_n - svi_result.params["beta_q_log"], 2.0))

    assert_allclose(alpha_error, 0, atol=0.04)
    assert_allclose(beta_error, 0, atol=0.04)
Esempio n. 22
0
def run_inference(docs, args):
    rng_key = random.PRNGKey(0)
    docs = device_put(docs)

    hyperparams = dict(
        vocab_size=docs.shape[1],
        num_topics=args.num_topics,
        hidden=args.hidden,
        dropout_rate=args.dropout_rate,
        batch_size=args.batch_size,
    )

    optimizer = numpyro.optim.Adam(args.learning_rate)
    svi = SVI(model, guide, optimizer, loss=TraceMeanField_ELBO())

    return svi.run(
        rng_key,
        args.num_steps,
        docs,
        hyperparams,
        is_training=True,
        progress_bar=not args.disable_progbar,
        nn_framework=args.nn_framework,
    )
Esempio n. 23
0
    pml.savefig(f'multicollinear_sum_post_{method}.pdf')
    plt.show()


# Laplace fit

m6_1 = AutoLaplaceApproximation(model)
svi = SVI(model,
          m6_1,
          optim.Adam(0.1),
          Trace_ELBO(),
          leg_left=df.leg_left.values,
          leg_right=df.leg_right.values,
          height=df.height.values,
          br_positive=False)
p6_1, losses = svi.run(random.PRNGKey(0), 2000)
post_laplace = m6_1.sample_posterior(random.PRNGKey(1), p6_1, (1000, ))

analyze_post(post_laplace, 'laplace')

# MCMC fit
# code from p298 (code 9.28) of rethinking2
#https://fehiepsi.github.io/rethinking-numpyro/09-markov-chain-monte-carlo.html

kernel = NUTS(
    model,
    init_strategy=init_to_value(values={
        "a": 10.0,
        "bl": 0.0,
        "br": 0.1,
        "sigma": 1.0
Esempio n. 24
0
def benchmark_hmc(args, features, labels):
    rng_key = random.PRNGKey(1)
    start = time.time()
    # a MAP estimate at the following source
    # https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117
    ref_params = {
        "coefs":
        jnp.array([
            +2.03420663e00,
            -3.53567265e-02,
            -1.49223924e-01,
            -3.07049364e-01,
            -1.00028366e-01,
            -1.46827862e-01,
            -1.64167881e-01,
            -4.20344204e-01,
            +9.47479829e-02,
            -1.12681836e-02,
            +2.64442056e-01,
            -1.22087866e-01,
            -6.00568838e-02,
            -3.79419506e-01,
            -1.06668741e-01,
            -2.97053963e-01,
            -2.05253899e-01,
            -4.69537191e-02,
            -2.78072730e-02,
            -1.43250525e-01,
            -6.77954629e-02,
            -4.34899796e-03,
            +5.90927452e-02,
            +7.23133609e-02,
            +1.38526391e-02,
            -1.24497898e-01,
            -1.50733739e-02,
            -2.68872194e-02,
            -1.80925727e-02,
            +3.47936489e-02,
            +4.03552800e-02,
            -9.98773426e-03,
            +6.20188080e-02,
            +1.15002751e-01,
            +1.32145107e-01,
            +2.69109547e-01,
            +2.45785132e-01,
            +1.19035013e-01,
            -2.59744357e-02,
            +9.94279515e-04,
            +3.39266285e-02,
            -1.44057125e-02,
            -6.95222765e-02,
            -7.52013028e-02,
            +1.21171586e-01,
            +2.29205526e-02,
            +1.47308692e-01,
            -8.34354162e-02,
            -9.34122875e-02,
            -2.97472421e-02,
            -3.03937674e-01,
            -1.70958012e-01,
            -1.59496680e-01,
            -1.88516974e-01,
            -1.20889175e00,
        ])
    }
    if args.algo == "HMC":
        step_size = jnp.sqrt(0.5 / features.shape[0])
        trajectory_length = step_size * args.num_steps
        kernel = HMC(
            model,
            step_size=step_size,
            trajectory_length=trajectory_length,
            adapt_step_size=False,
            dense_mass=args.dense_mass,
        )
        subsample_size = None
    elif args.algo == "NUTS":
        kernel = NUTS(model, dense_mass=args.dense_mass)
        subsample_size = None
    elif args.algo == "HMCECS":
        subsample_size = 1000
        inner_kernel = NUTS(
            model,
            init_strategy=init_to_value(values=ref_params),
            dense_mass=args.dense_mass,
        )
        # note: if num_blocks=100, we'll update 10 index at each MCMC step
        # so it took 50000 MCMC steps to iterative the whole dataset
        kernel = HMCECS(inner_kernel,
                        num_blocks=100,
                        proxy=HMCECS.taylor_proxy(ref_params))
    elif args.algo == "SA":
        # NB: this kernel requires large num_warmup and num_samples
        # and running on GPU is much faster than on CPU
        kernel = SA(model,
                    adapt_state_size=1000,
                    init_strategy=init_to_value(values=ref_params))
        subsample_size = None
    elif args.algo == "FlowHMCECS":
        subsample_size = 1000
        guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8])
        svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
        svi_result = svi.run(random.PRNGKey(2), 2000, features, labels)
        params, losses = svi_result.params, svi_result.losses
        plt.plot(losses)
        plt.show()

        neutra = NeuTraReparam(guide, params)
        neutra_model = neutra.reparam(model)
        neutra_ref_params = {"auto_shared_latent": jnp.zeros(55)}
        # no need to adapt mass matrix if the flow does a good job
        inner_kernel = NUTS(
            neutra_model,
            init_strategy=init_to_value(values=neutra_ref_params),
            adapt_mass_matrix=False,
        )
        kernel = HMCECS(inner_kernel,
                        num_blocks=100,
                        proxy=HMCECS.taylor_proxy(neutra_ref_params))
    else:
        raise ValueError(
            "Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.")
    mcmc = MCMC(kernel,
                num_warmup=args.num_warmup,
                num_samples=args.num_samples)
    mcmc.run(rng_key,
             features,
             labels,
             subsample_size,
             extra_fields=("accept_prob", ))
    print("Mean accept prob:",
          jnp.mean(mcmc.get_extra_fields()["accept_prob"]))
    mcmc.print_summary(exclude_deterministic=False)
    print("\nMCMC elapsed time:", time.time() - start)
Esempio n. 25
0
# ===================
n_epochs = 1_000
lr = 0.01
optimizer = Adam(step_size=lr)

# ===================
# Training
# ===================
# reproducibility
rng_key = random.PRNGKey(42)

# setup svi
svi = SVI(sgp_model, delta_guide, optimizer, loss=Trace_ELBO())

# run svi
svi_results = svi.run(rng_key, n_epochs, X, y.T)

# ===================
# Plot Loss
# ===================
fig, ax = plt.subplots(ncols=1, figsize=(6, 4))
ax.plot(svi_results.losses)
ax.set(title="Loss", xlabel="Iterations", ylabel="Negative Log-Likelihood")
plt.tight_layout()
wandb.log({f"loss": [wandb.Image(plt)]})

wandb.log({f"nll_loss": np.array(svi_results.losses[-1])})
learned_params = delta_guide.median(svi_results.params)
learned_params["x_u"] = svi_results.params["x_u"]

# =================
    bM = numpyro.sample("bM", dist.Normal(0, 0.5))
    bA = numpyro.sample("bA", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a + bM * M + bA * A)
    numpyro.sample("D", dist.Normal(mu, sigma), obs=D)


m5_3 = AutoLaplaceApproximation(model)
svi = SVI(model,
          m5_3,
          optim.Adam(1),
          Trace_ELBO(),
          M=d.M.values,
          A=d.A.values,
          D=d.D.values)
p5_3, losses = svi.run(random.PRNGKey(0), 1000)
post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (1000, ))

# Posterior

param_names = {'a', 'bA', 'bM', 'sigma'}
for p in param_names:
    print(f'posterior for {p}')
    print_summary(post[p], 0.95, False)

# PPC

# call predictive without specifying new data
# so it uses original data
post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (int(1e4), ))
post_pred = Predictive(m5_3.model, post)(random.PRNGKey(2),
Esempio n. 27
0
def test_tracegraph_gaussian_chain(num_latents, num_steps, step_size, atol,
                                   difficulty):
    loc0 = 0.2
    data = jnp.array([-0.1, 0.03, 0.2, 0.1])
    n_data = data.shape[0]
    sum_data = data.sum()
    N = num_latents
    lambdas = [1.5 * (k + 1) / N for k in range(N + 1)]
    lambdas = list(map(lambda x: jnp.array([x]), lambdas))
    lambda_tilde_posts = [lambdas[0]]
    for k in range(1, N):
        lambda_tilde_k = (lambdas[k] * lambda_tilde_posts[k - 1]) / (
            lambdas[k] + lambda_tilde_posts[k - 1])
        lambda_tilde_posts.append(lambda_tilde_k)
    lambda_posts = [
        None
    ]  # this is never used (just a way of shifting the indexing by 1)
    for k in range(1, N):
        lambda_k = lambdas[k] + lambda_tilde_posts[k - 1]
        lambda_posts.append(lambda_k)
    lambda_N_post = (n_data * lambdas[N]) + lambda_tilde_posts[N - 1]
    lambda_posts.append(lambda_N_post)
    target_kappas = [None]
    target_kappas.extend([lambdas[k] / lambda_posts[k] for k in range(1, N)])
    target_mus = [None]
    target_mus.extend([
        loc0 * lambda_tilde_posts[k - 1] / lambda_posts[k]
        for k in range(1, N)
    ])
    target_loc_N = (sum_data * lambdas[N] / lambda_N_post +
                    loc0 * lambda_tilde_posts[N - 1] / lambda_N_post)
    target_mus.append(target_loc_N)
    np.random.seed(0)
    while True:
        mask = np.random.binomial(1, 0.3, (N, ))
        if mask.sum() < 0.4 * N and mask.sum() > 0.5:
            which_nodes_reparam = mask
            break

    class FakeNormal(dist.Normal):
        reparametrized_params = []

    def model(difficulty=0.0):
        next_mean = loc0
        for k in range(1, N + 1):
            latent_dist = dist.Normal(next_mean,
                                      jnp.power(lambdas[k - 1], -0.5))
            loc_latent = numpyro.sample("loc_latent_{}".format(k), latent_dist)
            next_mean = loc_latent

        loc_N = next_mean
        with numpyro.plate("data", data.shape[0]):
            numpyro.sample("obs",
                           dist.Normal(loc_N, jnp.power(lambdas[N], -0.5)),
                           obs=data)
        return loc_N

    def guide(difficulty=0.0):
        previous_sample = None
        for k in reversed(range(1, N + 1)):
            loc_q = numpyro.param(
                f"loc_q_{k}",
                lambda key: target_mus[k] + difficulty *
                (0.1 * random.normal(key) - 0.53),
            )
            log_sig_q = numpyro.param(
                f"log_sig_q_{k}",
                lambda key: -0.5 * jnp.log(lambda_posts[k]) + difficulty *
                (0.1 * random.normal(key) - 0.53),
            )
            sig_q = jnp.exp(log_sig_q)
            kappa_q = None
            if k != N:
                kappa_q = numpyro.param(
                    "kappa_q_%d" % k,
                    lambda key: target_kappas[k] + difficulty *
                    (0.1 * random.normal(key) - 0.53),
                )
            mean_function = loc_q if k == N else kappa_q * previous_sample + loc_q
            node_flagged = True if which_nodes_reparam[k - 1] == 1.0 else False
            Normal = dist.Normal if node_flagged else FakeNormal
            loc_latent = numpyro.sample(f"loc_latent_{k}",
                                        Normal(mean_function, sig_q))
            previous_sample = loc_latent
        return previous_sample

    adam = optim.Adam(step_size=step_size, b1=0.95, b2=0.999)
    svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())
    svi_result = svi.run(jax.random.PRNGKey(0),
                         num_steps,
                         difficulty=difficulty)

    kappa_errors, log_sig_errors, loc_errors = [], [], []
    for k in range(1, N + 1):
        if k != N:
            kappa_error = jnp.sum(
                jnp.power(svi_result.params[f"kappa_q_{k}"] - target_kappas[k],
                          2))
            kappa_errors.append(kappa_error)

        loc_errors.append(
            jnp.sum(
                jnp.power(svi_result.params[f"loc_q_{k}"] - target_mus[k], 2)))
        log_sig_error = jnp.sum(
            jnp.power(
                svi_result.params[f"log_sig_q_{k}"] +
                0.5 * jnp.log(lambda_posts[k]), 2))
        log_sig_errors.append(log_sig_error)

    max_errors = (np.max(loc_errors), np.max(log_sig_errors),
                  np.max(kappa_errors))

    for i in range(3):
        assert_allclose(max_errors[i], 0, atol=atol)
Esempio n. 28
0
def main(args):
    print("Start vanilla HMC...")
    nuts_kernel = NUTS(dual_moon_model)
    mcmc = MCMC(
        nuts_kernel,
        args.num_warmup,
        args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(random.PRNGKey(0))
    mcmc.print_summary()
    vanilla_samples = mcmc.get_samples()['x'].copy()

    guide = AutoBNAFNormal(
        dual_moon_model,
        hidden_factors=[args.hidden_factor, args.hidden_factor])
    svi = SVI(dual_moon_model, guide, optim.Adam(0.003), Trace_ELBO())

    print("Start training guide...")
    svi_result = svi.run(random.PRNGKey(1), args.num_iters)
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(
        random.PRNGKey(2),
        svi_result.params,
        sample_shape=(args.num_samples, ))['x'].copy()

    print("\nStart NeuTra HMC...")
    neutra = NeuTraReparam(guide, svi_result.params)
    neutra_model = neutra.reparam(dual_moon_model)
    nuts_kernel = NUTS(neutra_model)
    mcmc = MCMC(
        nuts_kernel,
        args.num_warmup,
        args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(random.PRNGKey(3))
    mcmc.print_summary()
    zs = mcmc.get_samples(group_by_chain=True)["auto_shared_latent"]
    print("Transform samples into unwarped space...")
    samples = neutra.transform_sample(zs)
    print_summary(samples)
    zs = zs.reshape(-1, 2)
    samples = samples['x'].reshape(-1, 2).copy()

    # make plots

    # guide samples (for plotting)
    guide_base_samples = dist.Normal(jnp.zeros(2),
                                     1.).sample(random.PRNGKey(4), (1000, ))
    guide_trans_samples = neutra.transform_sample(guide_base_samples)['x']

    x1 = jnp.linspace(-3, 3, 100)
    x2 = jnp.linspace(-3, 3, 100)
    X1, X2 = jnp.meshgrid(x1, x2)
    P = jnp.exp(DualMoonDistribution().log_prob(jnp.stack([X1, X2], axis=-1)))

    fig = plt.figure(figsize=(12, 8), constrained_layout=True)
    gs = GridSpec(2, 3, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[1, 0])
    ax3 = fig.add_subplot(gs[0, 1])
    ax4 = fig.add_subplot(gs[1, 1])
    ax5 = fig.add_subplot(gs[0, 2])
    ax6 = fig.add_subplot(gs[1, 2])

    ax1.plot(svi_result.losses[1000:])
    ax1.set_title('Autoguide training loss\n(after 1000 steps)')

    ax2.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2)
    ax2.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nAutoBNAFNormal guide')

    sns.scatterplot(guide_base_samples[:, 0],
                    guide_base_samples[:, 1],
                    ax=ax3,
                    hue=guide_trans_samples[:, 0] < 0.)
    ax3.set(
        xlim=[-3, 3],
        ylim=[-3, 3],
        xlabel='x0',
        ylabel='x1',
        title='AutoBNAFNormal base samples\n(True=left moon; False=right moon)'
    )

    ax4.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(vanilla_samples[:, 0],
                vanilla_samples[:, 1],
                n_levels=30,
                ax=ax4)
    ax4.plot(vanilla_samples[-50:, 0],
             vanilla_samples[-50:, 1],
             'bo-',
             alpha=0.5)
    ax4.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nvanilla HMC sampler')

    sns.scatterplot(zs[:, 0],
                    zs[:, 1],
                    ax=ax5,
                    hue=samples[:, 0] < 0.,
                    s=30,
                    alpha=0.5,
                    edgecolor="none")
    ax5.set(xlim=[-5, 5],
            ylim=[-5, 5],
            xlabel='x0',
            ylabel='x1',
            title='Samples from the\nwarped posterior - p(z)')

    ax6.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6)
    ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2)
    ax6.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nNeuTra HMC sampler')

    plt.savefig("neutra.pdf")
Esempio n. 29
0
def test_cond():
    def model():
        def true_fun(_):
            x = numpyro.sample("x", dist.Normal(4.0))
            numpyro.deterministic("z", x - 4.0)

        def false_fun(_):
            x = numpyro.sample("x", dist.Normal(0.0))
            numpyro.deterministic("z", x)

        cluster = numpyro.sample("cluster", dist.Normal())
        cond(cluster > 0, true_fun, false_fun, None)

    def guide():
        m1 = numpyro.param("m1", 2.0)
        s1 = numpyro.param("s1", 0.1, constraint=dist.constraints.positive)
        m2 = numpyro.param("m2", 2.0)
        s2 = numpyro.param("s2", 0.1, constraint=dist.constraints.positive)

        def true_fun(_):
            numpyro.sample("x", dist.Normal(m1, s1))

        def false_fun(_):
            numpyro.sample("x", dist.Normal(m2, s2))

        cluster = numpyro.sample("cluster", dist.Normal())
        cond(cluster > 0, true_fun, false_fun, None)

    svi = SVI(model, guide, numpyro.optim.Adam(1e-2), Trace_ELBO(num_particles=100))
    svi_result = svi.run(random.PRNGKey(0), num_steps=2500)
    params = svi_result.params

    predictive = Predictive(
        model,
        guide=guide,
        params=params,
        num_samples=1000,
        return_sites=["cluster", "x", "z"],
    )
    result = predictive(random.PRNGKey(0))

    assert result["cluster"].shape == (1000,)
    assert result["x"].shape == (1000,)
    assert result["z"].shape == (1000,)

    mcmc = MCMC(
        NUTS(model),
        num_warmup=500,
        num_samples=2500,
        num_chains=4,
        chain_method="sequential",
    )
    mcmc.run(random.PRNGKey(0))

    x = mcmc.get_samples()["x"]
    assert x.shape == (10_000,)
    assert_allclose(
        [x[x > 2.0].mean(), x[x > 2.0].std(), x[x < 2.0].mean(), x[x < 2.0].std()],
        [4.01, 0.965, -0.01, 0.965],
        atol=0.1,
    )
    assert_allclose([x.mean(), x.std()], [2.0, jnp.sqrt(5.0)], atol=0.5)
Esempio n. 30
0
    plt.show()

# Laplace fit

m6_1 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m6_1,
    optim.Adam(0.1),
    Trace_ELBO(),
    leg_left=df.leg_left.values,
    leg_right=df.leg_right.values,
    height=df.height.values,
    br_positive=False
)
svi_run = svi.run(random.PRNGKey(0), 2000)
p6_1 = svi_run.params
losses = svi_run.losses
post_laplace = m6_1.sample_posterior(random.PRNGKey(1), p6_1, (1000,))

analyze_post(post_laplace, 'laplace')


# MCMC fit
# code from p298 (code 9.28) of rethinking2
#https://fehiepsi.github.io/rethinking-numpyro/09-markov-chain-monte-carlo.html


kernel = NUTS(
    model,
    init_strategy=init_to_value(values={"a": 10.0, "bl": 0.0, "br": 0.1, "sigma": 1.0}),