def test_iaf(): # test for substitute logic for exposed methods `sample_posterior` and `get_transforms` N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) offset = numpyro.sample("offset", dist.Uniform(-1, 1)) logits = offset + jnp.sum(coefs * data, axis=-1) return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoIAFNormal(model) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(rng_key_init, data, labels) params = svi.get_params(svi_state) x = random.normal(random.PRNGKey(0), (dim + 1, )) rng_key = random.PRNGKey(1) actual_sample = guide.sample_posterior(rng_key, params) actual_output = guide._unpack_latent(guide.get_transform(params)(x)) flows = [] for i in range(guide.num_flows): if i > 0: flows.append(transforms.PermuteTransform( jnp.arange(dim + 1)[::-1])) arn_init, arn_apply = AutoregressiveNN( dim + 1, [dim + 1, dim + 1], permutation=jnp.arange(dim + 1), skip_connections=guide._skip_connections, nonlinearity=guide._nonlinearity, ) arn = partial(arn_apply, params["auto_arn__{}$params".format(i)]) flows.append(InverseAutoregressiveTransform(arn)) flows.append(guide._unpack_latent) transform = transforms.ComposeTransform(flows) _, rng_key_sample = random.split(rng_key) expected_sample = transform( dist.Normal(jnp.zeros(dim + 1), 1).sample(rng_key_sample)) expected_output = transform(x) assert_allclose(actual_sample["coefs"], expected_sample["coefs"]) assert_allclose( actual_sample["offset"], transforms.biject_to(constraints.interval(-1, 1))( expected_sample["offset"]), ) check_eq(actual_output, expected_output)
def _make_iaf(input_dim, hidden_dims, rng_key): arn_init, arn = AutoregressiveNN(input_dim, hidden_dims, param_dims=[1, 1]) _, init_params = arn_init(rng_key, (input_dim,)) return InverseAutoregressiveTransform(partial(arn, init_params)) @pytest.mark.parametrize('ts', [ [transforms.PowerTransform(0.7), transforms.AffineTransform(2., 3.)], [transforms.ExpTransform()], [transforms.ComposeTransform([transforms.AffineTransform(-2, 3), transforms.ExpTransform()]), transforms.PowerTransform(3.)], [_make_iaf(5, hidden_dims=[10], rng_key=random.PRNGKey(0)), transforms.PermuteTransform(np.arange(5)[::-1]), _make_iaf(5, hidden_dims=[10], rng_key=random.PRNGKey(1))] ]) def test_compose_transform_with_intermediates(ts): transform = transforms.ComposeTransform(ts) x = random.normal(random.PRNGKey(2), (7, 5)) y, intermediates = transform.call_with_intermediates(x) logdet = transform.log_abs_det_jacobian(x, y, intermediates) assert_allclose(y, transform(x)) assert_allclose(logdet, transform.log_abs_det_jacobian(x, y)) def test_unpack_transform(): x = np.ones(3) unpack_fn = lambda x: {'key': x} # noqa: E731 transform = transforms.UnpackTransform(unpack_fn)