示例#1
0
def test_dirichlet_samples():

    alphas = np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]],
                      dtype=config.floatX)

    res = get_test_value(dirichlet(alphas))
    assert np.all(np.diag(res) >= res)

    res = get_test_value(dirichlet(alphas, size=2))
    assert res.shape == (2, 3, 3)
    assert all(np.all(np.diag(r) >= r) for r in res)

    for i in range(alphas.shape[0]):
        res = get_test_value(dirichlet(alphas[i]))
        assert np.all(res[i] > np.delete(res, [i]))

        res = get_test_value(dirichlet(alphas[i], size=2))
        assert res.shape == (2, 3)
        assert all(np.all(r[i] > np.delete(r, [i])) for r in res)

    rng_state = np.random.RandomState(
        np.random.MT19937(np.random.SeedSequence(1234)))

    alphas = np.array([[1000, 1, 1], [1, 1000, 1], [1, 1, 1000]],
                      dtype=config.floatX)

    assert dirichlet.rng_fn(rng_state, alphas, None).shape == alphas.shape
    assert dirichlet.rng_fn(rng_state, alphas,
                            size=10).shape == (10, ) + alphas.shape
    assert (dirichlet.rng_fn(rng_state, alphas,
                             size=(10, 2)).shape == (10, 2) + alphas.shape)
示例#2
0
def test_dirichlet_infer_shape():
    M_aet = iscalar("M")
    M_aet.tag.test_value = 3

    test_params = [
        ([aet.ones((M_aet, ))], None),
        ([aet.ones((M_aet, ))], (M_aet + 1, )),
        ([aet.ones((M_aet, ))], (2, M_aet)),
        ([aet.ones((M_aet, M_aet + 1))], None),
        ([aet.ones((M_aet, M_aet + 1))], (M_aet + 2, )),
        ([aet.ones((M_aet, M_aet + 1))], (2, M_aet + 2, M_aet + 3)),
    ]
    for args, size in test_params:
        rv = dirichlet(*args, size=size)
        rv_shape = tuple(dirichlet._infer_shape(size or (), args, None))
        assert tuple(get_test_value(rv_shape)) == tuple(
            get_test_value(rv).shape)
示例#3
0
def test_dirichlet_infer_shape(M, size):
    rv = dirichlet(M, size=size)
    rv_shape = list(dirichlet._infer_shape(size or (), [M], None))

    all_args = (M, ) + size
    fn_inputs = [
        i
        for i in graph_inputs([a for a in all_args if isinstance(a, Variable)])
        if not isinstance(i, (Constant, SharedVariable))
    ]
    aesara_fn = function(fn_inputs, [at.as_tensor(o) for o in rv_shape + [rv]],
                         mode=py_mode)

    *rv_shape_val, rv_val = aesara_fn(*[
        i.tag.test_value for i in fn_inputs
        if not isinstance(i, (SharedVariable, Constant))
    ])

    assert tuple(rv_shape_val) == tuple(rv_val.shape)
示例#4
0
def test_dirichlet_ShapeFeature():
    """Make sure `RandomVariable.infer_shape` works with `ShapeFeature`."""
    M_at = iscalar("M")
    M_at.tag.test_value = 2
    N_at = iscalar("N")
    N_at.tag.test_value = 3

    d_rv = dirichlet(at.ones((M_at, N_at)), name="Gamma")

    fg = FunctionGraph(
        outputs=[d_rv],
        clone=False,
        features=[ShapeFeature()],
    )

    s1, s2 = fg.shape_feature.shape_of[d_rv]

    assert M_at in graph_inputs([s1])
    assert N_at in graph_inputs([s2])
示例#5
0
def test_dirichlet_ShapeFeature():
    """Make sure `RandomVariable.infer_shape` works with `ShapeFeature`."""
    M_tt = iscalar("M")
    M_tt.tag.test_value = 2
    N_tt = iscalar("N")
    N_tt.tag.test_value = 3

    d_rv = dirichlet(aet.ones((M_tt, N_tt)), name="Gamma")

    fg = FunctionGraph(
        [i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)],
        [d_rv],
        clone=False,
        features=[ShapeFeature()],
    )

    s1, s2 = fg.shape_feature.shape_of[d_rv]

    assert M_tt in graph_inputs([s1])
    assert N_tt in graph_inputs([s2])