def test_iaf(): # test for substitute logic for exposed methods `sample_posterior` and `get_transforms` N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) offset = numpyro.sample("offset", dist.Uniform(-1, 1)) logits = offset + jnp.sum(coefs * data, axis=-1) return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoIAFNormal(model) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(rng_key_init, data, labels) params = svi.get_params(svi_state) x = random.normal(random.PRNGKey(0), (dim + 1, )) rng_key = random.PRNGKey(1) actual_sample = guide.sample_posterior(rng_key, params) actual_output = guide._unpack_latent(guide.get_transform(params)(x)) flows = [] for i in range(guide.num_flows): if i > 0: flows.append(transforms.PermuteTransform( jnp.arange(dim + 1)[::-1])) arn_init, arn_apply = AutoregressiveNN( dim + 1, [dim + 1, dim + 1], permutation=jnp.arange(dim + 1), skip_connections=guide._skip_connections, nonlinearity=guide._nonlinearity, ) arn = partial(arn_apply, params["auto_arn__{}$params".format(i)]) flows.append(InverseAutoregressiveTransform(arn)) flows.append(guide._unpack_latent) transform = transforms.ComposeTransform(flows) _, rng_key_sample = random.split(rng_key) expected_sample = transform( dist.Normal(jnp.zeros(dim + 1), 1).sample(rng_key_sample)) expected_output = transform(x) assert_allclose(actual_sample["coefs"], expected_sample["coefs"]) assert_allclose( actual_sample["offset"], transforms.biject_to(constraints.interval(-1, 1))( expected_sample["offset"]), ) check_eq(actual_output, expected_output)
def _get_posterior(self): if self.latent_dim == 1: raise ValueError('latent dim = 1. Consider using AutoDiagonalNormal instead') hidden_dims = [self.latent_dim, self.latent_dim] if self._hidden_dims is None else self._hidden_dims flows = [] for i in range(self.num_flows): if i > 0: flows.append(PermuteTransform(jnp.arange(self.latent_dim)[::-1])) arn = AutoregressiveNN(self.latent_dim, hidden_dims, permutation=jnp.arange(self.latent_dim), skip_connections=self._skip_connections, nonlinearity=self._nonlinearity) arnn = numpyro.module('{}_arn__{}'.format(self.prefix, i), arn, (self.latent_dim,)) flows.append(InverseAutoregressiveTransform(arnn)) return dist.TransformedDistribution(self.get_base_dist(), flows)