def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)): if isinstance(constraint, constraints._Boolean): return random.bernoulli(key, shape=size) - 2 elif isinstance(constraint, constraints._GreaterThan): return constraint.lower_bound - np.exp(random.normal(key, size)) elif isinstance(constraint, constraints._IntegerInterval): lower_bound = np.broadcast_to(constraint.lower_bound, size) return random.randint(key, size, lower_bound - 1, lower_bound) elif isinstance(constraint, constraints._IntegerGreaterThan): return constraint.lower_bound - poisson(key, 5, shape=size) elif isinstance(constraint, constraints._Interval): upper_bound = np.broadcast_to(constraint.upper_bound, size) return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.) elif isinstance(constraint, constraints._Real): return lax.full(size, np.nan) elif isinstance(constraint, constraints._Simplex): return osp.dirichlet.rvs(alpha=np.ones( (size[-1], )), size=size[:-1]) + 1e-2 elif isinstance(constraint, constraints._Multinomial): n = size[-1] return multinomial(key, p=np.ones((n, )) / n, n=constraint.upper_bound, shape=size[:-1]) + 1 elif isinstance(constraint, constraints._CorrCholesky): return signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2, ), minval=-1, maxval=1)) + 1e-2 else: raise NotImplementedError('{} not implemented.'.format(constraint))
def test_poisson(): mu = rate = 1000 N = 2**18 key = random.PRNGKey(64) B = poisson(key, rate=rate, shape=(N, )) assert_allclose(B.mean(), mu, rtol=0.001)
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)): eps = 1e-6 if isinstance(constraint, constraints._Boolean): return random.bernoulli(key, shape=size) elif isinstance(constraint, constraints._GreaterThan): return np.exp(random.normal(key, size)) + constraint.lower_bound + eps elif isinstance(constraint, constraints._IntegerInterval): lower_bound = np.broadcast_to(constraint.lower_bound, size) upper_bound = np.broadcast_to(constraint.upper_bound, size) return random.randint(key, size, lower_bound, upper_bound + 1) elif isinstance(constraint, constraints._IntegerGreaterThan): return constraint.lower_bound + poisson(key, 5, shape=size) elif isinstance(constraint, constraints._Interval): lower_bound = np.broadcast_to(constraint.lower_bound, size) upper_bound = np.broadcast_to(constraint.upper_bound, size) return random.uniform(key, size, minval=lower_bound, maxval=upper_bound) elif isinstance(constraint, constraints._Real): return random.normal(key, size) elif isinstance(constraint, constraints._Simplex): return osp.dirichlet.rvs(alpha=np.ones((size[-1],)), size=size[:-1]) elif isinstance(constraint, constraints._Multinomial): n = size[-1] return multinomial(key, p=np.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) elif isinstance(constraint, constraints._CorrCholesky): return signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1)) elif isinstance(constraint, constraints._LowerCholesky): return np.tril(random.uniform(key, size)) elif isinstance(constraint, constraints._PositiveDefinite): x = random.normal(key, size) return np.matmul(x, np.swapaxes(x, -2, -1)) else: raise NotImplementedError('{} not implemented.'.format(constraint))
def sample(self, key, size=()): return poisson(key, self.rate, shape=size + self.batch_shape)
def sample(self, key, sample_shape=()): key_bern, key_poisson = random.split(key) shape = sample_shape + self.batch_shape mask = random.bernoulli(key_bern, self.gate, shape) samples = poisson(key_poisson, self.rate, shape) return np.where(mask, 0, samples)