Пример #1
0
def test_iaf(input_dim, hidden_dims):
    arn = AutoregressiveNN(input_dim, hidden_dims, param_dims=[1, 1])

    rng = random.PRNGKey(0)
    batch_size = 4
    input_shape = (batch_size, input_dim)
    _, init_params = arn.init_fun(rng, input_shape)

    iaf = InverseAutoregressiveTransform(arn, init_params)

    # test inverse is correct
    x = onp.random.rand(*input_shape)
    y = iaf(x)
    inv = iaf.inv(y)
    assert_allclose(x, inv, atol=1e-5)

    # test jacobian
    x = onp.random.rand(*input_shape[-1:])
    jac = jacfwd(iaf)(x)

    # permute jacobian as necessary
    permuted_jac = onp.zeros(jac.shape)
    perm = arn.permutation

    for j in range(input_dim):
        for k in range(input_dim):
            permuted_jac[..., j, k] = jac[..., perm[j], perm[k]]

    # make sure jacobian is triangular
    assert onp.sum(onp.abs(onp.triu(permuted_jac, 1))) == 0.00

    # make sure iaf.log_abs_det_jacobian is correct
    ldj = iaf.log_abs_det_jacobian(x, y)
    assert_allclose(ldj, onp.sum(onp.log(onp.diag(permuted_jac))), atol=1e-5)
Пример #2
0
def _make_iaf_args(input_dim, hidden_dims):
    _, rng_key_perm = random.split(random.PRNGKey(0))
    perm = random.shuffle(rng_key_perm, onp.arange(input_dim))
    arn_init, arn = AutoregressiveNN(input_dim,
                                     hidden_dims,
                                     param_dims=[1, 1],
                                     permutation=perm)
    _, init_params = arn_init(random.PRNGKey(0), (input_dim, ))
    return partial(arn, init_params),
Пример #3
0
def _make_iaf_args(input_dim, hidden_dims):
    _, rng_perm = random.split(random.PRNGKey(0))
    perm = random.shuffle(rng_perm, onp.arange(input_dim))
    # we use Elu nonlinearity because the default one, Relu, masks out negative hidden values,
    # which in turn create some zero entries in the lower triangular part of Jacobian.
    arn_init, arn = AutoregressiveNN(input_dim,
                                     hidden_dims,
                                     param_dims=[1, 1],
                                     permutation=perm,
                                     nonlinearity=stax.Elu)
    _, init_params = arn_init(random.PRNGKey(0), (input_dim, ))
    return partial(arn, init_params),
Пример #4
0
def test_auto_reg_nn(input_dim, hidden_dims, param_dims, skip_connections):
    rng_key, rng_key_perm = random.split(random.PRNGKey(0))
    perm = random.permutation(rng_key_perm, onp.arange(input_dim))
    arn_init, arn = AutoregressiveNN(input_dim,
                                     hidden_dims,
                                     param_dims=param_dims,
                                     skip_connections=skip_connections,
                                     permutation=perm)

    batch_size = 4
    input_shape = (batch_size, input_dim)
    _, init_params = arn_init(rng_key, input_shape)

    output = arn(init_params, onp.random.rand(*input_shape))

    if param_dims == [1]:
        assert output.shape == (batch_size, input_dim)
        jac = jacfwd(lambda x: arn(init_params, x))(onp.random.rand(input_dim))
    elif param_dims == [1, 1]:
        assert output[0].shape == (batch_size, input_dim)
        assert output[1].shape == (batch_size, input_dim)
        jac = jacfwd(lambda x: arn(init_params, x)[0])(
            onp.random.rand(input_dim))
    elif param_dims == [2]:
        assert output.shape == (2, batch_size, input_dim)
        jac = jacfwd(lambda x: arn(init_params, x))(onp.random.rand(input_dim))
    elif param_dims == [2, 3]:
        assert output[0].shape == (2, batch_size, input_dim)
        assert output[1].shape == (3, batch_size, input_dim)
        jac = jacfwd(lambda x: arn(init_params, x)[0])(
            onp.random.rand(input_dim))

    # permute jacobian as necessary
    permuted_jac = onp.zeros(jac.shape)

    for j in range(input_dim):
        for k in range(input_dim):
            permuted_jac[..., j, k] = jac[..., perm[j], perm[k]]

    # make sure jacobians are triangular
    assert onp.sum(onp.abs(onp.triu(permuted_jac))) == 0.0
Пример #5
0
def _make_iaf(input_dim, hidden_dims, rng):
    arn_init, arn = AutoregressiveNN(input_dim, hidden_dims, param_dims=[1, 1])
    _, init_params = arn_init(rng, (input_dim, ))
    return InverseAutoregressiveTransform(partial(arn, init_params))
Пример #6
0
def _make_iaf_args(input_dim, hidden_dims):
    arn_init, arn = AutoregressiveNN(input_dim, hidden_dims, param_dims=[1, 1])
    _, init_params = arn_init(random.PRNGKey(0), (input_dim, ))
    return arn, init_params