示例#1
0
    def _hmc_next(step_size, inverse_mass_matrix, vv_state, model_args,
                  model_kwargs, rng_key):
        if potential_fn_gen:
            nonlocal vv_update
            pe_fn = potential_fn_gen(*model_args, **model_kwargs)
            _, vv_update = velocity_verlet(pe_fn, kinetic_fn)

        num_steps = _get_num_steps(step_size, trajectory_len)
        vv_state_new = fori_loop(
            0, num_steps,
            lambda i, val: vv_update(step_size, inverse_mass_matrix, val),
            vv_state)
        energy_old = vv_state.potential_energy + kinetic_fn(
            inverse_mass_matrix, vv_state.r)
        energy_new = vv_state_new.potential_energy + kinetic_fn(
            inverse_mass_matrix, vv_state_new.r)
        delta_energy = energy_new - energy_old
        delta_energy = np.where(np.isnan(delta_energy), np.inf, delta_energy)
        accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0)
        diverging = delta_energy > max_delta_energy
        transition = random.bernoulli(rng_key, accept_prob)
        vv_state, energy = cond(transition, (vv_state_new, energy_new),
                                lambda args: args, (vv_state, energy_old),
                                lambda args: args)
        return vv_state, energy, num_steps, accept_prob, diverging
示例#2
0
    def _hmc_next(step_size, inverse_mass_matrix, vv_state, model_args,
                  model_kwargs, rng_key, trajectory_length):
        if potential_fn_gen:
            nonlocal vv_update, forward_mode_ad
            pe_fn = potential_fn_gen(*model_args, **model_kwargs)
            _, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)

        # no need to spend too many steps if the state z has 0 size (i.e. z is empty)
        if len(inverse_mass_matrix) == 0:
            num_steps = 1
        else:
            num_steps = _get_num_steps(step_size, trajectory_length)
        # makes sure trajectory length is constant, rather than step_size * num_steps
        step_size = trajectory_length / num_steps
        vv_state_new = fori_loop(
            0, num_steps,
            lambda i, val: vv_update(step_size, inverse_mass_matrix, val),
            vv_state)
        energy_old = vv_state.potential_energy + kinetic_fn(
            inverse_mass_matrix, vv_state.r)
        energy_new = vv_state_new.potential_energy + kinetic_fn(
            inverse_mass_matrix, vv_state_new.r)
        delta_energy = energy_new - energy_old
        delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf,
                                 delta_energy)
        accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
        diverging = delta_energy > max_delta_energy
        transition = random.bernoulli(rng_key, accept_prob)
        vv_state, energy = cond(transition, (vv_state_new, energy_new),
                                identity, (vv_state, energy_old), identity)
        return vv_state, energy, num_steps, accept_prob, diverging
示例#3
0
 def get_final_state(model, step_size, num_steps, q_i, p_i):
     vv_init, vv_update = velocity_verlet(model.potential_fn, model.kinetic_fn)
     vv_state = vv_init(q_i, p_i)
     q_f, p_f, _, _ = fori_loop(0, num_steps,
                                lambda i, val: vv_update(step_size, args.m_inv, val),
                                vv_state)
     return (q_f, p_f)
示例#4
0
def test_beta_bernoulli(elbo):
    data = jnp.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1., 1.))
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q", 1.0,
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0,
                               constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

    adam = optim.Adam(0.05)
    svi = SVI(model, guide, adam, elbo)
    svi_state = svi.init(random.PRNGKey(1), data)
    assert_allclose(adam.get_params(svi_state.optim_state)['alpha_q'], 0.)

    def body_fn(i, val):
        svi_state, _ = svi.update(val, data)
        return svi_state

    svi_state = fori_loop(0, 2000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    assert_allclose(params['alpha_q'] / (params['alpha_q'] + params['beta_q']), 0.8, atol=0.05, rtol=0.05)
示例#5
0
    def gibbs_fn(rng_key, gibbs_sites, hmc_sites, pe):
        # get support_sizes of gibbs_sites
        support_sizes_flat, _ = ravel_pytree({k: support_sizes[k] for k in gibbs_sites})
        num_discretes = support_sizes_flat.shape[0]

        rng_key, rng_permute = random.split(rng_key)
        idxs = random.permutation(rng_key, jnp.arange(num_discretes))

        def body_fn(i, val):
            idx = idxs[i]
            support_size = support_sizes_flat[idx]
            rng_key, z, pe = val
            rng_key, z_new, pe_new, log_accept_ratio = proposal_fn(
                rng_key, z, pe, potential_fn=partial(potential_fn, z_hmc=hmc_sites),
                idx=idx, support_size=support_size)
            rng_key, rng_accept = random.split(rng_key)
            # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio
            # and -log(u) ~ exponential(1)
            z, pe = cond(random.exponential(rng_accept) > -log_accept_ratio,
                         (z_new, pe_new), identity,
                         (z, pe), identity)
            return rng_key, z, pe

        init_val = (rng_key, gibbs_sites, pe)
        _, gibbs_sites, pe = fori_loop(0, num_discretes, body_fn, init_val)
        return gibbs_sites, pe
示例#6
0
def main(args):
    # Generate some data.
    data = random.normal(PRNGKey(0), shape=(100, )) + 3.0

    # Construct an SVI object so we can do variational inference on our
    # model/guide pair.
    adam = optim.Adam(args.learning_rate)

    svi = SVI(model, guide, adam, Trace_ELBO(num_particles=100))
    svi_state = svi.init(PRNGKey(0), data)

    # Training loop
    def body_fn(i, val):
        svi_state, loss = svi.update(val, data)
        return svi_state

    svi_state = fori_loop(0, args.num_steps, body_fn, svi_state)

    # Report the final values of the variational parameters
    # in the guide after training.
    params = svi.get_params(svi_state)
    for name, value in params.items():
        print("{} = {}".format(name, value))

    # For this simple (conjugate) model we know the exact posterior. In
    # particular we know that the variational distribution should be
    # centered near 3.0. So let's check this explicitly.
    assert jnp.abs(params["guide_loc"] - 3.0) < 0.1
示例#7
0
def test_uniform_normal():
    true_coef = 0.9
    data = true_coef + random.normal(random.PRNGKey(0), (1000,))

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample('loc', dist.TransformedDistribution(
                dist.Uniform(0, 1), transforms.AffineTransform(0, alpha)))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = AutoDiagonalNormal(model)
    svi = SVI(model, guide, adam, Trace_ELBO())
    svi_state = svi.init(rng_key_init, data)

    def body_fn(i, val):
        svi_state, loss = svi.update(val, data)
        return svi_state

    svi_state = fori_loop(0, 1000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    median = guide.median(params)
    assert_allclose(median['loc'], true_coef, rtol=0.05)
    # test .quantile method
    median = guide.quantiles(params, [0.2, 0.5])
    assert_allclose(median['loc'][1], true_coef, rtol=0.1)
示例#8
0
    def run(self,
            rng_key,
            num_steps,
            *args,
            return_last=True,
            progbar=True,
            **kwargs):
        def bodyfn(i, info):
            svgd_state, losses = info
            svgd_state, loss = self.update(svgd_state, *args, **kwargs)
            losses = ops.index_update(losses, i, loss)
            return svgd_state, losses

        svgd_state = self.init(rng_key, *args, **kwargs)
        losses = np.empty((num_steps, ))
        if not progbar:
            svgd_state, losses = fori_loop(0, num_steps, bodyfn,
                                           (svgd_state, losses))
        else:
            with tqdm.trange(num_steps) as t:
                for i in t:
                    svgd_state, losses = jax.jit(bodyfn)(i,
                                                         (svgd_state, losses))
                    t.set_description('SVGD {:.5}'.format(losses[i]),
                                      refresh=False)
                    t.update()
        loss_res = losses[-1] if return_last else losses
        return svgd_state, loss_res
示例#9
0
def test_beta_bernoulli(auto_class):
    data = jnp.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T

    def model(data):
        f = numpyro.sample("beta", dist.Beta(jnp.ones(2), jnp.ones(2)))
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    adam = optim.Adam(0.01)
    guide = auto_class(model, init_loc_fn=init_strategy)
    svi = SVI(model, guide, adam, Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(1), data)

    def body_fn(i, val):
        svi_state, loss = svi.update(val, data)
        return svi_state

    svi_state = fori_loop(0, 3000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(
        random.PRNGKey(1), params, sample_shape=(1000,)
    )
    assert_allclose(jnp.mean(posterior_samples["beta"], 0), true_coefs, atol=0.05)

    # Predictive can be instantiated from posterior samples...
    predictive = Predictive(model, posterior_samples=posterior_samples)
    predictive_samples = predictive(random.PRNGKey(1), None)
    assert predictive_samples["obs"].shape == (1000, 2)

    # ... or from the guide + params
    predictive = Predictive(model, guide=guide, params=params, num_samples=1000)
    predictive_samples = predictive(random.PRNGKey(1), None)
    assert predictive_samples["obs"].shape == (1000, 2)
示例#10
0
def test_beta_bernoulli(auto_class):
    data = np.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T

    def model(data):
        f = numpyro.sample('beta', dist.Beta(np.ones(2), np.ones(2)))
        numpyro.sample('obs', dist.Bernoulli(f), obs=data)

    adam = optim.Adam(0.01)
    guide = auto_class(model)
    svi = SVI(model, guide, elbo, adam)
    svi_state = svi.init(random.PRNGKey(1),
                         model_args=(data, ),
                         guide_args=(data, ))

    def body_fn(i, val):
        svi_state, loss = svi.update(val,
                                     model_args=(data, ),
                                     guide_args=(data, ))
        return svi_state

    svi_state = fori_loop(0, 2000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    true_coefs = (np.sum(data, axis=0) + 1) / (data.shape[0] + 2)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1),
                                               params,
                                               sample_shape=(1000, ))
    assert_allclose(np.mean(posterior_samples['beta'], 0),
                    true_coefs,
                    atol=0.04)
示例#11
0
def test_logistic_regression(auto_class):
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = jnp.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = auto_class(model, init_strategy=init_strategy)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(rng_key_init, data, labels)

    def body_fn(i, val):
        svi_state, loss = svi.update(val, data, labels)
        return svi_state

    svi_state = fori_loop(0, 2000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    if auto_class not in (AutoIAFNormal, AutoBNAFNormal):
        median = guide.median(params)
        assert_allclose(median['coefs'], true_coefs, rtol=0.1)
        # test .quantile method
        median = guide.quantiles(params, [0.2, 0.5])
        assert_allclose(median['coefs'][1], true_coefs, rtol=0.1)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,))
    assert_allclose(jnp.mean(posterior_samples['coefs'], 0), true_coefs, rtol=0.1)
示例#12
0
def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx,
                             support_size):
    # idx: current index of `z_discrete_flat` to update
    # support_size: support size of z_discrete at the index idx

    z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)
    # Here we loop over the support of z_flat[idx] to get z_new
    # XXX: we can't vmap potential_fn over all proposals and sample from the conditional
    # categorical distribution because support_size is a traced value, i.e. its value
    # might change across different discrete variables;
    # so here we will loop over all proposals and use an online scheme to sample from
    # the conditional categorical distribution
    body_fn = partial(
        _discrete_gibbs_proposal_body_fn,
        z_discrete_flat,
        unravel_fn,
        pe,
        potential_fn,
        idx,
    )
    init_val = (rng_key, z_discrete, pe, jnp.array(0.0))
    rng_key, z_new, pe_new, _ = fori_loop(0, support_size - 1, body_fn,
                                          init_val)
    log_accept_ratio = jnp.array(0.0)
    return rng_key, z_new, pe_new, log_accept_ratio
示例#13
0
def test_uniform_normal():
    true_coef = 0.9
    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    adam = optim.Adam(0.01)
    rng_init = random.PRNGKey(1)
    guide = AutoDiagonalNormal(model)
    svi = SVI(model, guide, elbo, adam)
    svi_state = svi.init(rng_init, model_args=(data, ), guide_args=(data, ))

    def body_fn(i, val):
        svi_state, loss = svi.update(val,
                                     model_args=(data, ),
                                     guide_args=(data, ))
        return svi_state

    svi_state = fori_loop(0, 1000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    median = guide.median(params)
    assert_allclose(median['loc'], true_coef, rtol=0.05)
    # test .quantile method
    median = guide.quantiles(params, [0.2, 0.5])
    assert_allclose(median['loc'][1], true_coef, rtol=0.1)
示例#14
0
def _discrete_modified_gibbs_proposal(rng_key,
                                      z_discrete,
                                      pe,
                                      potential_fn,
                                      idx,
                                      support_size,
                                      stay_prob=0.0):
    assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1
    z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)
    body_fn = partial(
        _discrete_gibbs_proposal_body_fn,
        z_discrete_flat,
        unravel_fn,
        pe,
        potential_fn,
        idx,
    )
    # like gibbs_step but here, weight of the current value is 0
    init_val = (rng_key, z_discrete, pe, jnp.array(-jnp.inf))
    rng_key, z_new, pe_new, log_weight_sum = fori_loop(0, support_size - 1,
                                                       body_fn, init_val)
    rng_key, rng_stay = random.split(rng_key)
    z_new, pe_new = cond(
        random.bernoulli(rng_stay, stay_prob),
        (z_discrete, pe),
        identity,
        (z_new, pe_new),
        identity,
    )
    # here we calculate the MH correction: (1 - P(z)) / (1 - P(z_new))
    # where 1 - P(z) ~ weight_sum
    # and 1 - P(z_new) ~ 1 + weight_sum - z_new_weight
    log_accept_ratio = log_weight_sum - jnp.log(
        jnp.exp(log_weight_sum) - jnp.expm1(pe - pe_new))
    return rng_key, z_new, pe_new, log_accept_ratio
示例#15
0
def main(args):
    # Generate some data.
    data = random.normal(PRNGKey(0), shape=(100,)) + 3.0

    # Construct an SVI object so we can do variational inference on our
    # model/guide pair.
    opt_init, opt_update, get_params = optimizers.adam(args.learning_rate)
    svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params)
    rng, rng_init = random.split(PRNGKey(0))
    opt_state, _ = svi_init(rng_init, model_args=(data,))

    # Training loop
    def body_fn(i, val):
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_, model_args=(data,))
        return opt_state_, rng_

    opt_state, _ = fori_loop(0, args.num_steps, body_fn, (opt_state, rng))

    # Report the final values of the variational parameters
    # in the guide after training.
    params = get_params(opt_state)
    for name, value in params.items():
        print("{} = {}".format(name, value))

    # For this simple (conjugate) model we know the exact posterior. In
    # particular we know that the variational distribution should be
    # centered near 3.0. So let's check this explicitly.
    assert np.abs(params["guide_loc"] - 3.0) < 0.1
示例#16
0
 def get_cov(x):
     wc_init, wc_update, wc_final = welford_covariance(
         diagonal=diagonal)
     wc_state = wc_init(3)
     wc_state = fori_loop(0, 2000, lambda i, val: wc_update(x[i], val),
                          wc_state)
     cov, cov_inv_sqrt = wc_final(wc_state, regularize=regularize)
     return cov, cov_inv_sqrt
示例#17
0
def test_mnist_data_load():
    def mean_pixels(i, mean_pix):
        batch, _ = fetch(i, idx)
        return mean_pix + jnp.sum(batch) / batch.size

    init, fetch = load_dataset(MNIST, batch_size=128, split='train')
    num_batches, idx = init()
    assert fori_loop(0, num_batches, mean_pixels, jnp.float32(0.)) / num_batches < 0.15
示例#18
0
    def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
        # convert to unconstrained values
        z_hmc = {
            k: biject_to(prototype_trace[k]["fn"].support).inv(v)
            for k, v in hmc_sites.items()
            if k in prototype_trace and prototype_trace[k]["type"] == "sample"
        }
        use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0
        wrapped_model = _wrap_model(model)
        if use_enum:
            from numpyro.contrib.funsor import config_enumerate, enum

            wrapped_model = enum(config_enumerate(wrapped_model),
                                 -max_plate_nesting - 1)

        def potential_fn(z_discrete):
            model_kwargs_ = model_kwargs.copy()
            model_kwargs_["_gibbs_sites"] = z_discrete
            return potential_energy(wrapped_model,
                                    model_args,
                                    model_kwargs_,
                                    z_hmc,
                                    enum=use_enum)

        # get support_sizes of gibbs_sites
        support_sizes_flat, _ = ravel_pytree(
            {k: support_sizes[k]
             for k in gibbs_sites})
        num_discretes = support_sizes_flat.shape[0]

        rng_key, rng_permute = random.split(rng_key)
        idxs = random.permutation(rng_key, jnp.arange(num_discretes))

        def body_fn(i, val):
            idx = idxs[i]
            support_size = support_sizes_flat[idx]
            rng_key, z, pe = val
            rng_key, z_new, pe_new, log_accept_ratio = proposal_fn(
                rng_key,
                z,
                pe,
                potential_fn=potential_fn,
                idx=idx,
                support_size=support_size)
            rng_key, rng_accept = random.split(rng_key)
            # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio
            # and -log(u) ~ exponential(1)
            z, pe = cond(
                random.exponential(rng_accept) > -log_accept_ratio,
                (z_new, pe_new), identity, (z, pe), identity)
            return rng_key, z, pe

        init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites))
        _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val)
        return gibbs_sites
示例#19
0
def test_logistic_regression(auto_class, Elbo):
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1.0, dim + 1.0)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = numpyro.sample("coefs",
                               dist.Normal(0, 1).expand([dim]).to_event())
        logits = numpyro.deterministic("logits", jnp.sum(coefs * data,
                                                         axis=-1))
        with numpyro.plate("N", len(data)):
            return numpyro.sample("obs",
                                  dist.Bernoulli(logits=logits),
                                  obs=labels)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = auto_class(model, init_loc_fn=init_strategy)
    svi = SVI(model, guide, adam, Elbo())
    svi_state = svi.init(rng_key_init, data, labels)

    # smoke test if analytic KL is used
    if auto_class is AutoNormal and Elbo is TraceMeanField_ELBO:
        _, mean_field_loss = svi.update(svi_state, data, labels)
        svi.loss = Trace_ELBO()
        _, elbo_loss = svi.update(svi_state, data, labels)
        svi.loss = TraceMeanField_ELBO()
        assert abs(mean_field_loss - elbo_loss) > 0.5

    def body_fn(i, val):
        svi_state, loss = svi.update(val, data, labels)
        return svi_state

    svi_state = fori_loop(0, 2000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    if auto_class not in (AutoDAIS, AutoIAFNormal, AutoBNAFNormal):
        median = guide.median(params)
        assert_allclose(median["coefs"], true_coefs, rtol=0.1)
        # test .quantile method
        if auto_class is not AutoDelta:
            median = guide.quantiles(params, [0.2, 0.5])
            assert_allclose(median["coefs"][1], true_coefs, rtol=0.1)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1),
                                               params,
                                               sample_shape=(1000, ))
    expected_coefs = jnp.array([0.97, 2.05, 3.18])
    assert_allclose(jnp.mean(posterior_samples["coefs"], 0),
                    expected_coefs,
                    rtol=0.1)
示例#20
0
def test_beta_bernoulli(auto_class):
    data = jnp.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T
    N = len(data)

    def model(data):
        f = numpyro.sample("beta",
                           dist.Beta(jnp.ones(2), jnp.ones(2)).to_event())
        with numpyro.plate("N", N):
            numpyro.sample("obs", dist.Bernoulli(f).to_event(1), obs=data)

    adam = optim.Adam(0.01)
    if auto_class == AutoDAIS:
        guide = auto_class(model,
                           init_loc_fn=init_strategy,
                           base_dist="cholesky")
    else:
        guide = auto_class(model, init_loc_fn=init_strategy)
    svi = SVI(model, guide, adam, Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(1), data)

    def body_fn(i, val):
        svi_state, loss = svi.update(val, data)
        return svi_state

    svi_state = fori_loop(0, 3000, body_fn, svi_state)
    params = svi.get_params(svi_state)

    true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1),
                                               params,
                                               sample_shape=(1000, ))
    posterior_mean = jnp.mean(posterior_samples["beta"], 0)
    assert_allclose(posterior_mean, true_coefs, atol=0.05)

    if auto_class not in [AutoDAIS, AutoDelta, AutoIAFNormal, AutoBNAFNormal]:
        quantiles = guide.quantiles(params, [0.2, 0.5, 0.8])
        assert quantiles["beta"].shape == (3, 2)

    # Predictive can be instantiated from posterior samples...
    predictive = Predictive(model, posterior_samples=posterior_samples)
    predictive_samples = predictive(random.PRNGKey(1), None)
    assert predictive_samples["obs"].shape == (1000, N, 2)

    # ... or from the guide + params
    predictive = Predictive(model,
                            guide=guide,
                            params=params,
                            num_samples=1000)
    predictive_samples = predictive(random.PRNGKey(1), None)
    assert predictive_samples["obs"].shape == (1000, N, 2)
示例#21
0
    def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
        z_hmc = hmc_sites
        use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0
        if use_enum:
            from numpyro.contrib.funsor import config_enumerate, enum

            wrapped_model_ = enum(config_enumerate(wrapped_model),
                                  -max_plate_nesting - 1)
        else:
            wrapped_model_ = wrapped_model

        def potential_fn(z_discrete):
            model_kwargs_ = model_kwargs.copy()
            model_kwargs_["_gibbs_sites"] = z_discrete
            return potential_energy(wrapped_model_,
                                    model_args,
                                    model_kwargs_,
                                    z_hmc,
                                    enum=use_enum)

        # get support_sizes of gibbs_sites
        support_sizes_flat, _ = ravel_pytree(
            {k: support_sizes[k]
             for k in gibbs_sites})
        num_discretes = support_sizes_flat.shape[0]

        rng_key, rng_permute = random.split(rng_key)
        idxs = random.permutation(rng_key, jnp.arange(num_discretes))

        def body_fn(i, val):
            idx = idxs[i]
            support_size = support_sizes_flat[idx]
            rng_key, z, pe = val
            rng_key, z_new, pe_new, log_accept_ratio = proposal_fn(
                rng_key,
                z,
                pe,
                potential_fn=potential_fn,
                idx=idx,
                support_size=support_size)
            rng_key, rng_accept = random.split(rng_key)
            # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio
            # and -log(u) ~ exponential(1)
            z, pe = cond(
                random.exponential(rng_accept) > -log_accept_ratio,
                (z_new, pe_new), identity, (z, pe), identity)
            return rng_key, z, pe

        init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites))
        _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val)
        return gibbs_sites
示例#22
0
 def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng):
     num_steps = _get_num_steps(step_size, trajectory_len)
     vv_state_new = fori_loop(0, num_steps,
                              lambda i, val: vv_update(step_size, inverse_mass_matrix, val),
                              vv_state)
     energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r)
     energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r)
     delta_energy = energy_new - energy_old
     delta_energy = np.where(np.isnan(delta_energy), np.inf, delta_energy)
     accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0)
     transition = random.bernoulli(rng, accept_prob)
     vv_state = cond(transition,
                     vv_state_new, lambda state: state,
                     vv_state, lambda state: state)
     return vv_state, num_steps, accept_prob
示例#23
0
    def init_kernel(init_samples,
                    num_warmup_steps,
                    step_size=1.0,
                    num_steps=None,
                    adapt_step_size=True,
                    adapt_mass_matrix=True,
                    diag_mass=True,
                    target_accept_prob=0.8,
                    run_warmup=True,
                    rng=PRNGKey(0)):
        step_size = float(step_size)
        nonlocal trajectory_length, momentum_generator, wa_update

        if num_steps is None:
            trajectory_length = 2 * math.pi
        else:
            trajectory_length = num_steps * step_size

        z = init_samples
        z_flat, unravel_fn = ravel_pytree(z)
        momentum_generator = partial(_sample_momentum, unravel_fn)

        find_reasonable_ss = partial(find_reasonable_step_size, potential_fn,
                                     kinetic_fn, momentum_generator)

        wa_init, wa_update = warmup_adapter(
            num_warmup_steps,
            find_reasonable_step_size=find_reasonable_ss,
            adapt_step_size=adapt_step_size,
            adapt_mass_matrix=adapt_mass_matrix,
            diag_mass=diag_mass,
            target_accept_prob=target_accept_prob)

        rng_hmc, rng_wa = random.split(rng)
        wa_state = wa_init(z, rng_wa, mass_matrix_size=np.size(z_flat))
        r = momentum_generator(wa_state.inverse_mass_matrix, rng)
        vv_state = vv_init(z, r)
        hmc_state = HMCState(vv_state.z, vv_state.z_grad,
                             vv_state.potential_energy, 0, 0.,
                             wa_state.step_size, wa_state.inverse_mass_matrix,
                             rng_hmc)

        if run_warmup:
            hmc_state, _ = fori_loop(0, num_warmup_steps, warmup_update,
                                     (hmc_state, wa_state))
            return hmc_state
        else:
            return hmc_state, wa_state, warmup_update
示例#24
0
def test_laplace_approximation_warning():
    def model(x, y):
        a = numpyro.sample("a", dist.Normal(0, 10))
        b = numpyro.sample("b", dist.Normal(0, 10), sample_shape=(3,))
        mu = a + b[0] * x + b[1] * x ** 2 + b[2] * x ** 3
        numpyro.sample("y", dist.Normal(mu, 0.001), obs=y)

    x = random.normal(random.PRNGKey(0), (3,))
    y = 1 + 2 * x + 3 * x ** 2 + 4 * x ** 3
    guide = AutoLaplaceApproximation(model)
    svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), x=x, y=y)
    init_state = svi.init(random.PRNGKey(0))
    svi_state = fori_loop(0, 10000, lambda i, val: svi.update(val)[0], init_state)
    params = svi.get_params(svi_state)
    with pytest.warns(UserWarning, match="Hessian of log posterior"):
        guide.sample_posterior(random.PRNGKey(1), params)
示例#25
0
    def _inverse(self, y):
        """
        :param numpy.ndarray y: the output of the transform to be inverted
        """

        # NOTE: Inversion is an expensive operation that scales in the dimension of the input
        def _update_x(i, x):
            mean, log_scale = self.arn(x)
            inverse_scale = jnp.exp(
                -_clamp_preserve_gradients(log_scale,
                                           min=self.log_scale_min_clip,
                                           max=self.log_scale_max_clip))
            x = (y - mean) * inverse_scale
            return x

        x = fori_loop(0, y.shape[-1], _update_x, jnp.zeros(y.shape))
        return x
示例#26
0
def test_logistic_regression(auto_class):
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = np.arange(1., dim + 1.)
    logits = np.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim)))
        logits = np.sum(coefs * data, axis=-1)
        return sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    opt_init, opt_update, get_params = optimizers.adam(0.01)
    rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3)
    guide = auto_class(rng_guide, model, get_params)
    svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update,
                                  get_params)
    opt_state, constrain_fn = svi_init(rng_init,
                                       model_args=(data, labels),
                                       guide_args=(data, labels))

    def body_fn(i, val):
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(i,
                                            rng_,
                                            opt_state_,
                                            model_args=(data, labels),
                                            guide_args=(data, labels))
        return opt_state_, rng_

    opt_state, _ = fori_loop(0, 1000, body_fn, (opt_state, rng_train))
    if auto_class is not AutoIAFNormal:
        median = guide.median(opt_state)
        assert_allclose(median['coefs'], true_coefs, rtol=0.1)
        # test .quantile method
        median = guide.quantiles(opt_state, [0.2, 0.5])
        assert_allclose(median['coefs'][1], true_coefs, rtol=0.1)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1),
                                               opt_state,
                                               sample_shape=(1000, ))
    assert_allclose(np.mean(posterior_samples['coefs'], 0),
                    true_coefs,
                    rtol=0.1)
示例#27
0
def test_dynamic_constraints():
    true_coef = 0.9
    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))

    def model(data):
        # NB: model's constraints will play no effect
        loc = param('loc', 0., constraint=constraints.interval(0, 0.5))
        sample('obs', dist.Normal(loc, 0.1), obs=data)

    def guide():
        alpha = param('alpha', 0.5, constraint=constraints.unit_interval)
        param('loc', 0, constraint=constraints.interval(0, alpha))

    opt_init, opt_update, get_params = optimizers.adam(0.05)
    svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update,
                                  get_params)
    rng_init, rng_train = random.split(random.PRNGKey(1))
    opt_state, constrain_fn = svi_init(rng_init, model_args=(data, ))

    def body_fn(i, val):
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(i,
                                            rng_,
                                            opt_state_,
                                            model_args=(data, ))
        return opt_state_, rng_

    opt_state, rng = fori_loop(0, 300, body_fn, (opt_state, rng_train))
    params = get_param(opt_state,
                       model,
                       guide,
                       get_params,
                       constrain_fn,
                       rng,
                       guide_args=())
    assert_allclose(params['loc'], true_coef, atol=0.05)
示例#28
0
    def sample(self, state, model_args, model_kwargs):
        model_kwargs = {} if model_kwargs is None else model_kwargs
        num_discretes = self._support_sizes_flat.shape[0]

        def potential_fn(z_gibbs, z_hmc):
            return self.inner_kernel._potential_fn_gen(*model_args,
                                                       _gibbs_sites=z_gibbs,
                                                       **model_kwargs)(z_hmc)

        def update_discrete(idx, rng_key, hmc_state, z_discrete, ke_discrete,
                            delta_pe_sum):
            # Algo 1, line 19: get a new discrete proposal
            (
                rng_key,
                z_discrete_new,
                pe_new,
                log_accept_ratio,
            ) = self._discrete_proposal_fn(
                rng_key,
                z_discrete,
                hmc_state.potential_energy,
                partial(potential_fn, z_hmc=hmc_state.z),
                idx,
                self._support_sizes_flat[idx],
            )
            # Algo 1, line 20: depending on reject or refract, we will update
            # the discrete variable and its corresponding kinetic energy. In case of
            # refract, we will need to update the potential energy and its grad w.r.t. hmc_state.z
            ke_discrete_i_new = ke_discrete[idx] + log_accept_ratio
            grad_ = jacfwd if self.inner_kernel._forward_mode_differentiation else grad
            z_discrete, pe, ke_discrete_i, z_grad = lax.cond(
                ke_discrete_i_new > 0,
                (z_discrete_new, pe_new, ke_discrete_i_new),
                lambda vals: vals + (grad_(partial(potential_fn, vals[0]))
                                     (hmc_state.z), ),
                (
                    z_discrete,
                    hmc_state.potential_energy,
                    ke_discrete[idx],
                    hmc_state.z_grad,
                ),
                identity,
            )

            delta_pe_sum = delta_pe_sum + pe - hmc_state.potential_energy
            ke_discrete = ops.index_update(ke_discrete, idx, ke_discrete_i)
            hmc_state = hmc_state._replace(potential_energy=pe, z_grad=z_grad)
            return rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum

        def update_continuous(hmc_state, z_discrete):
            model_kwargs_ = model_kwargs.copy()
            model_kwargs_["_gibbs_sites"] = z_discrete
            hmc_state_new = self.inner_kernel.sample(hmc_state, model_args,
                                                     model_kwargs_)

            # each time a sub-trajectory is performed, we need to reset i and adapt_state
            # (we will only update them at the end of HMCGibbs step)
            # For `num_steps`, we will record its cumulative sum for diagnostics
            hmc_state = hmc_state_new._replace(
                i=hmc_state.i,
                adapt_state=hmc_state.adapt_state,
                num_steps=hmc_state.num_steps + hmc_state_new.num_steps,
            )
            return hmc_state

        def body_fn(i, vals):
            (
                rng_key,
                hmc_state,
                z_discrete,
                ke_discrete,
                delta_pe_sum,
                arrival_times,
            ) = vals
            idx = jnp.argmin(arrival_times)
            # NB: length of each sub-trajectory is scaled from the current min(arrival_times)
            # (see the note at total_time below)
            trajectory_length = arrival_times[idx] * time_unit
            arrival_times = arrival_times - arrival_times[idx]
            arrival_times = ops.index_update(arrival_times, idx, 1.0)

            # this is a trick, so that in a sub-trajectory of HMC, we always accept the new proposal
            pe = jnp.inf
            hmc_state = hmc_state._replace(trajectory_length=trajectory_length,
                                           potential_energy=pe)
            # Algo 1, line 7: perform a sub-trajectory
            hmc_state = update_continuous(hmc_state, z_discrete)
            # Algo 1, line 8: perform a discrete update
            rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum = update_discrete(
                idx, rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum)
            return (
                rng_key,
                hmc_state,
                z_discrete,
                ke_discrete,
                delta_pe_sum,
                arrival_times,
            )

        z_discrete = {
            k: v
            for k, v in state.z.items() if k not in state.hmc_state.z
        }
        rng_key, rng_ke, rng_time, rng_r, rng_accept = random.split(
            state.rng_key, 5)
        # Algo 1, line 2: sample discrete kinetic energy
        ke_discrete = random.exponential(rng_ke, (num_discretes, ))
        # Algo 1, line 4 and 5: sample the initial amount of time that each discrete site visits
        # the point 0/1. The logic in GetStepSizesNSteps(...) is more complicated but does
        # the same job: the sub-trajectory length eta_t * M_t is the lag between two arrival time.
        arrival_times = random.uniform(rng_time, (num_discretes, ))
        # compute the amount of time to make `num_discrete_updates` discrete updates
        total_time = (self._num_discrete_updates -
                      1) // num_discretes + jnp.sort(arrival_times)[
                          (self._num_discrete_updates - 1) % num_discretes]
        # NB: total_time can be different from the HMC trajectory length, so we need to scale
        # the time unit so that total_time * time_unit = hmc_trajectory_length
        time_unit = state.hmc_state.trajectory_length / total_time

        # Algo 1, line 2: sample hmc momentum
        r = momentum_generator(state.hmc_state.r,
                               state.hmc_state.adapt_state.mass_matrix_sqrt,
                               rng_r)
        hmc_state = state.hmc_state._replace(r=r, num_steps=0)
        hmc_ke = euclidean_kinetic_energy(
            hmc_state.adapt_state.inverse_mass_matrix, r)
        # Algo 1, line 10: compute the initial energy
        energy_old = hmc_ke + hmc_state.potential_energy

        # Algo 1, line 3: set initial values
        delta_pe_sum = 0.0
        init_val = (
            rng_key,
            hmc_state,
            z_discrete,
            ke_discrete,
            delta_pe_sum,
            arrival_times,
        )
        # Algo 1, line 6-9: perform the update loop
        rng_key, hmc_state_new, z_discrete_new, _, delta_pe_sum, _ = fori_loop(
            0, self._num_discrete_updates, body_fn, init_val)
        # Algo 1, line 10: compute the proposal energy
        hmc_ke = euclidean_kinetic_energy(
            hmc_state.adapt_state.inverse_mass_matrix, hmc_state_new.r)
        energy_new = hmc_ke + hmc_state_new.potential_energy
        # Algo 1, line 11: perform MH correction
        delta_energy = energy_new - energy_old - delta_pe_sum
        delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf,
                                 delta_energy)
        accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)

        # record the correct new num_steps
        hmc_state = hmc_state._replace(num_steps=hmc_state_new.num_steps)
        # reset the trajectory length
        hmc_state_new = hmc_state_new._replace(
            trajectory_length=hmc_state.trajectory_length)
        hmc_state, z_discrete = cond(
            random.bernoulli(rng_key, accept_prob),
            (hmc_state_new, z_discrete_new),
            identity,
            (hmc_state, z_discrete),
            identity,
        )

        # perform hmc adapting (similar to the implementation in hmc)
        adapt_state = cond(
            hmc_state.i < self._num_warmup,
            (hmc_state.i, accept_prob, (hmc_state.z, ), hmc_state.adapt_state),
            lambda args: self._wa_update(*args),
            hmc_state.adapt_state,
            identity,
        )

        itr = hmc_state.i + 1
        n = jnp.where(hmc_state.i < self._num_warmup, itr,
                      itr - self._num_warmup)
        mean_accept_prob_prev = state.hmc_state.mean_accept_prob
        mean_accept_prob = (mean_accept_prob_prev +
                            (accept_prob - mean_accept_prob_prev) / n)
        hmc_state = hmc_state._replace(
            i=itr,
            accept_prob=accept_prob,
            mean_accept_prob=mean_accept_prob,
            adapt_state=adapt_state,
        )

        z = {**z_discrete, **hmc_state.z}
        return MixedHMCState(z, hmc_state, rng_key, accept_prob)
示例#29
0
    def init_kernel(init_params,
                    num_warmup,
                    step_size=1.0,
                    adapt_step_size=True,
                    adapt_mass_matrix=True,
                    dense_mass=False,
                    target_accept_prob=0.8,
                    trajectory_length=2*math.pi,
                    max_tree_depth=10,
                    run_warmup=True,
                    progbar=True,
                    rng=PRNGKey(0)):
        """
        Initializes the HMC sampler.

        :param init_params: Initial parameters to begin sampling. The type must
            be consistent with the input type to `potential_fn`.
        :param int num_warmup_steps: Number of warmup steps; samples generated
            during warmup are discarded.
        :param float step_size: Determines the size of a single step taken by the
            verlet integrator while computing the trajectory using Hamiltonian
            dynamics. If not specified, it will be set to 1.
        :param bool adapt_step_size: A flag to decide if we want to adapt step_size
            during warm-up phase using Dual Averaging scheme.
        :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
            matrix during warm-up phase using Welford scheme.
        :param bool dense_mass: A flag to decide if mass matrix is dense or
            diagonal (default when ``dense_mass=False``)
        :param float target_accept_prob: Target acceptance probability for step size
            adaptation using Dual Averaging. Increasing this value will lead to a smaller
            step size, hence the sampling will be slower but more robust. Default to 0.8.
        :param float trajectory_length: Length of a MCMC trajectory for HMC. Default
            value is :math:`2\\pi`.
        :param int max_tree_depth: Max depth of the binary tree created during the doubling
            scheme of NUTS sampler. Defaults to 10.
        :param bool run_warmup: Flag to decide whether warmup is run. If ``True``,
            `init_kernel` returns an initial :data:`HMCState` that can be used to
            generate samples using MCMC. Else, returns the arguments and callable
            that does the initial adaptation.
        :param bool progbar: Whether to enable progress bar updates. Defaults to
            ``True``.
        :param bool heuristic_step_size: If ``True``, a coarse grained adjustment of
            step size is done at the beginning of each adaptation window to achieve
            `target_acceptance_prob`.
        :param jax.random.PRNGKey rng: random key to be used as the source of
            randomness.
        """
        step_size = float(step_size)
        nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps
        wa_steps = num_warmup
        trajectory_len = float(trajectory_length)
        max_treedepth = max_tree_depth
        z = init_params
        z_flat, unravel_fn = ravel_pytree(z)
        momentum_generator = partial(_sample_momentum, unravel_fn)

        find_reasonable_ss = partial(find_reasonable_step_size,
                                     potential_fn, kinetic_fn,
                                     momentum_generator)

        wa_init, wa_update = warmup_adapter(num_warmup,
                                            adapt_step_size=adapt_step_size,
                                            adapt_mass_matrix=adapt_mass_matrix,
                                            dense_mass=dense_mass,
                                            target_accept_prob=target_accept_prob,
                                            find_reasonable_step_size=find_reasonable_ss)

        rng_hmc, rng_wa = random.split(rng)
        wa_state = wa_init(z, rng_wa, step_size, mass_matrix_size=np.size(z_flat))
        r = momentum_generator(wa_state.mass_matrix_sqrt, rng)
        vv_state = vv_init(z, r)
        hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., 0.,
                             wa_state, rng_hmc)

        if run_warmup and num_warmup > 0:
            # JIT if progress bar updates not required
            if not progbar:
                hmc_state = fori_loop(0, num_warmup, lambda *args: sample_kernel(args[1]), hmc_state)
            else:
                with tqdm.trange(num_warmup, desc='warmup') as t:
                    for i in t:
                        hmc_state = sample_kernel(hmc_state)
                        t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=False)
        return hmc_state