Exemplo n.º 1
0
 def _get_posterior(self):
     if self.latent_dim == 1:
         raise ValueError('latent dim = 1. Consider using AutoDiagonalNormal instead')
     flows = []
     for i in range(self.num_flows):
         if i > 0:
             flows.append(PermuteTransform(jnp.arange(self.latent_dim)[::-1]))
         residual = "gated" if i < (self.num_flows - 1) else None
         arn = BlockNeuralAutoregressiveNN(self.latent_dim, self._hidden_factors, residual)
         arnn = numpyro.module('{}_arn__{}'.format(self.prefix, i), arn, (self.latent_dim,))
         flows.append(BlockNeuralAutoregressiveTransform(arnn))
     return dist.TransformedDistribution(self.get_base_dist(), flows)
Exemplo n.º 2
0
def test_block_neural_arn(input_dim, hidden_factors, residual, batch_shape):
    arn_init, arn = BlockNeuralAutoregressiveNN(input_dim, hidden_factors,
                                                residual)

    rng = random.PRNGKey(0)
    input_shape = batch_shape + (input_dim, )
    out_shape, init_params = arn_init(rng, input_shape)
    assert out_shape == input_shape

    x = random.normal(random.PRNGKey(1), input_shape)
    output, logdet = arn(init_params, x)
    assert output.shape == input_shape
    assert logdet.shape == input_shape

    if len(batch_shape) == 1:
        jac = vmap(jacfwd(lambda x: arn(init_params, x)[0]))(x)
    else:
        jac = jacfwd(lambda x: arn(init_params, x)[0])(x)
    assert_allclose(logdet.sum(-1), jnp.linalg.slogdet(jac)[1], rtol=1e-6)

    # make sure jacobians are lower triangular
    assert np.sum(np.abs(np.triu(jac, k=1))) == 0.0
    assert np.all(np.abs(matrix_to_tril_vec(jac)) > 0)