Пример #1
0
def test_dirichlet_samples():

    alphas = np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]],
                      dtype=config.floatX)

    res = get_test_value(dirichlet(alphas))
    assert np.all(np.diag(res) >= res)

    res = get_test_value(dirichlet(alphas, size=2))
    assert res.shape == (2, 3, 3)
    assert all(np.all(np.diag(r) >= r) for r in res)

    for i in range(alphas.shape[0]):
        res = get_test_value(dirichlet(alphas[i]))
        assert np.all(res[i] > np.delete(res, [i]))

        res = get_test_value(dirichlet(alphas[i], size=2))
        assert res.shape == (2, 3)
        assert all(np.all(r[i] > np.delete(r, [i])) for r in res)

    rng_state = np.random.RandomState(
        np.random.MT19937(np.random.SeedSequence(1234)))

    alphas = np.array([[1000, 1, 1], [1, 1000, 1], [1, 1, 1000]],
                      dtype=config.floatX)

    assert dirichlet.rng_fn(rng_state, alphas, None).shape == alphas.shape
    assert dirichlet.rng_fn(rng_state, alphas,
                            size=10).shape == (10, ) + alphas.shape
    assert (dirichlet.rng_fn(rng_state, alphas,
                             size=(10, 2)).shape == (10, 2) + alphas.shape)
Пример #2
0
def test_dirichlet_rng():
    alphas = np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]],
                      dtype=config.floatX)

    with pytest.raises(ValueError, match="shape mismatch.*"):
        # The independent dimension's shape is missing from size (i.e. should
        # be `(10, 2, 3)`)
        dirichlet.rng_fn(None, alphas, size=(10, 2))
Пример #3
0
 def dirichlet_test_fn(mean=None, cov=None, size=None, random_state=None):
     if size is None:
         size = ()
     return dirichlet.rng_fn(random_state, alphas, size)