def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)): eps = 1e-6 if isinstance(constraint, constraints._Boolean): return random.bernoulli(key, shape=size) elif isinstance(constraint, constraints._GreaterThan): return np.exp(random.normal(key, size)) + constraint.lower_bound + eps elif isinstance(constraint, constraints._IntegerInterval): lower_bound = np.broadcast_to(constraint.lower_bound, size) upper_bound = np.broadcast_to(constraint.upper_bound, size) return random.randint(key, size, lower_bound, upper_bound + 1) elif isinstance(constraint, constraints._IntegerGreaterThan): return constraint.lower_bound + poisson(key, 5, shape=size) elif isinstance(constraint, constraints._Interval): lower_bound = np.broadcast_to(constraint.lower_bound, size) upper_bound = np.broadcast_to(constraint.upper_bound, size) return random.uniform(key, size, minval=lower_bound, maxval=upper_bound) elif isinstance(constraint, (constraints._Real, constraints._RealVector)): return random.normal(key, size) elif isinstance(constraint, constraints._Simplex): return osp.dirichlet.rvs(alpha=np.ones((size[-1], )), size=size[:-1]) elif isinstance(constraint, constraints._Multinomial): n = size[-1] return multinomial(key, p=np.ones((n, )) / n, n=constraint.upper_bound, shape=size[:-1]) elif isinstance(constraint, constraints._CorrCholesky): return signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2, ), minval=-1, maxval=1)) elif isinstance(constraint, constraints._CorrMatrix): cholesky = signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2, ), minval=-1, maxval=1)) return np.matmul(cholesky, np.swapaxes(cholesky, -2, -1)) elif isinstance(constraint, constraints._LowerCholesky): return np.tril(random.uniform(key, size)) elif isinstance(constraint, constraints._PositiveDefinite): x = random.normal(key, size) return np.matmul(x, np.swapaxes(x, -2, -1)) elif isinstance(constraint, constraints._OrderedVector): x = np.cumsum(random.exponential(key, size), -1) return x - random.normal(key, size[:-1]) else: raise NotImplementedError('{} not implemented.'.format(constraint))
def test_log_prob_LKJCholesky(dimension, concentration): # We will test against the fact that LKJCorrCholesky can be seen as a # TransformedDistribution with base distribution is a distribution of partial # correlations in C-vine method (modulo an affine transform to change domain from (0, 1) # to (1, 0)) and transform is a signed stick-breaking process. d = dist.LKJCholesky(dimension, concentration, sample_method="cvine") beta_sample = d._beta.sample(random.PRNGKey(0)) beta_log_prob = np.sum(d._beta.log_prob(beta_sample)) partial_correlation = 2 * beta_sample - 1 affine_logdet = beta_sample.shape[-1] * np.log(2) sample = signed_stick_breaking_tril(partial_correlation) # compute signed stick breaking logdet inv_tanh = lambda t: np.log((1 + t) / (1 - t)) / 2 # noqa: E731 inv_tanh_logdet = np.sum(np.log(vmap(grad(inv_tanh))(partial_correlation))) unconstrained = inv_tanh(partial_correlation) corr_cholesky_logdet = biject_to( constraints.corr_cholesky).log_abs_det_jacobian( unconstrained, sample, ) signed_stick_breaking_logdet = corr_cholesky_logdet + inv_tanh_logdet actual_log_prob = d.log_prob(sample) expected_log_prob = beta_log_prob - affine_logdet - signed_stick_breaking_logdet assert_allclose(actual_log_prob, expected_log_prob, rtol=1e-5) assert_allclose(jax.jit(d.log_prob)(sample), d.log_prob(sample), atol=1e-7)
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)): if isinstance(constraint, constraints._Boolean): return random.bernoulli(key, shape=size) - 2 elif isinstance(constraint, constraints._GreaterThan): return constraint.lower_bound - np.exp(random.normal(key, size)) elif isinstance(constraint, constraints._IntegerInterval): lower_bound = np.broadcast_to(constraint.lower_bound, size) return random.randint(key, size, lower_bound - 1, lower_bound) elif isinstance(constraint, constraints._IntegerGreaterThan): return constraint.lower_bound - poisson(key, 5, shape=size) elif isinstance(constraint, constraints._Interval): upper_bound = np.broadcast_to(constraint.upper_bound, size) return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.) elif isinstance(constraint, constraints._Real): return lax.full(size, np.nan) elif isinstance(constraint, constraints._Simplex): return osp.dirichlet.rvs(alpha=np.ones( (size[-1], )), size=size[:-1]) + 1e-2 elif isinstance(constraint, constraints._Multinomial): n = size[-1] return multinomial(key, p=np.ones((n, )) / n, n=constraint.upper_bound, shape=size[:-1]) + 1 elif isinstance(constraint, constraints._CorrCholesky): return signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2, ), minval=-1, maxval=1)) + 1e-2 else: raise NotImplementedError('{} not implemented.'.format(constraint))
def _cvine(self, key, size): # C-vine method first uses beta_dist to generate partial correlations, # then apply signed stick breaking to transform to cholesky factor. # Here is an attempt to prove that using signed stick breaking to # generate correlation matrices is the same as the C-vine method in [1] # for the entry r_32. # # With notations follow from [1], we define # p: partial correlation matrix, # c: cholesky factor, # r: correlation matrix. # From recursive formula (2) in [1], we have # r_32 = p_32 * sqrt{(1 - p_21^2)*(1 - p_31^2)} + p_21 * p_31 =: I # On the other hand, signed stick breaking process gives: # l_21 = p_21, l_31 = p_31, l_22 = sqrt(1 - p_21^2), l_32 = p_32 * sqrt(1 - p_31^2) # r_32 = l_21 * l_31 + l_22 * l_32 # = p_21 * p_31 + p_32 * sqrt{(1 - p_21^2)*(1 - p_31^2)} = I beta_sample = self._beta.sample(key, size) partial_correlation = 2 * beta_sample - 1 # scale to domain to (-1, 1) return signed_stick_breaking_tril(partial_correlation)
def __call__(self, x): # we interchange step 1 and step 2.a for a better performance t = jnp.tanh(x) return signed_stick_breaking_tril(t)