def test_hmcecs_normal_normal(kernel_cls, num_block, subsample_size): true_loc = jnp.array([0.3, 0.1, 0.9]) num_warmup, num_samples = 200, 200 data = true_loc + dist.Normal(jnp.zeros(3, ), jnp.ones(3, )).sample( random.PRNGKey(1), (10000, )) def model(data, subsample_size): mean = numpyro.sample('mean', dist.Normal().expand((3, )).to_event(1)) with numpyro.plate('batch', data.shape[0], dim=-2, subsample_size=subsample_size): sub_data = numpyro.subsample(data, 0) numpyro.sample("obs", dist.Normal(mean, 1), obs=sub_data) ref_params = { 'mean': true_loc + dist.Normal(true_loc, 5e-2).sample(random.PRNGKey(0)) } proxy_fn = HMCECS.taylor_proxy(ref_params) kernel = HMCECS(kernel_cls(model), proxy=proxy_fn) mcmc = MCMC(kernel, num_warmup, num_samples) mcmc.run(random.PRNGKey(0), data, subsample_size) samples = mcmc.get_samples() assert_allclose(np.mean(mcmc.get_samples()['mean'], axis=0), true_loc, atol=0.1) assert len(samples['mean']) == num_samples
def test_estimate_likelihood(kernel_cls): data_key, tr_key, sub_key, rng_key = random.split(random.PRNGKey(0), 4) ref_params = jnp.array([0.1, 0.5, -0.2]) sigma = 0.1 data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample( data_key, (10_000,) ) n, _ = data.shape num_warmup = 200 num_samples = 200 num_blocks = 20 def model(data): mean = numpyro.sample( "mean", dist.Normal(ref_params, jnp.ones_like(ref_params)) ) with numpyro.plate("N", data.shape[0], subsample_size=100, dim=-2) as idx: numpyro.sample("obs", dist.Normal(mean, sigma), obs=data[idx]) proxy_fn = HMCECS.taylor_proxy({"mean": ref_params}) kernel = HMCECS(kernel_cls(model), proxy=proxy_fn, num_blocks=num_blocks) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.run(random.PRNGKey(0), data, extra_fields=["hmc_state.potential_energy"]) pes = mcmc.get_extra_fields()["hmc_state.potential_energy"] samples = mcmc.get_samples() pes_full = vmap( lambda sample: log_density( model, (data,), {}, {**sample, **{"N": jnp.arange(n)}} )[0] )(samples) assert jnp.var(jnp.exp(-pes - pes_full)) < 1.0
def run_hmcecs(hmcecs_key, args, data, obs, inner_kernel): svi_key, mcmc_key = random.split(hmcecs_key) # find reference parameters for second order taylor expansion to estimate likelihood (taylor_proxy) optimizer = numpyro.optim.Adam(step_size=1e-3) guide = autoguide.AutoDelta(model) svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) svi_result = svi.run(svi_key, args.num_svi_steps, data, obs, args.subsample_size) params, losses = svi_result.params, svi_result.losses ref_params = {"theta": params["theta_auto_loc"]} # taylor proxy estimates log likelihood (ll) by # taylor_expansion(ll, theta_curr) + # sum_{i in subsample} ll_i(theta_curr) - taylor_expansion(ll_i, theta_curr) around ref_params proxy = HMCECS.taylor_proxy(ref_params) kernel = HMCECS(inner_kernel, num_blocks=args.num_blocks, proxy=proxy) mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples) mcmc.run(mcmc_key, data, obs, args.subsample_size) mcmc.print_summary() return losses, mcmc.get_samples()
def test_taylor_proxy_norm(subsample_size): data_key, tr_key, rng_key = random.split(random.PRNGKey(0), 3) ref_params = jnp.array([0.1, 0.5, -0.2]) sigma = .1 data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample( data_key, (100, )) n, _ = data.shape def model(data, subsample_size): mean = numpyro.sample( 'mean', dist.Normal(ref_params, jnp.ones_like(ref_params))) with numpyro.plate('data', data.shape[0], subsample_size=subsample_size, dim=-2) as idx: numpyro.sample('obs', dist.Normal(mean, sigma), obs=data[idx]) def log_prob_fn(params): return vmap(dist.Normal(params, sigma).log_prob)(data).sum(-1) log_prob = log_prob_fn(ref_params) log_norm_jac = jacrev(log_prob_fn)(ref_params) log_norm_hessian = hessian(log_prob_fn)(ref_params) tr = numpyro.handlers.trace(numpyro.handlers.seed(model, tr_key)).get_trace( data, subsample_size) plate_sizes = {'data': (n, subsample_size)} proxy_constructor = HMCECS.taylor_proxy({'mean': ref_params}) proxy_fn, gibbs_init, gibbs_update = proxy_constructor( tr, plate_sizes, model, (data, subsample_size), {}) def taylor_expand_2nd_order(idx, pos): return log_prob[idx] + ( log_norm_jac[idx] @ pos) + .5 * (pos @ log_norm_hessian[idx]) @ pos def taylor_expand_2nd_order_sum(pos): return log_prob.sum() + log_norm_jac.sum( 0) @ pos + .5 * pos @ log_norm_hessian.sum(0) @ pos for _ in range(5): split_key, perturbe_key, rng_key = random.split(rng_key, 3) perturbe_params = ref_params + dist.Normal(.1, 0.1).sample( perturbe_key, ref_params.shape) subsample_idx = random.randint(rng_key, (subsample_size, ), 0, n) gibbs_site = {'data': subsample_idx} proxy_state = gibbs_init(None, gibbs_site) actual_proxy_sum, actual_proxy_sub = proxy_fn( {'data': perturbe_params}, ['data'], proxy_state) assert_allclose(actual_proxy_sub['data'], taylor_expand_2nd_order(subsample_idx, perturbe_params - ref_params), rtol=1e-5) assert_allclose(actual_proxy_sum['data'], taylor_expand_2nd_order_sum(perturbe_params - ref_params), rtol=1e-5)
def test_hmcecs_multiple_plates(): true_loc = jnp.array([0.3, 0.1, 0.9]) num_warmup, num_samples = 2, 2 data = true_loc + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample( random.PRNGKey(1), (1000,) ) def model(data): mean = numpyro.sample("mean", dist.Normal().expand((3,)).to_event(1)) with numpyro.plate("batch", data.shape[0], dim=-2, subsample_size=10): sub_data = numpyro.subsample(data, 0) with numpyro.plate("dim", 3): numpyro.sample("obs", dist.Normal(mean, 1), obs=sub_data) ref_params = { "mean": true_loc + dist.Normal(true_loc, 5e-2).sample(random.PRNGKey(0)) } proxy_fn = HMCECS.taylor_proxy(ref_params) kernel = HMCECS(NUTS(model), proxy=proxy_fn) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.run(random.PRNGKey(0), data)
def benchmark_hmc(args, features, labels): rng_key = random.PRNGKey(1) start = time.time() # a MAP estimate at the following source # https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117 ref_params = { "coefs": jnp.array([ +2.03420663e00, -3.53567265e-02, -1.49223924e-01, -3.07049364e-01, -1.00028366e-01, -1.46827862e-01, -1.64167881e-01, -4.20344204e-01, +9.47479829e-02, -1.12681836e-02, +2.64442056e-01, -1.22087866e-01, -6.00568838e-02, -3.79419506e-01, -1.06668741e-01, -2.97053963e-01, -2.05253899e-01, -4.69537191e-02, -2.78072730e-02, -1.43250525e-01, -6.77954629e-02, -4.34899796e-03, +5.90927452e-02, +7.23133609e-02, +1.38526391e-02, -1.24497898e-01, -1.50733739e-02, -2.68872194e-02, -1.80925727e-02, +3.47936489e-02, +4.03552800e-02, -9.98773426e-03, +6.20188080e-02, +1.15002751e-01, +1.32145107e-01, +2.69109547e-01, +2.45785132e-01, +1.19035013e-01, -2.59744357e-02, +9.94279515e-04, +3.39266285e-02, -1.44057125e-02, -6.95222765e-02, -7.52013028e-02, +1.21171586e-01, +2.29205526e-02, +1.47308692e-01, -8.34354162e-02, -9.34122875e-02, -2.97472421e-02, -3.03937674e-01, -1.70958012e-01, -1.59496680e-01, -1.88516974e-01, -1.20889175e00, ]) } if args.algo == "HMC": step_size = jnp.sqrt(0.5 / features.shape[0]) trajectory_length = step_size * args.num_steps kernel = HMC( model, step_size=step_size, trajectory_length=trajectory_length, adapt_step_size=False, dense_mass=args.dense_mass, ) subsample_size = None elif args.algo == "NUTS": kernel = NUTS(model, dense_mass=args.dense_mass) subsample_size = None elif args.algo == "HMCECS": subsample_size = 1000 inner_kernel = NUTS( model, init_strategy=init_to_value(values=ref_params), dense_mass=args.dense_mass, ) # note: if num_blocks=100, we'll update 10 index at each MCMC step # so it took 50000 MCMC steps to iterative the whole dataset kernel = HMCECS(inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(ref_params)) elif args.algo == "SA": # NB: this kernel requires large num_warmup and num_samples # and running on GPU is much faster than on CPU kernel = SA(model, adapt_state_size=1000, init_strategy=init_to_value(values=ref_params)) subsample_size = None elif args.algo == "FlowHMCECS": subsample_size = 1000 guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8]) svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) svi_result = svi.run(random.PRNGKey(2), 2000, features, labels) params, losses = svi_result.params, svi_result.losses plt.plot(losses) plt.show() neutra = NeuTraReparam(guide, params) neutra_model = neutra.reparam(model) neutra_ref_params = {"auto_shared_latent": jnp.zeros(55)} # no need to adapt mass matrix if the flow does a good job inner_kernel = NUTS( neutra_model, init_strategy=init_to_value(values=neutra_ref_params), adapt_mass_matrix=False, ) kernel = HMCECS(inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(neutra_ref_params)) else: raise ValueError( "Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.") mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples) mcmc.run(rng_key, features, labels, subsample_size, extra_fields=("accept_prob", )) print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"])) mcmc.print_summary(exclude_deterministic=False) print("\nMCMC elapsed time:", time.time() - start)