Beispiel #1
0
def test_biject_to(constraint, shape):
    transform = biject_to(constraint)
    if isinstance(constraint, constraints._Interval):
        assert transform.codomain.upper_bound == constraint.upper_bound
        assert transform.codomain.lower_bound == constraint.lower_bound
    elif isinstance(constraint, constraints._GreaterThan):
        assert transform.codomain.lower_bound == constraint.lower_bound
    if len(shape) < transform.event_dim:
        return
    rng = random.PRNGKey(0)
    x = random.normal(rng, shape)
    y = transform(x)

    # test codomain
    batch_shape = shape if transform.event_dim == 0 else shape[:-1]
    assert_array_equal(transform.codomain(y),
                       np.ones(batch_shape, dtype=np.bool_))

    # test inv
    z = transform.inv(y)
    assert_allclose(x, z, atol=1e-6, rtol=1e-6)

    # test domain, currently all is constraints.real or constraints.real_vector
    assert_array_equal(transform.domain(z), np.ones(batch_shape))

    # test log_abs_det_jacobian
    actual = transform.log_abs_det_jacobian(x, y)
    assert np.shape(actual) == batch_shape
    if len(shape) == transform.event_dim:
        if constraint is constraints.simplex:
            expected = onp.linalg.slogdet(
                jax.jacobian(transform)(x)[:-1, :])[1]
            inv_expected = onp.linalg.slogdet(
                jax.jacobian(transform.inv)(y)[:, :-1])[1]
        elif constraint is constraints.corr_cholesky:
            vec_transform = lambda x: matrix_to_tril_vec(
                transform(x), diagonal=-1)  # noqa: E731
            y_tril = matrix_to_tril_vec(y, diagonal=-1)
            inv_vec_transform = lambda x: transform.inv(
                vec_to_tril_matrix(x, diagonal=-1))  # noqa: E731
            expected = onp.linalg.slogdet(jax.jacobian(vec_transform)(x))[1]
            inv_expected = onp.linalg.slogdet(
                jax.jacobian(inv_vec_transform)(y_tril))[1]
        elif constraint is constraints.lower_cholesky:
            vec_transform = lambda x: matrix_to_tril_vec(transform(x)
                                                         )  # noqa: E731
            y_tril = matrix_to_tril_vec(y)
            inv_vec_transform = lambda x: transform.inv(vec_to_tril_matrix(x)
                                                        )  # noqa: E731
            expected = onp.linalg.slogdet(jax.jacobian(vec_transform)(x))[1]
            inv_expected = onp.linalg.slogdet(
                jax.jacobian(inv_vec_transform)(y_tril))[1]
        else:
            expected = np.log(np.abs(grad(transform)(x)))
            inv_expected = np.log(np.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6)
Beispiel #2
0
 def _inverse(self, y):
     # inverse stick-breaking
     z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1)
     pad_width = [(0, 0)] * y.ndim
     pad_width[-1] = (1, 0)
     z1m_cumprod_shifted = jnp.pad(z1m_cumprod[..., :-1], pad_width,
                                   mode="constant", constant_values=1.)
     t = matrix_to_tril_vec(y, diagonal=-1) / jnp.sqrt(
         matrix_to_tril_vec(z1m_cumprod_shifted, diagonal=-1))
     # inverse of tanh
     x = jnp.log((1 + t) / (1 - t)) / 2
     return x
Beispiel #3
0
def test_log_prob_LKJCholesky_uniform(dimension):
    # When concentration=1, the distribution of correlation matrices is uniform.
    # We will test that fact here.
    d = dist.LKJCholesky(dimension=dimension, concentration=1)
    N = 5
    corr_log_prob = []
    for i in range(N):
        sample = d.sample(random.PRNGKey(i))
        log_prob = d.log_prob(sample)
        sample_tril = matrix_to_tril_vec(sample, diagonal=-1)
        cholesky_to_corr_jac = onp.linalg.slogdet(
            jax.jacobian(_tril_cholesky_to_tril_corr)(sample_tril))[1]
        corr_log_prob.append(log_prob - cholesky_to_corr_jac)

    corr_log_prob = np.array(corr_log_prob)
    # test if they are constant
    assert_allclose(corr_log_prob,
                    np.broadcast_to(corr_log_prob[0], corr_log_prob.shape),
                    rtol=1e-6)

    if dimension == 2:
        # when concentration = 1, LKJ gives a uniform distribution over correlation matrix,
        # hence for the case dimension = 2,
        # density of a correlation matrix will be Uniform(-1, 1) = 0.5.
        # In addition, jacobian of the transformation from cholesky -> corr is 1 (hence its
        # log value is 0) because the off-diagonal lower triangular element does not change
        # in the transform.
        # So target_log_prob = log(0.5)
        assert_allclose(corr_log_prob[0], np.log(0.5), rtol=1e-6)
Beispiel #4
0
    def log_abs_det_jacobian(self, x, y, intermediates=None):
        # NB: because domain and codomain are two spaces with different dimensions, determinant of
        # Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the
        # flatten lower triangular part of `y`.

        # stick_breaking_logdet = log(y / r) = log(z_cumprod)  (modulo right shifted)
        z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1)
        # by taking diagonal=-2, we don't need to shift z_cumprod to the right
        # NB: diagonal=-2 works fine for (2 x 2) matrix, where we get an empty array
        z1m_cumprod_tril = matrix_to_tril_vec(z1m_cumprod, diagonal=-2)
        stick_breaking_logdet = 0.5 * jnp.sum(jnp.log(z1m_cumprod_tril), axis=-1)

        tanh_logdet = -2 * jnp.sum(x + softplus(-2 * x) - jnp.log(2.), axis=-1)
        return stick_breaking_logdet + tanh_logdet
Beispiel #5
0
    def __init__(self,
                 dimension,
                 concentration=1.,
                 sample_method='onion',
                 validate_args=None):
        if dimension < 2:
            raise ValueError("Dimension must be greater than or equal to 2.")
        self.dimension = dimension
        self.concentration = concentration
        batch_shape = np.shape(concentration)
        event_shape = (dimension, dimension)

        # We construct base distributions to generate samples for each method.
        # The purpose of this base distribution is to generate a distribution for
        # correlation matrices which is propotional to `det(M)^{\eta - 1}`.
        # (note that this is not a unique way to define base distribution)
        # Both of the following methods have marginal distribution of each off-diagonal
        # element of sampled correlation matrices is Beta(eta + (D-2) / 2, eta + (D-2) / 2)
        # (up to a linear transform: x -> 2x - 1)
        Dm1 = self.dimension - 1
        marginal_concentration = concentration + 0.5 * (self.dimension - 2)
        offset = 0.5 * np.arange(Dm1)
        if sample_method == 'onion':
            # The following construction follows from the algorithm in Section 3.2 of [1]:
            # NB: in [1], the method for case k > 1 can also work for the case k = 1.
            beta_concentration0 = np.expand_dims(marginal_concentration,
                                                 axis=-1) - offset
            beta_concentration1 = offset + 0.5
            self._beta = Beta(beta_concentration1, beta_concentration0)
        elif sample_method == 'cvine':
            # The following construction follows from the algorithm in Section 2.4 of [1]:
            # offset_tril is [0, 1, 1, 2, 2, 2,...] / 2
            offset_tril = matrix_to_tril_vec(
                np.broadcast_to(offset, (Dm1, Dm1)))
            beta_concentration = np.expand_dims(marginal_concentration,
                                                axis=-1) - offset_tril
            self._beta = Beta(beta_concentration, beta_concentration)
        else:
            raise ValueError("`method` should be one of 'cvine' or 'onion'.")
        self.sample_method = sample_method

        super(LKJCholesky, self).__init__(batch_shape=batch_shape,
                                          event_shape=event_shape,
                                          validate_args=validate_args)
Beispiel #6
0
def test_block_neural_arn(input_dim, hidden_factors, residual, batch_shape):
    arn_init, arn = BlockNeuralAutoregressiveNN(input_dim, hidden_factors, residual)

    rng = random.PRNGKey(0)
    input_shape = batch_shape + (input_dim,)
    out_shape, init_params = arn_init(rng, input_shape)
    assert out_shape == input_shape

    x = random.normal(random.PRNGKey(1), input_shape)
    output, logdet = arn(init_params, x)
    assert output.shape == input_shape
    assert logdet.shape == input_shape

    if len(batch_shape) == 1:
        jac = vmap(jacfwd(lambda x: arn(init_params, x)[0]))(x)
    else:
        jac = jacfwd(lambda x: arn(init_params, x)[0])(x)
    assert_allclose(logdet.sum(-1), np.linalg.slogdet(jac)[1], rtol=1e-6)

    # make sure jacobians are lower triangular
    assert onp.sum(onp.abs(onp.triu(jac, k=1))) == 0.0
    assert onp.all(onp.abs(matrix_to_tril_vec(jac)) > 0)
Beispiel #7
0
def test_flows(flow_class, flow_args, input_dim, batch_shape):
    transform = flow_class(*flow_args)
    x = random.normal(random.PRNGKey(0), batch_shape + (input_dim, ))

    # test inverse is correct
    y = transform(x)
    try:
        inv = transform.inv(y)
        assert_allclose(x, inv, atol=1e-5)
    except NotImplementedError:
        pass

    # test jacobian shape
    actual = transform.log_abs_det_jacobian(x, y)
    assert onp.shape(actual) == batch_shape

    if batch_shape == ():
        # make sure transform.log_abs_det_jacobian is correct
        jac = jacfwd(transform)(x)
        expected = onp.linalg.slogdet(jac)[1]
        assert_allclose(actual, expected, atol=1e-5)

        # make sure jacobian is triangular, first permute jacobian as necessary
        if isinstance(transform, InverseAutoregressiveTransform):
            permuted_jac = onp.zeros(jac.shape)
            _, rng_key_perm = random.split(random.PRNGKey(0))
            perm = random.shuffle(rng_key_perm, onp.arange(input_dim))

            for j in range(input_dim):
                for k in range(input_dim):
                    permuted_jac[j, k] = jac[perm[j], perm[k]]

            jac = permuted_jac

        assert onp.sum(onp.abs(onp.triu(jac, 1))) == 0.00
        assert onp.all(onp.abs(matrix_to_tril_vec(jac)) > 0)
Beispiel #8
0
 def _inverse(self, y):
     z = matrix_to_tril_vec(y, diagonal=-1)
     diag = _softplus_inv(jnp.diagonal(y, axis1=-2, axis2=-1))
     return jnp.concatenate([z, diag], axis=-1)
Beispiel #9
0
 def _inverse(self, y):
     z = matrix_to_tril_vec(y, diagonal=-1)
     return jnp.concatenate([z, jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1))], axis=-1)
Beispiel #10
0
def _tril_cholesky_to_tril_corr(x):
    w = vec_to_tril_matrix(x, diagonal=-1)
    diag = np.sqrt(1 - np.sum(w**2, axis=-1))
    cholesky = w + np.expand_dims(diag, axis=-1) * np.identity(w.shape[-1])
    corr = np.matmul(cholesky, cholesky.T)
    return matrix_to_tril_vec(corr, diagonal=-1)
Beispiel #11
0
 def inv(self, y):
     z = matrix_to_tril_vec(y, diagonal=-1)
     return np.concatenate(
         [z, np.log(np.diagonal(y, axis1=-2, axis2=-1))], axis=-1)
def test_biject_to(constraint, shape):
    transform = biject_to(constraint)
    if transform.event_dim == 2:
        event_dim = 1  # actual dim of unconstrained domain
    else:
        event_dim = transform.event_dim
    if isinstance(constraint, constraints._Interval):
        assert transform.codomain.upper_bound == constraint.upper_bound
        assert transform.codomain.lower_bound == constraint.lower_bound
    elif isinstance(constraint, constraints._GreaterThan):
        assert transform.codomain.lower_bound == constraint.lower_bound
    if len(shape) < event_dim:
        return
    rng_key = random.PRNGKey(0)
    x = random.normal(rng_key, shape)
    y = transform(x)

    # test codomain
    batch_shape = shape if event_dim == 0 else shape[:-1]
    assert_array_equal(transform.codomain(y), np.ones(batch_shape, dtype=np.bool_))

    # test inv
    z = transform.inv(y)
    assert_allclose(x, z, atol=1e-6, rtol=1e-6)

    # test domain, currently all is constraints.real or constraints.real_vector
    assert_array_equal(transform.domain(z), np.ones(batch_shape))

    # test log_abs_det_jacobian
    actual = transform.log_abs_det_jacobian(x, y)
    assert np.shape(actual) == batch_shape
    if len(shape) == event_dim:
        if constraint is constraints.simplex:
            expected = onp.linalg.slogdet(jax.jacobian(transform)(x)[:-1, :])[1]
            inv_expected = onp.linalg.slogdet(jax.jacobian(transform.inv)(y)[:, :-1])[1]
        elif constraint is constraints.ordered_vector:
            expected = onp.linalg.slogdet(jax.jacobian(transform)(x))[1]
            inv_expected = onp.linalg.slogdet(jax.jacobian(transform.inv)(y))[1]
        elif constraint in [constraints.corr_cholesky, constraints.corr_matrix]:
            vec_transform = lambda x: matrix_to_tril_vec(transform(x), diagonal=-1)  # noqa: E731
            y_tril = matrix_to_tril_vec(y, diagonal=-1)

            def inv_vec_transform(y):
                matrix = vec_to_tril_matrix(y, diagonal=-1)
                if constraint is constraints.corr_matrix:
                    # fill the upper triangular part
                    matrix = matrix + np.swapaxes(matrix, -2, -1) + np.identity(matrix.shape[-1])
                return transform.inv(matrix)

            expected = onp.linalg.slogdet(jax.jacobian(vec_transform)(x))[1]
            inv_expected = onp.linalg.slogdet(jax.jacobian(inv_vec_transform)(y_tril))[1]
        elif constraint in [constraints.lower_cholesky, constraints.positive_definite]:
            vec_transform = lambda x: matrix_to_tril_vec(transform(x))  # noqa: E731
            y_tril = matrix_to_tril_vec(y)

            def inv_vec_transform(y):
                matrix = vec_to_tril_matrix(y)
                if constraint is constraints.positive_definite:
                    # fill the upper triangular part
                    matrix = matrix + np.swapaxes(matrix, -2, -1) - np.diag(np.diag(matrix))
                return transform.inv(matrix)

            expected = onp.linalg.slogdet(jax.jacobian(vec_transform)(x))[1]
            inv_expected = onp.linalg.slogdet(jax.jacobian(inv_vec_transform)(y_tril))[1]
        else:
            expected = np.log(np.abs(grad(transform)(x)))
            inv_expected = np.log(np.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6, rtol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6, rtol=1e-6)
Beispiel #13
0
 def _inverse(self, y):
     diag = jnp.diagonal(y, axis1=-2, axis2=-1)
     z = matrix_to_tril_vec(y / diag[..., None], diagonal=-1)
     return jnp.concatenate([z, _softplus_inv(diag)], axis=-1)