예제 #1
0
def test_binomial_mean(n, p):
    samples = binomial(random.PRNGKey(1), p, n, shape=(100, 100))
    expected_mean = n * p
    assert_allclose(jnp.mean(samples), expected_mean, rtol=0.05)
예제 #2
0
 def sample(self, key, size=()):
     return binomial(key,
                     self.probs,
                     n=self.total_count,
                     shape=size + self.batch_shape)
예제 #3
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     return binomial(key,
                     self.probs,
                     n=self.total_count,
                     shape=sample_shape + self.batch_shape)