def fit_advi(model, num_iter, learning_rate=0.01, seed=0):
    """Automatic Differentiation Variational Inference using a Normal variational distribution
    with a diagonal covariance matrix.

    Args:
        model: a NumPyro's model function
        num_iter: number of iterations of gradient descent (Adam)
        learning_rate: the step size for the Adam algorithm (default: {0.01})
        seed: random seed (default: {0})

    Returns:
        a set of results of type ADVIResults
    """
    rng_key = random.PRNGKey(seed)
    adam = Adam(learning_rate)
    # Automatically create a variational distribution (aka "guide" in Pyro's terminology)
    guide = AutoDiagonalNormal(model)
    svi = SVI(model, guide, adam, AutoContinuousELBO())
    svi_state = svi.init(rng_key)

    # Run optimization
    last_state, losses = lax.scan(lambda state, i: svi.update(state),
                                  svi_state, np.zeros(num_iter))
    results = ADVIResults(svi=svi,
                          guide=guide,
                          state=last_state,
                          losses=losses)
    return results
Example #2
0
def train_model(rng,
                rng_suite,
                model,
                guide,
                data,
                batch_size,
                num_data,
                dp_scale,
                num_epochs,
                clipping_threshold=1.):
    """ trains a given model using DPSVI and the globally defined parameters and data """

    optimizer = Adam(1e-3)

    svi = DPSVI(model,
                guide,
                optimizer,
                Trace_ELBO(),
                num_obs_total=num_data,
                clipping_threshold=clipping_threshold,
                dp_scale=dp_scale,
                rng_suite=rng_suite)

    return _train_model(rng, rng_suite, svi, data, batch_size, num_data,
                        num_epochs)
Example #3
0
def run_inference(model, inputs, method=None):
    if method is None:
        # NUTS
        num_samples = 5000
        logger.info('NUTS sampling')
        kernel = NUTS(model)
        mcmc = MCMC(kernel, num_warmup=300, num_samples=num_samples)
        rng_key = random.PRNGKey(0)
        mcmc.run(rng_key, **inputs, extra_fields=('potential_energy', ))
        logger.info(r'MCMC summary for: {}'.format(model.__name__))
        mcmc.print_summary(exclude_deterministic=False)
        samples = mcmc.get_samples()
    else:
        #SVI
        logger.info('Guide generation...')
        rng_key = random.PRNGKey(0)
        guide = AutoDiagonalNormal(model=model)
        logger.info('Optimizer generation...')
        optim = Adam(0.05)
        logger.info('SVI generation...')
        svi = SVI(model, guide, optim, AutoContinuousELBO(), **inputs)
        init_state = svi.init(rng_key)
        logger.info('Scan...')
        state, loss = lax.scan(lambda x, i: svi.update(x), init_state,
                               np.zeros(2000))
        params = svi.get_params(state)
        samples = guide.sample_posterior(random.PRNGKey(1), params, (1000, ))
        logger.info(r'SVI summary for: {}'.format(model.__name__))
        numpyro.diagnostics.print_summary(samples,
                                          prob=0.90,
                                          group_by_chain=False)
    return samples
Example #4
0
def test_steinvi_smoke(kernel, auto_guide, init_loc_fn, problem):
    true_coefs, data, model = problem()
    stein = SteinVI(
        model,
        auto_guide(model, init_loc_fn=init_loc_fn),
        Adam(1e-1),
        Trace_ELBO(),
        kernel,
    )
    stein.run(random.PRNGKey(0), 1, *data)
Example #5
0
def test_neutra_reparam_unobserved_model():
    model = dirichlet_categorical
    data = jnp.ones(10, dtype=jnp.int32)
    guide = AutoIAFNormal(model)
    svi = SVI(model, guide, Adam(1e-3), Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(0), data)
    params = svi.get_params(svi_state)
    neutra = NeuTraReparam(guide, params)
    reparam_model = neutra.reparam(model)
    with handlers.seed(rng_seed=0):
        reparam_model(data=None)
Example #6
0
def test_svgd_loss_and_grads():
    true_coefs, data, model = uniform_normal()
    guide = AutoDelta(model)
    loss = Trace_ELBO()
    stein_uparams = {
        "alpha_auto_loc": np.array([
            -1.2,
        ]),
        "loc_base_auto_loc": np.array([
            1.53,
        ]),
    }
    stein = SteinVI(model, guide, Adam(0.1), loss, RBFKernel())
    stein.init(random.PRNGKey(0), *data)
    svi = SVI(model, guide, Adam(0.1), loss)
    svi.init(random.PRNGKey(0), *data)
    expected_loss = loss.loss(random.PRNGKey(1),
                              svi.constrain_fn(stein_uparams), model, guide,
                              *data)
    stein_loss, stein_grad = stein._svgd_loss_and_grads(
        random.PRNGKey(1), stein_uparams, *data)
    assert expected_loss == stein_loss
Example #7
0
def test_param_size(length, depth, t):
    def nest(v, d):
        if d == 0:
            return v
        return nest(t([v]), d - 1)

    seed = random.PRNGKey(nrandom.randint(0, 10_000))
    sizes = Poisson(5).sample(seed, (length, nrandom.randint(0, 10))) + 1
    total_size = sum(map(lambda size: size.prod(), sizes))
    uparam = t(
        nest(np.empty(tuple(size)), nrandom.randint(0, depth))
        for size in sizes)
    stein = SteinVI(id, id, Adam(1.0), Trace_ELBO(), RBFKernel())
    assert stein._param_size(uparam) == total_size, f"Failed for seed {seed}"
Example #8
0
def test_reparam_log_joint(model, kwargs):
    guide = AutoIAFNormal(model)
    svi = SVI(model, guide, Adam(1e-10), Trace_ELBO(), **kwargs)
    svi_state = svi.init(random.PRNGKey(0))
    params = svi.get_params(svi_state)
    neutra = NeuTraReparam(guide, params)
    reparam_model = neutra.reparam(model)
    _, pe_fn, _, _ = initialize_model(random.PRNGKey(1), model, model_kwargs=kwargs)
    init_params, pe_fn_neutra, _, _ = initialize_model(random.PRNGKey(2), reparam_model, model_kwargs=kwargs)
    latent_x = list(init_params[0].values())[0]
    pe_transformed = pe_fn_neutra(init_params[0])
    latent_y = neutra.transform(latent_x)
    log_det_jacobian = neutra.transform.log_abs_det_jacobian(latent_x, latent_y)
    pe = pe_fn(guide._unpack_latent(latent_y))
    assert_allclose(pe_transformed, pe - log_det_jacobian)
Example #9
0
def test_apply_kernel(kernel, particles, particle_info, loss_fn, tparticles,
                      mode, kval):
    if mode not in kval:
        pytest.skip()
    (d, ) = tparticles[0].shape
    kernel_fn = kernel(mode=mode)
    kernel_fn.init(random.PRNGKey(0), particles.shape)
    kernel_fn = kernel_fn.compute(particles, particle_info(d), loss_fn)
    v = np.ones_like(kval[mode])
    stein = SteinVI(id, id, Adam(1.0), Trace_ELBO(), kernel(mode))
    value = stein._apply_kernel(kernel_fn, *tparticles, v)
    kval_ = copy(kval)
    if mode == "matrix":
        kval_[mode] = np.dot(kval_[mode], v)
    assert_allclose(value, kval_[mode], atol=1e-9)
Example #10
0
def main(_args):
    data = generate_data()
    init_rng_key = PRNGKey(1273)
    # nuts = NUTS(gmm)
    # mcmc = MCMC(nuts, 100, 1000)
    # mcmc.print_summary()
    seeded_gmm = seed(gmm, init_rng_key)
    model_trace = trace(seeded_gmm).get_trace(data)
    max_plate_nesting = _guess_max_plate_nesting(model_trace)
    enum_gmm = enum(config_enumerate(gmm), - max_plate_nesting - 1)
    svi = SVI(enum_gmm, gmm_guide, Adam(0.1), RenyiELBO(-10.))
    svi_state = svi.init(init_rng_key, data)
    upd_fun = jax.jit(svi.update)
    with tqdm.trange(100_000) as pbar:
        for i in pbar:
            svi_state, loss = upd_fun(svi_state, data)
            pbar.set_description(f"SVI {loss}", True)
Example #11
0
def test_auto_guide(auto_class, init_loc_fn, num_particles):
    latent_dim = 3

    def model(obs):
        a = numpyro.sample("a", Normal(0, 1))
        return numpyro.sample("obs", Bernoulli(logits=a), obs=obs)

    obs = Bernoulli(0.5).sample(random.PRNGKey(0), (10, latent_dim))

    rng_key = random.PRNGKey(0)
    guide_key, stein_key = random.split(rng_key)
    inner_guide = auto_class(model, init_loc_fn=init_loc_fn())

    with handlers.seed(rng_seed=guide_key), handlers.trace() as inner_guide_tr:
        inner_guide(obs)

    steinvi = SteinVI(
        model,
        auto_class(model, init_loc_fn=init_loc_fn()),
        Adam(1.0),
        Trace_ELBO(),
        RBFKernel(),
        num_particles=num_particles,
    )
    state = steinvi.init(stein_key, obs)
    init_params = steinvi.get_params(state)

    for name, site in inner_guide_tr.items():
        if site.get("type") == "param":
            assert name in init_params
            inner_param = site
            init_value = init_params[name]
            expected_shape = (num_particles, *np.shape(inner_param["value"]))
            assert init_value.shape == expected_shape
            if "auto_loc" in name or name == "b":
                assert np.alltrue(init_value != np.zeros(expected_shape))
                assert np.unique(init_value).shape == init_value.reshape(
                    -1).shape
            elif "scale" in name:
                assert_array_approx_equal(init_value,
                                          np.full(expected_shape, 0.1))
            else:
                assert_array_approx_equal(init_value,
                                          np.full(expected_shape, 0.0))
Example #12
0
def fit_advi(model, num_iter, learning_rate=0.01, seed=0):
    """Automatic Differentiation Variational Inference using a Normal variational distribution
    with a diagonal covariance matrix.
    """
    rng_key = random.PRNGKey(seed)
    adam = Adam(learning_rate)
    # Automatically create a variational distribution (aka "guide" in Pyro's terminology)
    guide = AutoDiagonalNormal(model)
    svi = SVI(model, guide, adam, AutoContinuousELBO())
    svi_state = svi.init(rng_key)

    # Run optimization
    last_state, losses = lax.scan(lambda state, i: svi.update(state),
                                  svi_state, np.zeros(num_iter))
    results = ADVIResults(svi=svi,
                          guide=guide,
                          state=last_state,
                          losses=losses)
    return results
Example #13
0
def test_get_params(kernel, auto_guide, init_loc_fn, problem):
    _, data, model = problem()
    guide, optim, elbo = (
        auto_guide(model, init_loc_fn=init_loc_fn),
        Adam(1e-1),
        Trace_ELBO(),
    )

    stein = SteinVI(model, guide, optim, elbo, kernel)
    stein_params = stein.get_params(stein.init(random.PRNGKey(0), *data))

    svi = SVI(model, guide, optim, elbo)
    svi_params = svi.get_params(svi.init(random.PRNGKey(0), *data))
    assert svi_params.keys() == stein_params.keys()

    for name, svi_param in svi_params.items():
        assert (stein_params[name].shape == np.repeat(svi_param[None, ...],
                                                      stein.num_particles,
                                                      axis=0).shape)
Example #14
0
def train_model_no_dp(rng,
                      model,
                      guide,
                      data,
                      batch_size,
                      num_data,
                      num_epochs,
                      silent=False,
                      **kwargs):
    """ trains a given model using SVI (no DP!) and the globally defined parameters and data """

    optimizer = Adam(1e-3)

    svi = SVI(model, guide, optimizer, Trace_ELBO(), num_obs_total=num_data)

    import d3p.random.debug
    return _train_model(d3p.random.convert_to_jax_rng_key(rng),
                        d3p.random.debug, svi, data, batch_size, num_data,
                        num_epochs, silent)
Example #15
0
def test_calc_particle_info(num_params, num_particles):
    seed = random.PRNGKey(nrandom.randint(0, 10_000))
    sizes = Poisson(5).sample(seed, (100, nrandom.randint(0, 10))) + 1

    uparam = tuple(np.empty(tuple(size)) for size in sizes)
    uparams = {string.ascii_lowercase[i]: uparam for i in range(num_params)}

    par_param_size = sum(map(lambda size: size.prod(), sizes)) // num_particles
    expected_start_end = zip(
        par_param_size * np.arange(num_params),
        par_param_size * np.arange(1, num_params + 1),
    )
    expected_pinfo = dict(
        zip(string.ascii_lowercase[:num_params], expected_start_end))

    stein = SteinVI(id, id, Adam(1.0), Trace_ELBO(), RBFKernel())
    pinfo = stein._calc_particle_info(uparams, num_particles)

    for k in pinfo.keys():
        assert pinfo[k] == expected_pinfo[k], f"Failed for seed {seed}"
Example #16
0
def test_neals_funnel_smoke():
    dim = 10

    guide = AutoIAFNormal(neals_funnel)
    svi = SVI(neals_funnel, guide, Adam(1e-10), Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(0), dim)

    def body_fn(i, val):
        svi_state, loss = svi.update(val, dim)
        return svi_state

    svi_state = lax.fori_loop(0, 1000, body_fn, svi_state)
    params = svi.get_params(svi_state)

    neutra = NeuTraReparam(guide, params)
    model = neutra.reparam(neals_funnel)
    nuts = NUTS(model)
    mcmc = MCMC(nuts, num_warmup=50, num_samples=50)
    mcmc.run(random.PRNGKey(1), dim)
    samples = mcmc.get_samples()
    transformed_samples = neutra.transform_sample(samples['auto_shared_latent'])
    assert 'x' in transformed_samples
    assert 'y' in transformed_samples
Example #17
0
# ===================
# Model
# ===================
# GP model
sgp_model = SparseGP

# delta guide - basically deterministic
delta_guide = AutoDelta(SparseGP)

# ===================
# Optimization
# ===================
n_epochs = 1_000
lr = 0.01
optimizer = Adam(step_size=lr)

# ===================
# Training
# ===================
# reproducibility
rng_key = random.PRNGKey(42)

# setup svi
svi = SVI(sgp_model, delta_guide, optimizer, loss=Trace_ELBO())

# run svi
svi_results = svi.run(rng_key, n_epochs, X, y.T)

# ===================
# Plot Loss