def test_unpack_transform(): x = np.ones(3) unpack_fn = lambda x: {'key': x} # noqa: E731 transform = transforms.UnpackTransform(unpack_fn) y = transform(x) z = transform.inv(y) assert_allclose(y['key'], x) assert_allclose(z, x)
def test_iaf(): # test for substitute logic for exposed methods `sample_posterior` and `get_transforms` N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1., dim + 1.) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim))) offset = numpyro.sample('offset', dist.Uniform(-1, 1)) logits = offset + np.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoIAFNormal(model) svi = SVI(model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(rng_key_init, data, labels) params = svi.get_params(svi_state) x = random.normal(random.PRNGKey(0), (dim + 1, )) rng_key = random.PRNGKey(1) actual_sample = guide.sample_posterior(rng_key, params) actual_output = guide.get_transform(params)(x) flows = [] for i in range(guide.num_flows): if i > 0: flows.append(transforms.PermuteTransform(np.arange(dim + 1)[::-1])) arn_init, arn_apply = AutoregressiveNN( dim + 1, [dim + 1, dim + 1], permutation=np.arange(dim + 1), skip_connections=guide._skip_connections, nonlinearity=guide._nonlinearity) arn = partial(arn_apply, params['auto_arn__{}$params'.format(i)]) flows.append(InverseAutoregressiveTransform(arn)) flows.append(transforms.UnpackTransform(guide._unpack_latent)) transform = transforms.ComposeTransform(flows) _, rng_key_sample = random.split(rng_key) expected_sample = transform( dist.Normal(np.zeros(dim + 1), 1).sample(rng_key_sample)) expected_output = transform(x) assert_allclose(actual_sample['coefs'], expected_sample['coefs']) assert_allclose( actual_sample['offset'], transforms.biject_to(constraints.interval(-1, 1))( expected_sample['offset'])) check_eq(actual_output, expected_output)