def test_biject_to(constraint, shape): transform = biject_to(constraint) if isinstance(constraint, constraints._Interval): assert transform.codomain.upper_bound == constraint.upper_bound assert transform.codomain.lower_bound == constraint.lower_bound elif isinstance(constraint, constraints._GreaterThan): assert transform.codomain.lower_bound == constraint.lower_bound if len(shape) < transform.event_dim: return rng = random.PRNGKey(0) x = random.normal(rng, shape) y = transform(x) # test codomain batch_shape = shape if transform.event_dim == 0 else shape[:-1] assert_array_equal(transform.codomain(y), np.ones(batch_shape, dtype=np.bool_)) # test inv z = transform.inv(y) assert_allclose(x, z, atol=1e-6, rtol=1e-6) # test domain, currently all is constraints.real or constraints.real_vector assert_array_equal(transform.domain(z), np.ones(batch_shape)) # test log_abs_det_jacobian actual = transform.log_abs_det_jacobian(x, y) assert np.shape(actual) == batch_shape if len(shape) == transform.event_dim: if constraint is constraints.simplex: expected = onp.linalg.slogdet( jax.jacobian(transform)(x)[:-1, :])[1] inv_expected = onp.linalg.slogdet( jax.jacobian(transform.inv)(y)[:, :-1])[1] elif constraint is constraints.corr_cholesky: vec_transform = lambda x: matrix_to_tril_vec( transform(x), diagonal=-1) # noqa: E731 y_tril = matrix_to_tril_vec(y, diagonal=-1) inv_vec_transform = lambda x: transform.inv( vec_to_tril_matrix(x, diagonal=-1)) # noqa: E731 expected = onp.linalg.slogdet(jax.jacobian(vec_transform)(x))[1] inv_expected = onp.linalg.slogdet( jax.jacobian(inv_vec_transform)(y_tril))[1] elif constraint is constraints.lower_cholesky: vec_transform = lambda x: matrix_to_tril_vec(transform(x) ) # noqa: E731 y_tril = matrix_to_tril_vec(y) inv_vec_transform = lambda x: transform.inv(vec_to_tril_matrix(x) ) # noqa: E731 expected = onp.linalg.slogdet(jax.jacobian(vec_transform)(x))[1] inv_expected = onp.linalg.slogdet( jax.jacobian(inv_vec_transform)(y_tril))[1] else: expected = np.log(np.abs(grad(transform)(x))) inv_expected = np.log(np.abs(grad(transform.inv)(y))) assert_allclose(actual, expected, atol=1e-6) assert_allclose(actual, -inv_expected, atol=1e-6)
def _inverse(self, y): # inverse stick-breaking z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1) pad_width = [(0, 0)] * y.ndim pad_width[-1] = (1, 0) z1m_cumprod_shifted = jnp.pad(z1m_cumprod[..., :-1], pad_width, mode="constant", constant_values=1.) t = matrix_to_tril_vec(y, diagonal=-1) / jnp.sqrt( matrix_to_tril_vec(z1m_cumprod_shifted, diagonal=-1)) # inverse of tanh x = jnp.log((1 + t) / (1 - t)) / 2 return x
def test_log_prob_LKJCholesky_uniform(dimension): # When concentration=1, the distribution of correlation matrices is uniform. # We will test that fact here. d = dist.LKJCholesky(dimension=dimension, concentration=1) N = 5 corr_log_prob = [] for i in range(N): sample = d.sample(random.PRNGKey(i)) log_prob = d.log_prob(sample) sample_tril = matrix_to_tril_vec(sample, diagonal=-1) cholesky_to_corr_jac = onp.linalg.slogdet( jax.jacobian(_tril_cholesky_to_tril_corr)(sample_tril))[1] corr_log_prob.append(log_prob - cholesky_to_corr_jac) corr_log_prob = np.array(corr_log_prob) # test if they are constant assert_allclose(corr_log_prob, np.broadcast_to(corr_log_prob[0], corr_log_prob.shape), rtol=1e-6) if dimension == 2: # when concentration = 1, LKJ gives a uniform distribution over correlation matrix, # hence for the case dimension = 2, # density of a correlation matrix will be Uniform(-1, 1) = 0.5. # In addition, jacobian of the transformation from cholesky -> corr is 1 (hence its # log value is 0) because the off-diagonal lower triangular element does not change # in the transform. # So target_log_prob = log(0.5) assert_allclose(corr_log_prob[0], np.log(0.5), rtol=1e-6)
def log_abs_det_jacobian(self, x, y, intermediates=None): # NB: because domain and codomain are two spaces with different dimensions, determinant of # Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the # flatten lower triangular part of `y`. # stick_breaking_logdet = log(y / r) = log(z_cumprod) (modulo right shifted) z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1) # by taking diagonal=-2, we don't need to shift z_cumprod to the right # NB: diagonal=-2 works fine for (2 x 2) matrix, where we get an empty array z1m_cumprod_tril = matrix_to_tril_vec(z1m_cumprod, diagonal=-2) stick_breaking_logdet = 0.5 * jnp.sum(jnp.log(z1m_cumprod_tril), axis=-1) tanh_logdet = -2 * jnp.sum(x + softplus(-2 * x) - jnp.log(2.), axis=-1) return stick_breaking_logdet + tanh_logdet
def __init__(self, dimension, concentration=1., sample_method='onion', validate_args=None): if dimension < 2: raise ValueError("Dimension must be greater than or equal to 2.") self.dimension = dimension self.concentration = concentration batch_shape = np.shape(concentration) event_shape = (dimension, dimension) # We construct base distributions to generate samples for each method. # The purpose of this base distribution is to generate a distribution for # correlation matrices which is propotional to `det(M)^{\eta - 1}`. # (note that this is not a unique way to define base distribution) # Both of the following methods have marginal distribution of each off-diagonal # element of sampled correlation matrices is Beta(eta + (D-2) / 2, eta + (D-2) / 2) # (up to a linear transform: x -> 2x - 1) Dm1 = self.dimension - 1 marginal_concentration = concentration + 0.5 * (self.dimension - 2) offset = 0.5 * np.arange(Dm1) if sample_method == 'onion': # The following construction follows from the algorithm in Section 3.2 of [1]: # NB: in [1], the method for case k > 1 can also work for the case k = 1. beta_concentration0 = np.expand_dims(marginal_concentration, axis=-1) - offset beta_concentration1 = offset + 0.5 self._beta = Beta(beta_concentration1, beta_concentration0) elif sample_method == 'cvine': # The following construction follows from the algorithm in Section 2.4 of [1]: # offset_tril is [0, 1, 1, 2, 2, 2,...] / 2 offset_tril = matrix_to_tril_vec( np.broadcast_to(offset, (Dm1, Dm1))) beta_concentration = np.expand_dims(marginal_concentration, axis=-1) - offset_tril self._beta = Beta(beta_concentration, beta_concentration) else: raise ValueError("`method` should be one of 'cvine' or 'onion'.") self.sample_method = sample_method super(LKJCholesky, self).__init__(batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args)
def test_block_neural_arn(input_dim, hidden_factors, residual, batch_shape): arn_init, arn = BlockNeuralAutoregressiveNN(input_dim, hidden_factors, residual) rng = random.PRNGKey(0) input_shape = batch_shape + (input_dim,) out_shape, init_params = arn_init(rng, input_shape) assert out_shape == input_shape x = random.normal(random.PRNGKey(1), input_shape) output, logdet = arn(init_params, x) assert output.shape == input_shape assert logdet.shape == input_shape if len(batch_shape) == 1: jac = vmap(jacfwd(lambda x: arn(init_params, x)[0]))(x) else: jac = jacfwd(lambda x: arn(init_params, x)[0])(x) assert_allclose(logdet.sum(-1), np.linalg.slogdet(jac)[1], rtol=1e-6) # make sure jacobians are lower triangular assert onp.sum(onp.abs(onp.triu(jac, k=1))) == 0.0 assert onp.all(onp.abs(matrix_to_tril_vec(jac)) > 0)
def test_flows(flow_class, flow_args, input_dim, batch_shape): transform = flow_class(*flow_args) x = random.normal(random.PRNGKey(0), batch_shape + (input_dim, )) # test inverse is correct y = transform(x) try: inv = transform.inv(y) assert_allclose(x, inv, atol=1e-5) except NotImplementedError: pass # test jacobian shape actual = transform.log_abs_det_jacobian(x, y) assert onp.shape(actual) == batch_shape if batch_shape == (): # make sure transform.log_abs_det_jacobian is correct jac = jacfwd(transform)(x) expected = onp.linalg.slogdet(jac)[1] assert_allclose(actual, expected, atol=1e-5) # make sure jacobian is triangular, first permute jacobian as necessary if isinstance(transform, InverseAutoregressiveTransform): permuted_jac = onp.zeros(jac.shape) _, rng_key_perm = random.split(random.PRNGKey(0)) perm = random.shuffle(rng_key_perm, onp.arange(input_dim)) for j in range(input_dim): for k in range(input_dim): permuted_jac[j, k] = jac[perm[j], perm[k]] jac = permuted_jac assert onp.sum(onp.abs(onp.triu(jac, 1))) == 0.00 assert onp.all(onp.abs(matrix_to_tril_vec(jac)) > 0)
def _inverse(self, y): z = matrix_to_tril_vec(y, diagonal=-1) diag = _softplus_inv(jnp.diagonal(y, axis1=-2, axis2=-1)) return jnp.concatenate([z, diag], axis=-1)
def _inverse(self, y): z = matrix_to_tril_vec(y, diagonal=-1) return jnp.concatenate([z, jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1))], axis=-1)
def _tril_cholesky_to_tril_corr(x): w = vec_to_tril_matrix(x, diagonal=-1) diag = np.sqrt(1 - np.sum(w**2, axis=-1)) cholesky = w + np.expand_dims(diag, axis=-1) * np.identity(w.shape[-1]) corr = np.matmul(cholesky, cholesky.T) return matrix_to_tril_vec(corr, diagonal=-1)
def inv(self, y): z = matrix_to_tril_vec(y, diagonal=-1) return np.concatenate( [z, np.log(np.diagonal(y, axis1=-2, axis2=-1))], axis=-1)
def test_biject_to(constraint, shape): transform = biject_to(constraint) if transform.event_dim == 2: event_dim = 1 # actual dim of unconstrained domain else: event_dim = transform.event_dim if isinstance(constraint, constraints._Interval): assert transform.codomain.upper_bound == constraint.upper_bound assert transform.codomain.lower_bound == constraint.lower_bound elif isinstance(constraint, constraints._GreaterThan): assert transform.codomain.lower_bound == constraint.lower_bound if len(shape) < event_dim: return rng_key = random.PRNGKey(0) x = random.normal(rng_key, shape) y = transform(x) # test codomain batch_shape = shape if event_dim == 0 else shape[:-1] assert_array_equal(transform.codomain(y), np.ones(batch_shape, dtype=np.bool_)) # test inv z = transform.inv(y) assert_allclose(x, z, atol=1e-6, rtol=1e-6) # test domain, currently all is constraints.real or constraints.real_vector assert_array_equal(transform.domain(z), np.ones(batch_shape)) # test log_abs_det_jacobian actual = transform.log_abs_det_jacobian(x, y) assert np.shape(actual) == batch_shape if len(shape) == event_dim: if constraint is constraints.simplex: expected = onp.linalg.slogdet(jax.jacobian(transform)(x)[:-1, :])[1] inv_expected = onp.linalg.slogdet(jax.jacobian(transform.inv)(y)[:, :-1])[1] elif constraint is constraints.ordered_vector: expected = onp.linalg.slogdet(jax.jacobian(transform)(x))[1] inv_expected = onp.linalg.slogdet(jax.jacobian(transform.inv)(y))[1] elif constraint in [constraints.corr_cholesky, constraints.corr_matrix]: vec_transform = lambda x: matrix_to_tril_vec(transform(x), diagonal=-1) # noqa: E731 y_tril = matrix_to_tril_vec(y, diagonal=-1) def inv_vec_transform(y): matrix = vec_to_tril_matrix(y, diagonal=-1) if constraint is constraints.corr_matrix: # fill the upper triangular part matrix = matrix + np.swapaxes(matrix, -2, -1) + np.identity(matrix.shape[-1]) return transform.inv(matrix) expected = onp.linalg.slogdet(jax.jacobian(vec_transform)(x))[1] inv_expected = onp.linalg.slogdet(jax.jacobian(inv_vec_transform)(y_tril))[1] elif constraint in [constraints.lower_cholesky, constraints.positive_definite]: vec_transform = lambda x: matrix_to_tril_vec(transform(x)) # noqa: E731 y_tril = matrix_to_tril_vec(y) def inv_vec_transform(y): matrix = vec_to_tril_matrix(y) if constraint is constraints.positive_definite: # fill the upper triangular part matrix = matrix + np.swapaxes(matrix, -2, -1) - np.diag(np.diag(matrix)) return transform.inv(matrix) expected = onp.linalg.slogdet(jax.jacobian(vec_transform)(x))[1] inv_expected = onp.linalg.slogdet(jax.jacobian(inv_vec_transform)(y_tril))[1] else: expected = np.log(np.abs(grad(transform)(x))) inv_expected = np.log(np.abs(grad(transform.inv)(y))) assert_allclose(actual, expected, atol=1e-6, rtol=1e-6) assert_allclose(actual, -inv_expected, atol=1e-6, rtol=1e-6)
def _inverse(self, y): diag = jnp.diagonal(y, axis1=-2, axis2=-1) z = matrix_to_tril_vec(y / diag[..., None], diagonal=-1) return jnp.concatenate([z, _softplus_inv(diag)], axis=-1)