def test_svgd_loss_and_grads(): true_coefs, data, model = uniform_normal() guide = AutoDelta(model) loss = Trace_ELBO() stein_uparams = { "alpha_auto_loc": np.array([ -1.2, ]), "loc_base_auto_loc": np.array([ 1.53, ]), } stein = SteinVI(model, guide, Adam(0.1), loss, RBFKernel()) stein.init(random.PRNGKey(0), *data) svi = SVI(model, guide, Adam(0.1), loss) svi.init(random.PRNGKey(0), *data) expected_loss = loss.loss(random.PRNGKey(1), svi.constrain_fn(stein_uparams), model, guide, *data) stein_loss, stein_grad = stein._svgd_loss_and_grads( random.PRNGKey(1), stein_uparams, *data) assert expected_loss == stein_loss
def test_subsample_gradient(scale, subsample): data = jnp.array([-0.5, 2.0]) subsample_size = 1 if subsample else len(data) precision = 0.06 * scale def model(subsample): with handlers.substitute(data={"data": subsample}): with numpyro.plate("data", len(data), subsample_size) as ind: x = data[ind] z = numpyro.sample("z", dist.Normal(0, 1)) numpyro.sample("x", dist.Normal(z, 1), obs=x) def guide(subsample): scale = numpyro.param("scale", 1.) with handlers.substitute(data={"data": subsample}): with numpyro.plate("data", len(data), subsample_size): loc = numpyro.param("loc", jnp.zeros(len(data)), event_dim=0) numpyro.sample("z", dist.Normal(loc, scale)) if scale != 1.: model = handlers.scale(model, scale=scale) guide = handlers.scale(guide, scale=scale) num_particles = 50000 optimizer = optim.Adam(0.1) elbo = Trace_ELBO(num_particles=num_particles) svi = SVI(model, guide, optimizer, loss=elbo) svi_state = svi.init(random.PRNGKey(0), None) params = svi.optim.get_params(svi_state.optim_state) normalizer = 2 if subsample else 1 if subsample_size == 1: subsample = jnp.array([0]) loss1, grads1 = value_and_grad( lambda x: svi.loss.loss(svi_state.rng_key, svi.constrain_fn(x), svi .model, svi.guide, subsample))(params) subsample = jnp.array([1]) loss2, grads2 = value_and_grad( lambda x: svi.loss.loss(svi_state.rng_key, svi.constrain_fn(x), svi .model, svi.guide, subsample))(params) grads = tree_multimap(lambda *vals: vals[0] + vals[1], grads1, grads2) loss = loss1 + loss2 else: subsample = jnp.array([0, 1]) loss, grads = value_and_grad( lambda x: svi.loss.loss(svi_state.rng_key, svi.constrain_fn(x), svi .model, svi.guide, subsample))(params) actual_loss = loss / normalizer expected_loss, _ = value_and_grad(lambda x: svi.loss.loss( svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, None))( params) assert_allclose(actual_loss, expected_loss, rtol=precision, atol=precision) actual_grads = {name: grad / normalizer for name, grad in grads.items()} expected_grads = { 'loc': scale * jnp.array([0.5, -2.0]), 'scale': scale * jnp.array([2.0]) } assert actual_grads.keys() == expected_grads.keys() for name in expected_grads: assert_allclose(actual_grads[name], expected_grads[name], rtol=precision, atol=precision)