def _get_posterior(self): if self.latent_dim == 1: raise ValueError('latent dim = 1. Consider using AutoDiagonalNormal instead') flows = [] for i in range(self.num_flows): if i > 0: flows.append(PermuteTransform(jnp.arange(self.latent_dim)[::-1])) residual = "gated" if i < (self.num_flows - 1) else None arn = BlockNeuralAutoregressiveNN(self.latent_dim, self._hidden_factors, residual) arnn = numpyro.module('{}_arn__{}'.format(self.prefix, i), arn, (self.latent_dim,)) flows.append(BlockNeuralAutoregressiveTransform(arnn)) return dist.TransformedDistribution(self.get_base_dist(), flows)
def test_block_neural_arn(input_dim, hidden_factors, residual, batch_shape): arn_init, arn = BlockNeuralAutoregressiveNN(input_dim, hidden_factors, residual) rng = random.PRNGKey(0) input_shape = batch_shape + (input_dim, ) out_shape, init_params = arn_init(rng, input_shape) assert out_shape == input_shape x = random.normal(random.PRNGKey(1), input_shape) output, logdet = arn(init_params, x) assert output.shape == input_shape assert logdet.shape == input_shape if len(batch_shape) == 1: jac = vmap(jacfwd(lambda x: arn(init_params, x)[0]))(x) else: jac = jacfwd(lambda x: arn(init_params, x)[0])(x) assert_allclose(logdet.sum(-1), jnp.linalg.slogdet(jac)[1], rtol=1e-6) # make sure jacobians are lower triangular assert np.sum(np.abs(np.triu(jac, k=1))) == 0.0 assert np.all(np.abs(matrix_to_tril_vec(jac)) > 0)