def test_biject_to(constraint, shape): transform = biject_to(constraint) if isinstance(constraint, constraints._Interval): assert transform.codomain.upper_bound == constraint.upper_bound assert transform.codomain.lower_bound == constraint.lower_bound elif isinstance(constraint, constraints._GreaterThan): assert transform.codomain.lower_bound == constraint.lower_bound if len(shape) < transform.event_dim: return rng = random.PRNGKey(0) x = random.normal(rng, shape) y = transform(x) # test codomain batch_shape = shape if transform.event_dim == 0 else shape[:-1] assert_array_equal(transform.codomain(y), np.ones(batch_shape, dtype=np.bool_)) # test inv z = transform.inv(y) assert_allclose(x, z, atol=1e-6, rtol=1e-6) # test domain, currently all is constraints.real or constraints.real_vector assert_array_equal(transform.domain(z), np.ones(batch_shape)) # test log_abs_det_jacobian actual = transform.log_abs_det_jacobian(x, y) assert np.shape(actual) == batch_shape if len(shape) == transform.event_dim: if constraint is constraints.simplex: expected = onp.linalg.slogdet( jax.jacobian(transform)(x)[:-1, :])[1] inv_expected = onp.linalg.slogdet( jax.jacobian(transform.inv)(y)[:, :-1])[1] elif constraint is constraints.corr_cholesky: vec_transform = lambda x: matrix_to_tril_vec( transform(x), diagonal=-1) # noqa: E731 y_tril = matrix_to_tril_vec(y, diagonal=-1) inv_vec_transform = lambda x: transform.inv( vec_to_tril_matrix(x, diagonal=-1)) # noqa: E731 expected = onp.linalg.slogdet(jax.jacobian(vec_transform)(x))[1] inv_expected = onp.linalg.slogdet( jax.jacobian(inv_vec_transform)(y_tril))[1] elif constraint is constraints.lower_cholesky: vec_transform = lambda x: matrix_to_tril_vec(transform(x) ) # noqa: E731 y_tril = matrix_to_tril_vec(y) inv_vec_transform = lambda x: transform.inv(vec_to_tril_matrix(x) ) # noqa: E731 expected = onp.linalg.slogdet(jax.jacobian(vec_transform)(x))[1] inv_expected = onp.linalg.slogdet( jax.jacobian(inv_vec_transform)(y_tril))[1] else: expected = np.log(np.abs(grad(transform)(x))) inv_expected = np.log(np.abs(grad(transform.inv)(y))) assert_allclose(actual, expected, atol=1e-6) assert_allclose(actual, -inv_expected, atol=1e-6)
def _onion(self, key, size): key_beta, key_normal = random.split(key) # Now we generate w term in Algorithm 3.2 of [1]. beta_sample = self._beta.sample(key_beta, size) # The following Normal distribution is used to create a uniform distribution on # a hypershere (ref: http://mathworld.wolfram.com/HyperspherePointPicking.html) normal_sample = random.normal(key_normal, shape=size + self.batch_shape + (self.dimension * (self.dimension - 1) // 2, )) normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0) u_hypershere = normal_sample / np.linalg.norm( normal_sample, axis=-1, keepdims=True) w = np.expand_dims(np.sqrt(beta_sample), axis=-1) * u_hypershere # put w into the off-diagonal triangular part cholesky = ops.index_add( np.zeros(size + self.batch_shape + self.event_shape), ops.index[..., 1:, :-1], w) # correct the diagonal # NB: we clip due to numerical precision diag = np.sqrt(np.clip(1 - np.sum(cholesky**2, axis=-1), a_min=0.)) cholesky = cholesky + np.expand_dims(diag, axis=-1) * np.identity( self.dimension) return cholesky
def inv_vec_transform(y): matrix = vec_to_tril_matrix(y) if constraint is constraints.positive_definite: # fill the upper triangular part matrix = matrix + np.swapaxes(matrix, -2, -1) - np.diag( np.diag(matrix)) return transform.inv(matrix)
def inv_vec_transform(y): matrix = vec_to_tril_matrix(y, diagonal=-1) if constraint is constraints.corr_matrix: # fill the upper triangular part matrix = matrix + np.swapaxes( matrix, -2, -1) + np.identity(matrix.shape[-1]) return transform.inv(matrix)
def test_vec_to_tril_matrix(shape, diagonal): rng_key = random.PRNGKey(0) x = random.normal(rng_key, shape) actual = vec_to_tril_matrix(x, diagonal) expected = np.zeros(shape[:-1] + actual.shape[-2:]) tril_idxs = np.tril_indices(expected.shape[-1], diagonal) expected[..., tril_idxs[0], tril_idxs[1]] = x assert_allclose(actual, expected)
def __call__(self, x): n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = softplus(x[..., -n:]) return z + jnp.expand_dims(diag, axis=-1) * jnp.identity(n)
def _tril_cholesky_to_tril_corr(x): w = vec_to_tril_matrix(x, diagonal=-1) diag = np.sqrt(1 - np.sum(w**2, axis=-1)) cholesky = w + np.expand_dims(diag, axis=-1) * np.identity(w.shape[-1]) corr = np.matmul(cholesky, cholesky.T) return matrix_to_tril_vec(corr, diagonal=-1)
def BlockMaskedDense(num_blocks, in_factor, out_factor, bias=True, W_init=glorot_uniform()): """ Module that implements a linear layer with block matrices with positive diagonal blocks. Moreover, it uses Weight Normalization (https://arxiv.org/abs/1602.07868) for stability. :param int num_blocks: Number of block matrices. :param int in_factor: number of rows in each block. :param int out_factor: number of columns in each block. :param W_init: initialization method for the weights. :return: an (`init_fn`, `update_fn`) pair. """ input_dim, out_dim = num_blocks * in_factor, num_blocks * out_factor # construct mask_d, mask_o for formula (8) of Ref [1] # Diagonal block mask mask_d = np.identity(num_blocks)[..., None] mask_d = np.tile(mask_d, (1, in_factor, out_factor)).reshape(input_dim, out_dim) # Off-diagonal block mask for upper triangular weight matrix mask_o = vec_to_tril_matrix(jnp.ones(num_blocks * (num_blocks - 1) // 2), diagonal=-1).T[..., None] mask_o = jnp.tile(mask_o, (1, in_factor, out_factor)).reshape(input_dim, out_dim) def init_fun(rng, input_shape): assert input_dim == input_shape[-1] *k1, k2, k3 = random.split(rng, num_blocks + 2) # Initialize each column block using W_init W = jnp.zeros((input_dim, out_dim)) for i in range(num_blocks): W = ops.index_add( W, ops.index[:(i + 1) * in_factor, i * out_factor:(i + 1) * out_factor], W_init(k1[i], ((i + 1) * in_factor, out_factor))) # initialize weight scale ws = jnp.log(uniform(1.)(k2, (out_dim, ))) if bias: b = (uniform(1.)(k3, (out_dim, )) - 0.5) * (2 / jnp.sqrt(out_dim)) params = (W, ws, b) else: params = (W, ws) return input_shape[:-1] + (out_dim, ), params def apply_fun(params, inputs, **kwargs): x, logdet = inputs if bias: W, ws, b = params else: W, ws = params # Form block weight matrix, making sure it's positive on diagonal! w = jnp.exp(W) * mask_d + W * mask_o # Compute norm of each column (i.e. each output features) w_norm = jnp.linalg.norm(w, axis=-2, keepdims=True) # Normalize weight and rescale w = jnp.exp(ws) * w / w_norm out = jnp.dot(x, w) if bias: out = out + b dense_logdet = ws + W - jnp.log(w_norm) # logdet of block diagonal dense_logdet = dense_logdet[mask_d.astype(bool)].reshape( num_blocks, in_factor, out_factor) if logdet is None: logdet = jnp.broadcast_to(dense_logdet, x.shape[:-1] + dense_logdet.shape) else: logdet = logmatmulexp(logdet, dense_logdet) return out, logdet return init_fun, apply_fun
def __call__(self, x): n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = softplus(x[..., -n:]) return (z + jnp.identity(n)) * diag[..., None]