def test_binomial_mean(n, p): samples = binomial(random.PRNGKey(1), p, n, shape=(100, 100)) expected_mean = n * p assert_allclose(jnp.mean(samples), expected_mean, rtol=0.05)
def sample(self, key, size=()): return binomial(key, self.probs, n=self.total_count, shape=size + self.batch_shape)
def sample(self, key, sample_shape=()): assert is_prng_key(key) return binomial(key, self.probs, n=self.total_count, shape=sample_shape + self.batch_shape)