Beispiel #1
0
def test_multinomial_stats(p, n):
    rng_key = random.PRNGKey(0)
    z = multinomial(rng_key, p, n)
    n = float(n) if isinstance(n, Number) else jnp.expand_dims(
        n.astype(p.dtype), -1)
    p = jnp.broadcast_to(p, z.shape)
    assert_allclose(z / n, p, atol=0.01)
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):
    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size) - 2
    elif isinstance(constraint, constraints._GreaterThan):
        return constraint.lower_bound - np.exp(random.normal(key, size))
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = np.broadcast_to(constraint.lower_bound, size)
        return random.randint(key, size, lower_bound - 1, lower_bound)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound - poisson(key, 5, shape=size)
    elif isinstance(constraint, constraints._Interval):
        upper_bound = np.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key,
                              size,
                              minval=upper_bound,
                              maxval=upper_bound + 1.)
    elif isinstance(constraint, constraints._Real):
        return lax.full(size, np.nan)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=np.ones(
            (size[-1], )), size=size[:-1]) + 1e-2
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key,
                           p=np.ones((n, )) / n,
                           n=constraint.upper_bound,
                           shape=size[:-1]) + 1
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key,
                           size[:-2] + (size[-1] * (size[-1] - 1) // 2, ),
                           minval=-1,
                           maxval=1)) + 1e-2
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint))
Beispiel #3
0
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)):
    eps = 1e-6

    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size)
    elif isinstance(constraint, constraints._GreaterThan):
        return np.exp(random.normal(key, size)) + constraint.lower_bound + eps
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = np.broadcast_to(constraint.lower_bound, size)
        upper_bound = np.broadcast_to(constraint.upper_bound, size)
        return random.randint(key, size, lower_bound, upper_bound + 1)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound + poisson(key, 5, shape=size)
    elif isinstance(constraint, constraints._Interval):
        lower_bound = np.broadcast_to(constraint.lower_bound, size)
        upper_bound = np.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=lower_bound, maxval=upper_bound)
    elif isinstance(constraint, constraints._Real):
        return random.normal(key, size)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=np.ones((size[-1],)), size=size[:-1])
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key, p=np.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1])
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,),
                           minval=-1, maxval=1))
    elif isinstance(constraint, constraints._LowerCholesky):
        return np.tril(random.uniform(key, size))
    elif isinstance(constraint, constraints._PositiveDefinite):
        x = random.normal(key, size)
        return np.matmul(x, np.swapaxes(x, -2, -1))
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint))
Beispiel #4
0
def test_multinomial_inhomogeneous(n, device_array):
    if device_array:
        n = jnp.asarray(n)

    p = jnp.array([0.5, 0.5])
    x = multinomial(random.PRNGKey(0), p, n)
    assert x.shape == jnp.shape(n) + jnp.shape(p)
    assert_allclose(x.sum(-1), n)
Beispiel #5
0
 def sample(self, key, size=()):
     return multinomial(key,
                        self.probs,
                        self.total_count,
                        shape=size + self.batch_shape)
def test_multinomial_shape(p, shape):
    rng_key = random.PRNGKey(0)
    n = 10000
    expected_shape = lax.broadcast_shapes(p.shape[:-1], shape) + p.shape[-1:]
    assert jnp.shape(multinomial(rng_key, p, n, shape)) == expected_shape
Beispiel #7
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     return multinomial(key,
                        self.probs,
                        self.total_count,
                        shape=sample_shape + self.batch_shape)