def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):
    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size) - 2
    elif isinstance(constraint, constraints._GreaterThan):
        return constraint.lower_bound - np.exp(random.normal(key, size))
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = np.broadcast_to(constraint.lower_bound, size)
        return random.randint(key, size, lower_bound - 1, lower_bound)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound - poisson(key, 5, shape=size)
    elif isinstance(constraint, constraints._Interval):
        upper_bound = np.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key,
                              size,
                              minval=upper_bound,
                              maxval=upper_bound + 1.)
    elif isinstance(constraint, constraints._Real):
        return lax.full(size, np.nan)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=np.ones(
            (size[-1], )), size=size[:-1]) + 1e-2
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key,
                           p=np.ones((n, )) / n,
                           n=constraint.upper_bound,
                           shape=size[:-1]) + 1
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key,
                           size[:-2] + (size[-1] * (size[-1] - 1) // 2, ),
                           minval=-1,
                           maxval=1)) + 1e-2
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint))
Esempio n. 2
0
def test_poisson():
    mu = rate = 1000
    N = 2**18

    key = random.PRNGKey(64)
    B = poisson(key, rate=rate, shape=(N, ))
    assert_allclose(B.mean(), mu, rtol=0.001)
Esempio n. 3
0
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)):
    eps = 1e-6

    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size)
    elif isinstance(constraint, constraints._GreaterThan):
        return np.exp(random.normal(key, size)) + constraint.lower_bound + eps
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = np.broadcast_to(constraint.lower_bound, size)
        upper_bound = np.broadcast_to(constraint.upper_bound, size)
        return random.randint(key, size, lower_bound, upper_bound + 1)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound + poisson(key, 5, shape=size)
    elif isinstance(constraint, constraints._Interval):
        lower_bound = np.broadcast_to(constraint.lower_bound, size)
        upper_bound = np.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=lower_bound, maxval=upper_bound)
    elif isinstance(constraint, constraints._Real):
        return random.normal(key, size)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=np.ones((size[-1],)), size=size[:-1])
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key, p=np.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1])
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,),
                           minval=-1, maxval=1))
    elif isinstance(constraint, constraints._LowerCholesky):
        return np.tril(random.uniform(key, size))
    elif isinstance(constraint, constraints._PositiveDefinite):
        x = random.normal(key, size)
        return np.matmul(x, np.swapaxes(x, -2, -1))
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint))
Esempio n. 4
0
 def sample(self, key, size=()):
     return poisson(key, self.rate, shape=size + self.batch_shape)
Esempio n. 5
0
 def sample(self, key, sample_shape=()):
     key_bern, key_poisson = random.split(key)
     shape = sample_shape + self.batch_shape
     mask = random.bernoulli(key_bern, self.gate, shape)
     samples = poisson(key_poisson, self.rate, shape)
     return np.where(mask, 0, samples)