Example #1
0
def test_sample_frequency_vectorized(rng_key):
    probabilities = jnp.array([0, 1, 0.5])
    samples = Bernoulli(probabilities).sample(rng_key, (1_000_000, ))
    averages = jnp.mean(samples, axis=0)

    p_array = np.asarray(probabilities)
    avg_array = np.asarray(averages)
    assert avg_array == pytest.approx(p_array, abs=1e-3)
Example #2
0
def test_sample_shape_scalar_arguments(rng_key, case):
    """Test the correctness of broadcasting when both arguments are
    scalars. We test scalars arguments separately from array arguments
    since scalars are edge cases when it comes to broadcasting.

    The trailing `1` in the result shapes stands for the batch size.
    """
    samples = Bernoulli(0.5).sample(rng_key, case["sample_shape"])
    assert samples.shape == case["expected_shape"]
Example #3
0
def test_sample_frequency(rng_key, p):
    samples = Bernoulli(p).sample(rng_key, (1_000_000, ))
    avg = jnp.mean(samples, axis=0).item()
    assert avg == pytest.approx(p, abs=1e-3)