def fit_advi(model, num_iter, learning_rate=0.01, seed=0): """Automatic Differentiation Variational Inference using a Normal variational distribution with a diagonal covariance matrix. Args: model: a NumPyro's model function num_iter: number of iterations of gradient descent (Adam) learning_rate: the step size for the Adam algorithm (default: {0.01}) seed: random seed (default: {0}) Returns: a set of results of type ADVIResults """ rng_key = random.PRNGKey(seed) adam = Adam(learning_rate) # Automatically create a variational distribution (aka "guide" in Pyro's terminology) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(rng_key) # Run optimization last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(num_iter)) results = ADVIResults(svi=svi, guide=guide, state=last_state, losses=losses) return results
def train_model(rng, rng_suite, model, guide, data, batch_size, num_data, dp_scale, num_epochs, clipping_threshold=1.): """ trains a given model using DPSVI and the globally defined parameters and data """ optimizer = Adam(1e-3) svi = DPSVI(model, guide, optimizer, Trace_ELBO(), num_obs_total=num_data, clipping_threshold=clipping_threshold, dp_scale=dp_scale, rng_suite=rng_suite) return _train_model(rng, rng_suite, svi, data, batch_size, num_data, num_epochs)
def run_inference(model, inputs, method=None): if method is None: # NUTS num_samples = 5000 logger.info('NUTS sampling') kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=300, num_samples=num_samples) rng_key = random.PRNGKey(0) mcmc.run(rng_key, **inputs, extra_fields=('potential_energy', )) logger.info(r'MCMC summary for: {}'.format(model.__name__)) mcmc.print_summary(exclude_deterministic=False) samples = mcmc.get_samples() else: #SVI logger.info('Guide generation...') rng_key = random.PRNGKey(0) guide = AutoDiagonalNormal(model=model) logger.info('Optimizer generation...') optim = Adam(0.05) logger.info('SVI generation...') svi = SVI(model, guide, optim, AutoContinuousELBO(), **inputs) init_state = svi.init(rng_key) logger.info('Scan...') state, loss = lax.scan(lambda x, i: svi.update(x), init_state, np.zeros(2000)) params = svi.get_params(state) samples = guide.sample_posterior(random.PRNGKey(1), params, (1000, )) logger.info(r'SVI summary for: {}'.format(model.__name__)) numpyro.diagnostics.print_summary(samples, prob=0.90, group_by_chain=False) return samples
def test_steinvi_smoke(kernel, auto_guide, init_loc_fn, problem): true_coefs, data, model = problem() stein = SteinVI( model, auto_guide(model, init_loc_fn=init_loc_fn), Adam(1e-1), Trace_ELBO(), kernel, ) stein.run(random.PRNGKey(0), 1, *data)
def test_neutra_reparam_unobserved_model(): model = dirichlet_categorical data = jnp.ones(10, dtype=jnp.int32) guide = AutoIAFNormal(model) svi = SVI(model, guide, Adam(1e-3), Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0), data) params = svi.get_params(svi_state) neutra = NeuTraReparam(guide, params) reparam_model = neutra.reparam(model) with handlers.seed(rng_seed=0): reparam_model(data=None)
def test_svgd_loss_and_grads(): true_coefs, data, model = uniform_normal() guide = AutoDelta(model) loss = Trace_ELBO() stein_uparams = { "alpha_auto_loc": np.array([ -1.2, ]), "loc_base_auto_loc": np.array([ 1.53, ]), } stein = SteinVI(model, guide, Adam(0.1), loss, RBFKernel()) stein.init(random.PRNGKey(0), *data) svi = SVI(model, guide, Adam(0.1), loss) svi.init(random.PRNGKey(0), *data) expected_loss = loss.loss(random.PRNGKey(1), svi.constrain_fn(stein_uparams), model, guide, *data) stein_loss, stein_grad = stein._svgd_loss_and_grads( random.PRNGKey(1), stein_uparams, *data) assert expected_loss == stein_loss
def test_param_size(length, depth, t): def nest(v, d): if d == 0: return v return nest(t([v]), d - 1) seed = random.PRNGKey(nrandom.randint(0, 10_000)) sizes = Poisson(5).sample(seed, (length, nrandom.randint(0, 10))) + 1 total_size = sum(map(lambda size: size.prod(), sizes)) uparam = t( nest(np.empty(tuple(size)), nrandom.randint(0, depth)) for size in sizes) stein = SteinVI(id, id, Adam(1.0), Trace_ELBO(), RBFKernel()) assert stein._param_size(uparam) == total_size, f"Failed for seed {seed}"
def test_reparam_log_joint(model, kwargs): guide = AutoIAFNormal(model) svi = SVI(model, guide, Adam(1e-10), Trace_ELBO(), **kwargs) svi_state = svi.init(random.PRNGKey(0)) params = svi.get_params(svi_state) neutra = NeuTraReparam(guide, params) reparam_model = neutra.reparam(model) _, pe_fn, _, _ = initialize_model(random.PRNGKey(1), model, model_kwargs=kwargs) init_params, pe_fn_neutra, _, _ = initialize_model(random.PRNGKey(2), reparam_model, model_kwargs=kwargs) latent_x = list(init_params[0].values())[0] pe_transformed = pe_fn_neutra(init_params[0]) latent_y = neutra.transform(latent_x) log_det_jacobian = neutra.transform.log_abs_det_jacobian(latent_x, latent_y) pe = pe_fn(guide._unpack_latent(latent_y)) assert_allclose(pe_transformed, pe - log_det_jacobian)
def test_apply_kernel(kernel, particles, particle_info, loss_fn, tparticles, mode, kval): if mode not in kval: pytest.skip() (d, ) = tparticles[0].shape kernel_fn = kernel(mode=mode) kernel_fn.init(random.PRNGKey(0), particles.shape) kernel_fn = kernel_fn.compute(particles, particle_info(d), loss_fn) v = np.ones_like(kval[mode]) stein = SteinVI(id, id, Adam(1.0), Trace_ELBO(), kernel(mode)) value = stein._apply_kernel(kernel_fn, *tparticles, v) kval_ = copy(kval) if mode == "matrix": kval_[mode] = np.dot(kval_[mode], v) assert_allclose(value, kval_[mode], atol=1e-9)
def main(_args): data = generate_data() init_rng_key = PRNGKey(1273) # nuts = NUTS(gmm) # mcmc = MCMC(nuts, 100, 1000) # mcmc.print_summary() seeded_gmm = seed(gmm, init_rng_key) model_trace = trace(seeded_gmm).get_trace(data) max_plate_nesting = _guess_max_plate_nesting(model_trace) enum_gmm = enum(config_enumerate(gmm), - max_plate_nesting - 1) svi = SVI(enum_gmm, gmm_guide, Adam(0.1), RenyiELBO(-10.)) svi_state = svi.init(init_rng_key, data) upd_fun = jax.jit(svi.update) with tqdm.trange(100_000) as pbar: for i in pbar: svi_state, loss = upd_fun(svi_state, data) pbar.set_description(f"SVI {loss}", True)
def test_auto_guide(auto_class, init_loc_fn, num_particles): latent_dim = 3 def model(obs): a = numpyro.sample("a", Normal(0, 1)) return numpyro.sample("obs", Bernoulli(logits=a), obs=obs) obs = Bernoulli(0.5).sample(random.PRNGKey(0), (10, latent_dim)) rng_key = random.PRNGKey(0) guide_key, stein_key = random.split(rng_key) inner_guide = auto_class(model, init_loc_fn=init_loc_fn()) with handlers.seed(rng_seed=guide_key), handlers.trace() as inner_guide_tr: inner_guide(obs) steinvi = SteinVI( model, auto_class(model, init_loc_fn=init_loc_fn()), Adam(1.0), Trace_ELBO(), RBFKernel(), num_particles=num_particles, ) state = steinvi.init(stein_key, obs) init_params = steinvi.get_params(state) for name, site in inner_guide_tr.items(): if site.get("type") == "param": assert name in init_params inner_param = site init_value = init_params[name] expected_shape = (num_particles, *np.shape(inner_param["value"])) assert init_value.shape == expected_shape if "auto_loc" in name or name == "b": assert np.alltrue(init_value != np.zeros(expected_shape)) assert np.unique(init_value).shape == init_value.reshape( -1).shape elif "scale" in name: assert_array_approx_equal(init_value, np.full(expected_shape, 0.1)) else: assert_array_approx_equal(init_value, np.full(expected_shape, 0.0))
def fit_advi(model, num_iter, learning_rate=0.01, seed=0): """Automatic Differentiation Variational Inference using a Normal variational distribution with a diagonal covariance matrix. """ rng_key = random.PRNGKey(seed) adam = Adam(learning_rate) # Automatically create a variational distribution (aka "guide" in Pyro's terminology) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(rng_key) # Run optimization last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(num_iter)) results = ADVIResults(svi=svi, guide=guide, state=last_state, losses=losses) return results
def test_get_params(kernel, auto_guide, init_loc_fn, problem): _, data, model = problem() guide, optim, elbo = ( auto_guide(model, init_loc_fn=init_loc_fn), Adam(1e-1), Trace_ELBO(), ) stein = SteinVI(model, guide, optim, elbo, kernel) stein_params = stein.get_params(stein.init(random.PRNGKey(0), *data)) svi = SVI(model, guide, optim, elbo) svi_params = svi.get_params(svi.init(random.PRNGKey(0), *data)) assert svi_params.keys() == stein_params.keys() for name, svi_param in svi_params.items(): assert (stein_params[name].shape == np.repeat(svi_param[None, ...], stein.num_particles, axis=0).shape)
def train_model_no_dp(rng, model, guide, data, batch_size, num_data, num_epochs, silent=False, **kwargs): """ trains a given model using SVI (no DP!) and the globally defined parameters and data """ optimizer = Adam(1e-3) svi = SVI(model, guide, optimizer, Trace_ELBO(), num_obs_total=num_data) import d3p.random.debug return _train_model(d3p.random.convert_to_jax_rng_key(rng), d3p.random.debug, svi, data, batch_size, num_data, num_epochs, silent)
def test_calc_particle_info(num_params, num_particles): seed = random.PRNGKey(nrandom.randint(0, 10_000)) sizes = Poisson(5).sample(seed, (100, nrandom.randint(0, 10))) + 1 uparam = tuple(np.empty(tuple(size)) for size in sizes) uparams = {string.ascii_lowercase[i]: uparam for i in range(num_params)} par_param_size = sum(map(lambda size: size.prod(), sizes)) // num_particles expected_start_end = zip( par_param_size * np.arange(num_params), par_param_size * np.arange(1, num_params + 1), ) expected_pinfo = dict( zip(string.ascii_lowercase[:num_params], expected_start_end)) stein = SteinVI(id, id, Adam(1.0), Trace_ELBO(), RBFKernel()) pinfo = stein._calc_particle_info(uparams, num_particles) for k in pinfo.keys(): assert pinfo[k] == expected_pinfo[k], f"Failed for seed {seed}"
def test_neals_funnel_smoke(): dim = 10 guide = AutoIAFNormal(neals_funnel) svi = SVI(neals_funnel, guide, Adam(1e-10), Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0), dim) def body_fn(i, val): svi_state, loss = svi.update(val, dim) return svi_state svi_state = lax.fori_loop(0, 1000, body_fn, svi_state) params = svi.get_params(svi_state) neutra = NeuTraReparam(guide, params) model = neutra.reparam(neals_funnel) nuts = NUTS(model) mcmc = MCMC(nuts, num_warmup=50, num_samples=50) mcmc.run(random.PRNGKey(1), dim) samples = mcmc.get_samples() transformed_samples = neutra.transform_sample(samples['auto_shared_latent']) assert 'x' in transformed_samples assert 'y' in transformed_samples
# =================== # Model # =================== # GP model sgp_model = SparseGP # delta guide - basically deterministic delta_guide = AutoDelta(SparseGP) # =================== # Optimization # =================== n_epochs = 1_000 lr = 0.01 optimizer = Adam(step_size=lr) # =================== # Training # =================== # reproducibility rng_key = random.PRNGKey(42) # setup svi svi = SVI(sgp_model, delta_guide, optimizer, loss=Trace_ELBO()) # run svi svi_results = svi.run(rng_key, n_epochs, X, y.T) # =================== # Plot Loss