Exemplo n.º 1
0
def test_biject_to(constraint, shape):
    transform = biject_to(constraint)
    if isinstance(constraint, constraints._Interval):
        assert transform.codomain.upper_bound == constraint.upper_bound
        assert transform.codomain.lower_bound == constraint.lower_bound
    elif isinstance(constraint, constraints._GreaterThan):
        assert transform.codomain.lower_bound == constraint.lower_bound
    if len(shape) < transform.event_dim:
        return
    rng = random.PRNGKey(0)
    x = random.normal(rng, shape)
    y = transform(x)

    # test codomain
    batch_shape = shape if transform.event_dim == 0 else shape[:-1]
    assert_array_equal(transform.codomain(y),
                       np.ones(batch_shape, dtype=np.bool_))

    # test inv
    z = transform.inv(y)
    assert_allclose(x, z, atol=1e-6, rtol=1e-6)

    # test domain, currently all is constraints.real or constraints.real_vector
    assert_array_equal(transform.domain(z), np.ones(batch_shape))

    # test log_abs_det_jacobian
    actual = transform.log_abs_det_jacobian(x, y)
    assert np.shape(actual) == batch_shape
    if len(shape) == transform.event_dim:
        if constraint is constraints.simplex:
            expected = onp.linalg.slogdet(
                jax.jacobian(transform)(x)[:-1, :])[1]
            inv_expected = onp.linalg.slogdet(
                jax.jacobian(transform.inv)(y)[:, :-1])[1]
        elif constraint is constraints.corr_cholesky:
            vec_transform = lambda x: matrix_to_tril_vec(
                transform(x), diagonal=-1)  # noqa: E731
            y_tril = matrix_to_tril_vec(y, diagonal=-1)
            inv_vec_transform = lambda x: transform.inv(
                vec_to_tril_matrix(x, diagonal=-1))  # noqa: E731
            expected = onp.linalg.slogdet(jax.jacobian(vec_transform)(x))[1]
            inv_expected = onp.linalg.slogdet(
                jax.jacobian(inv_vec_transform)(y_tril))[1]
        elif constraint is constraints.lower_cholesky:
            vec_transform = lambda x: matrix_to_tril_vec(transform(x)
                                                         )  # noqa: E731
            y_tril = matrix_to_tril_vec(y)
            inv_vec_transform = lambda x: transform.inv(vec_to_tril_matrix(x)
                                                        )  # noqa: E731
            expected = onp.linalg.slogdet(jax.jacobian(vec_transform)(x))[1]
            inv_expected = onp.linalg.slogdet(
                jax.jacobian(inv_vec_transform)(y_tril))[1]
        else:
            expected = np.log(np.abs(grad(transform)(x)))
            inv_expected = np.log(np.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6)
Exemplo n.º 2
0
    def _onion(self, key, size):
        key_beta, key_normal = random.split(key)
        # Now we generate w term in Algorithm 3.2 of [1].
        beta_sample = self._beta.sample(key_beta, size)
        # The following Normal distribution is used to create a uniform distribution on
        # a hypershere (ref: http://mathworld.wolfram.com/HyperspherePointPicking.html)
        normal_sample = random.normal(key_normal,
                                      shape=size + self.batch_shape +
                                      (self.dimension *
                                       (self.dimension - 1) // 2, ))
        normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0)
        u_hypershere = normal_sample / np.linalg.norm(
            normal_sample, axis=-1, keepdims=True)
        w = np.expand_dims(np.sqrt(beta_sample), axis=-1) * u_hypershere

        # put w into the off-diagonal triangular part
        cholesky = ops.index_add(
            np.zeros(size + self.batch_shape + self.event_shape),
            ops.index[..., 1:, :-1], w)
        # correct the diagonal
        # NB: we clip due to numerical precision
        diag = np.sqrt(np.clip(1 - np.sum(cholesky**2, axis=-1), a_min=0.))
        cholesky = cholesky + np.expand_dims(diag, axis=-1) * np.identity(
            self.dimension)
        return cholesky
Exemplo n.º 3
0
 def inv_vec_transform(y):
     matrix = vec_to_tril_matrix(y)
     if constraint is constraints.positive_definite:
         # fill the upper triangular part
         matrix = matrix + np.swapaxes(matrix, -2, -1) - np.diag(
             np.diag(matrix))
     return transform.inv(matrix)
Exemplo n.º 4
0
 def inv_vec_transform(y):
     matrix = vec_to_tril_matrix(y, diagonal=-1)
     if constraint is constraints.corr_matrix:
         # fill the upper triangular part
         matrix = matrix + np.swapaxes(
             matrix, -2, -1) + np.identity(matrix.shape[-1])
     return transform.inv(matrix)
Exemplo n.º 5
0
def test_vec_to_tril_matrix(shape, diagonal):
    rng_key = random.PRNGKey(0)
    x = random.normal(rng_key, shape)
    actual = vec_to_tril_matrix(x, diagonal)
    expected = np.zeros(shape[:-1] + actual.shape[-2:])
    tril_idxs = np.tril_indices(expected.shape[-1], diagonal)
    expected[..., tril_idxs[0], tril_idxs[1]] = x
    assert_allclose(actual, expected)
Exemplo n.º 6
0
 def __call__(self, x):
     n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2)
     z = vec_to_tril_matrix(x[..., :-n], diagonal=-1)
     diag = softplus(x[..., -n:])
     return z + jnp.expand_dims(diag, axis=-1) * jnp.identity(n)
Exemplo n.º 7
0
def _tril_cholesky_to_tril_corr(x):
    w = vec_to_tril_matrix(x, diagonal=-1)
    diag = np.sqrt(1 - np.sum(w**2, axis=-1))
    cholesky = w + np.expand_dims(diag, axis=-1) * np.identity(w.shape[-1])
    corr = np.matmul(cholesky, cholesky.T)
    return matrix_to_tril_vec(corr, diagonal=-1)
Exemplo n.º 8
0
def BlockMaskedDense(num_blocks,
                     in_factor,
                     out_factor,
                     bias=True,
                     W_init=glorot_uniform()):
    """
    Module that implements a linear layer with block matrices with positive diagonal blocks.
    Moreover, it uses Weight Normalization (https://arxiv.org/abs/1602.07868) for stability.

    :param int num_blocks: Number of block matrices.
    :param int in_factor: number of rows in each block.
    :param int out_factor: number of columns in each block.
    :param W_init: initialization method for the weights.
    :return: an (`init_fn`, `update_fn`) pair.
    """
    input_dim, out_dim = num_blocks * in_factor, num_blocks * out_factor
    # construct mask_d, mask_o for formula (8) of Ref [1]
    # Diagonal block mask
    mask_d = np.identity(num_blocks)[..., None]
    mask_d = np.tile(mask_d,
                     (1, in_factor, out_factor)).reshape(input_dim, out_dim)
    # Off-diagonal block mask for upper triangular weight matrix
    mask_o = vec_to_tril_matrix(jnp.ones(num_blocks * (num_blocks - 1) // 2),
                                diagonal=-1).T[..., None]
    mask_o = jnp.tile(mask_o,
                      (1, in_factor, out_factor)).reshape(input_dim, out_dim)

    def init_fun(rng, input_shape):
        assert input_dim == input_shape[-1]
        *k1, k2, k3 = random.split(rng, num_blocks + 2)

        # Initialize each column block using W_init
        W = jnp.zeros((input_dim, out_dim))
        for i in range(num_blocks):
            W = ops.index_add(
                W, ops.index[:(i + 1) * in_factor,
                             i * out_factor:(i + 1) * out_factor],
                W_init(k1[i], ((i + 1) * in_factor, out_factor)))

        # initialize weight scale
        ws = jnp.log(uniform(1.)(k2, (out_dim, )))

        if bias:
            b = (uniform(1.)(k3, (out_dim, )) - 0.5) * (2 / jnp.sqrt(out_dim))
            params = (W, ws, b)
        else:
            params = (W, ws)
        return input_shape[:-1] + (out_dim, ), params

    def apply_fun(params, inputs, **kwargs):
        x, logdet = inputs
        if bias:
            W, ws, b = params
        else:
            W, ws = params

        # Form block weight matrix, making sure it's positive on diagonal!
        w = jnp.exp(W) * mask_d + W * mask_o

        # Compute norm of each column (i.e. each output features)
        w_norm = jnp.linalg.norm(w, axis=-2, keepdims=True)

        # Normalize weight and rescale
        w = jnp.exp(ws) * w / w_norm

        out = jnp.dot(x, w)
        if bias:
            out = out + b

        dense_logdet = ws + W - jnp.log(w_norm)
        # logdet of block diagonal
        dense_logdet = dense_logdet[mask_d.astype(bool)].reshape(
            num_blocks, in_factor, out_factor)
        if logdet is None:
            logdet = jnp.broadcast_to(dense_logdet,
                                      x.shape[:-1] + dense_logdet.shape)
        else:
            logdet = logmatmulexp(logdet, dense_logdet)
        return out, logdet

    return init_fun, apply_fun
Exemplo n.º 9
0
 def __call__(self, x):
     n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2)
     z = vec_to_tril_matrix(x[..., :-n], diagonal=-1)
     diag = softplus(x[..., -n:])
     return (z + jnp.identity(n)) * diag[..., None]