def test_sample_frequency_vectorized(rng_key): probabilities = jnp.array([0, 1, 0.5]) samples = Bernoulli(probabilities).sample(rng_key, (1_000_000, )) averages = jnp.mean(samples, axis=0) p_array = np.asarray(probabilities) avg_array = np.asarray(averages) assert avg_array == pytest.approx(p_array, abs=1e-3)
def test_sample_shape_scalar_arguments(rng_key, case): """Test the correctness of broadcasting when both arguments are scalars. We test scalars arguments separately from array arguments since scalars are edge cases when it comes to broadcasting. The trailing `1` in the result shapes stands for the batch size. """ samples = Bernoulli(0.5).sample(rng_key, case["sample_shape"]) assert samples.shape == case["expected_shape"]
def test_sample_frequency(rng_key, p): samples = Bernoulli(p).sample(rng_key, (1_000_000, )) avg = jnp.mean(samples, axis=0).item() assert avg == pytest.approx(p, abs=1e-3)