def test_unpack_transform():
    x = np.ones(3)
    unpack_fn = lambda x: {'key': x}  # noqa: E731
    transform = transforms.UnpackTransform(unpack_fn)
    y = transform(x)
    z = transform.inv(y)
    assert_allclose(y['key'], x)
    assert_allclose(z, x)
def test_iaf():
    # test for substitute logic for exposed methods `sample_posterior` and `get_transforms`
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = np.arange(1., dim + 1.)
    logits = np.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim),
                                                    np.ones(dim)))
        offset = numpyro.sample('offset', dist.Uniform(-1, 1))
        logits = offset + np.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = AutoIAFNormal(model)
    svi = SVI(model, guide, adam, AutoContinuousELBO())
    svi_state = svi.init(rng_key_init, data, labels)
    params = svi.get_params(svi_state)

    x = random.normal(random.PRNGKey(0), (dim + 1, ))
    rng_key = random.PRNGKey(1)
    actual_sample = guide.sample_posterior(rng_key, params)
    actual_output = guide.get_transform(params)(x)

    flows = []
    for i in range(guide.num_flows):
        if i > 0:
            flows.append(transforms.PermuteTransform(np.arange(dim + 1)[::-1]))
        arn_init, arn_apply = AutoregressiveNN(
            dim + 1, [dim + 1, dim + 1],
            permutation=np.arange(dim + 1),
            skip_connections=guide._skip_connections,
            nonlinearity=guide._nonlinearity)
        arn = partial(arn_apply, params['auto_arn__{}$params'.format(i)])
        flows.append(InverseAutoregressiveTransform(arn))
    flows.append(transforms.UnpackTransform(guide._unpack_latent))

    transform = transforms.ComposeTransform(flows)
    _, rng_key_sample = random.split(rng_key)
    expected_sample = transform(
        dist.Normal(np.zeros(dim + 1), 1).sample(rng_key_sample))
    expected_output = transform(x)
    assert_allclose(actual_sample['coefs'], expected_sample['coefs'])
    assert_allclose(
        actual_sample['offset'],
        transforms.biject_to(constraints.interval(-1, 1))(
            expected_sample['offset']))
    check_eq(actual_output, expected_output)