def test_multinomial_stats(p, n): rng_key = random.PRNGKey(0) z = multinomial(rng_key, p, n) n = float(n) if isinstance(n, Number) else jnp.expand_dims( n.astype(p.dtype), -1) p = jnp.broadcast_to(p, z.shape) assert_allclose(z / n, p, atol=0.01)
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)): if isinstance(constraint, constraints._Boolean): return random.bernoulli(key, shape=size) - 2 elif isinstance(constraint, constraints._GreaterThan): return constraint.lower_bound - np.exp(random.normal(key, size)) elif isinstance(constraint, constraints._IntegerInterval): lower_bound = np.broadcast_to(constraint.lower_bound, size) return random.randint(key, size, lower_bound - 1, lower_bound) elif isinstance(constraint, constraints._IntegerGreaterThan): return constraint.lower_bound - poisson(key, 5, shape=size) elif isinstance(constraint, constraints._Interval): upper_bound = np.broadcast_to(constraint.upper_bound, size) return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.) elif isinstance(constraint, constraints._Real): return lax.full(size, np.nan) elif isinstance(constraint, constraints._Simplex): return osp.dirichlet.rvs(alpha=np.ones( (size[-1], )), size=size[:-1]) + 1e-2 elif isinstance(constraint, constraints._Multinomial): n = size[-1] return multinomial(key, p=np.ones((n, )) / n, n=constraint.upper_bound, shape=size[:-1]) + 1 elif isinstance(constraint, constraints._CorrCholesky): return signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2, ), minval=-1, maxval=1)) + 1e-2 else: raise NotImplementedError('{} not implemented.'.format(constraint))
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)): eps = 1e-6 if isinstance(constraint, constraints._Boolean): return random.bernoulli(key, shape=size) elif isinstance(constraint, constraints._GreaterThan): return np.exp(random.normal(key, size)) + constraint.lower_bound + eps elif isinstance(constraint, constraints._IntegerInterval): lower_bound = np.broadcast_to(constraint.lower_bound, size) upper_bound = np.broadcast_to(constraint.upper_bound, size) return random.randint(key, size, lower_bound, upper_bound + 1) elif isinstance(constraint, constraints._IntegerGreaterThan): return constraint.lower_bound + poisson(key, 5, shape=size) elif isinstance(constraint, constraints._Interval): lower_bound = np.broadcast_to(constraint.lower_bound, size) upper_bound = np.broadcast_to(constraint.upper_bound, size) return random.uniform(key, size, minval=lower_bound, maxval=upper_bound) elif isinstance(constraint, constraints._Real): return random.normal(key, size) elif isinstance(constraint, constraints._Simplex): return osp.dirichlet.rvs(alpha=np.ones((size[-1],)), size=size[:-1]) elif isinstance(constraint, constraints._Multinomial): n = size[-1] return multinomial(key, p=np.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) elif isinstance(constraint, constraints._CorrCholesky): return signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1)) elif isinstance(constraint, constraints._LowerCholesky): return np.tril(random.uniform(key, size)) elif isinstance(constraint, constraints._PositiveDefinite): x = random.normal(key, size) return np.matmul(x, np.swapaxes(x, -2, -1)) else: raise NotImplementedError('{} not implemented.'.format(constraint))
def test_multinomial_inhomogeneous(n, device_array): if device_array: n = jnp.asarray(n) p = jnp.array([0.5, 0.5]) x = multinomial(random.PRNGKey(0), p, n) assert x.shape == jnp.shape(n) + jnp.shape(p) assert_allclose(x.sum(-1), n)
def sample(self, key, size=()): return multinomial(key, self.probs, self.total_count, shape=size + self.batch_shape)
def test_multinomial_shape(p, shape): rng_key = random.PRNGKey(0) n = 10000 expected_shape = lax.broadcast_shapes(p.shape[:-1], shape) + p.shape[-1:] assert jnp.shape(multinomial(rng_key, p, n, shape)) == expected_shape
def sample(self, key, sample_shape=()): assert is_prng_key(key) return multinomial(key, self.probs, self.total_count, shape=sample_shape + self.batch_shape)