Esempio n. 1
0
def test_categorical_samples():

    rng_state = np.random.RandomState(
        np.random.MT19937(np.random.SeedSequence(1234)))

    p = np.array([[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]],
                 dtype=config.floatX)
    p = p / p.sum(axis=-1)

    assert categorical.rng_fn(rng_state, p, size=None).shape == p.shape[:-1]

    with raises(ValueError):
        categorical.rng_fn(rng_state, p, size=10)

    assert categorical.rng_fn(rng_state, p, size=(10, 3)).shape == (10, 3)
    assert categorical.rng_fn(rng_state, p,
                              size=(10, 2, 3)).shape == (10, 2, 3)

    res = categorical(p)
    assert np.array_equal(get_test_value(res), np.arange(3))

    res = categorical(p, size=(10, 3))
    exp_res = np.tile(np.arange(3), (10, 1))
    assert np.array_equal(get_test_value(res), exp_res)

    res = categorical(p, size=(10, 2, 3))
    exp_res = np.tile(np.arange(3), (10, 2, 1))
    assert np.array_equal(get_test_value(res), exp_res)
Esempio n. 2
0
def test_dirichlet_samples():

    alphas = np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]],
                      dtype=config.floatX)

    res = get_test_value(dirichlet(alphas))
    assert np.all(np.diag(res) >= res)

    res = get_test_value(dirichlet(alphas, size=2))
    assert res.shape == (2, 3, 3)
    assert all(np.all(np.diag(r) >= r) for r in res)

    for i in range(alphas.shape[0]):
        res = get_test_value(dirichlet(alphas[i]))
        assert np.all(res[i] > np.delete(res, [i]))

        res = get_test_value(dirichlet(alphas[i], size=2))
        assert res.shape == (2, 3)
        assert all(np.all(r[i] > np.delete(r, [i])) for r in res)

    rng_state = np.random.RandomState(
        np.random.MT19937(np.random.SeedSequence(1234)))

    alphas = np.array([[1000, 1, 1], [1, 1000, 1], [1, 1, 1000]],
                      dtype=config.floatX)

    assert dirichlet.rng_fn(rng_state, alphas, None).shape == alphas.shape
    assert dirichlet.rng_fn(rng_state, alphas,
                            size=10).shape == (10, ) + alphas.shape
    assert (dirichlet.rng_fn(rng_state, alphas,
                             size=(10, 2)).shape == (10, 2) + alphas.shape)
Esempio n. 3
0
def test_jax_BatchedDot():
    # tensor3 . tensor3
    a = tensor3("a")
    a.tag.test_value = (np.linspace(-1, 1,
                                    10 * 5 * 3).astype(config.floatX).reshape(
                                        (10, 5, 3)))
    b = tensor3("b")
    b.tag.test_value = (np.linspace(1, -1,
                                    10 * 3 * 2).astype(config.floatX).reshape(
                                        (10, 3, 2)))
    out = aet_blas.BatchedDot()(a, b)
    fgraph = FunctionGraph([a, b], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    # A dimension mismatch should raise a TypeError for compatibility
    inputs = [get_test_value(a)[:-1], get_test_value(b)]
    opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
    jax_mode = Mode(JAXLinker(), opts)
    aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
    with pytest.raises(TypeError):
        aesara_jax_fn(*inputs)

    # matrix . matrix
    a = matrix("a")
    a.tag.test_value = np.linspace(-1, 1, 5 * 3).astype(config.floatX).reshape(
        (5, 3))
    b = matrix("b")
    b.tag.test_value = np.linspace(1, -1, 5 * 3).astype(config.floatX).reshape(
        (5, 3))
    out = aet_blas.BatchedDot()(a, b)
    fgraph = FunctionGraph([a, b], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Esempio n. 4
0
def test_polyagamma_samples():

    _ = importorskip("pypolyagamma")

    # Sampled values should be scalars
    a = np.array(1.1, dtype=config.floatX)
    b = np.array(-10.5, dtype=config.floatX)
    pg_rv = polyagamma(a, b)
    assert get_test_value(pg_rv).shape == ()

    pg_rv = polyagamma(a, b, size=[1])
    assert get_test_value(pg_rv).shape == (1, )

    pg_rv = polyagamma(a, b, size=[2, 3])
    bcast_smpl = get_test_value(pg_rv)
    assert bcast_smpl.shape == (2, 3)
    # Make sure they're not all equal
    assert np.all(np.abs(np.diff(bcast_smpl.flat)) > 0.0)

    a = np.array([1.1, 3], dtype=config.floatX)
    b = np.array(-10.5, dtype=config.floatX)
    pg_rv = polyagamma(a, b)
    bcast_smpl = get_test_value(pg_rv)
    assert bcast_smpl.shape == (2, )
    assert np.all(np.abs(np.diff(bcast_smpl.flat)) > 0.0)

    pg_rv = polyagamma(a, b, size=(3, 2))
    bcast_smpl = get_test_value(pg_rv)
    assert bcast_smpl.shape == (3, 2)
    assert np.all(np.abs(np.diff(bcast_smpl.flat)) > 0.0)
Esempio n. 5
0
def test_tensor_basics():
    y = vector("y")
    y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
    x = vector("x")
    x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)
    A = matrix("A")
    A.tag.test_value = np.empty((2, 2), dtype=config.floatX)
    alpha = scalar("alpha")
    alpha.tag.test_value = np.array(3.0, dtype=config.floatX)
    beta = scalar("beta")
    beta.tag.test_value = np.array(5.0, dtype=config.floatX)

    # This should be converted into a `Gemv` `Op` when the non-JAX compatible
    # optimizations are turned on; however, when using JAX mode, it should
    # leave the expression alone.
    out = y.dot(alpha * A).dot(x) + beta * y
    fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = maximum(y, x)
    fgraph = FunctionGraph([y, x], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = aet_max(y)
    fgraph = FunctionGraph([y], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Esempio n. 6
0
def test_extra_ops():
    a = matrix("a")
    a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))

    out = aet_extra_ops.cumsum(a, axis=0)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = aet_extra_ops.cumprod(a, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = aet_extra_ops.diff(a, n=2, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = aet_extra_ops.repeat(a, (3, 3), axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    # This function also cannot take symbolic input.
    c = aet.as_tensor(5)
    out = aet_extra_ops.bartlett(c)
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = aet_extra_ops.fill_diagonal(a, c)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = aet_extra_ops.fill_diagonal_offset(a, c, c)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = aet_extra_ops.Unique(axis=1)(a)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    indices = np.arange(np.product((3, 4)))
    out = aet_extra_ops.unravel_index(indices, (3, 4), order="C")
    fgraph = FunctionGraph([], out)
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs],
                       must_be_device_array=False)

    multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4))
    out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4))
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs],
                       must_be_device_array=False)

    # The inputs are "concrete", yet it still has problems?
    out = aet_extra_ops.Unique()(aet.as_tensor(
        np.arange(6, dtype=config.floatX).reshape((3, 2))))
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [])
Esempio n. 7
0
def compare_sample_values(rv, *params, rng=None, test_fn=None, **kwargs):
    """Test for equivalence between `RandomVariable` and NumPy/other samples.

    An equivalently named method on a NumPy RNG object will be used, unless
    `test_fn` is specified.

    """
    if rng is None:
        rng = np.random.default_rng()

    if test_fn is None:
        name = getattr(rv, "name", None)

        if name is None:
            name = rv.__name__

        def test_fn(*args, random_state=None, **kwargs):
            return getattr(random_state, name)(*args, **kwargs)

    param_vals = [
        get_test_value(p) if isinstance(p, Variable) else p for p in params
    ]
    kwargs_vals = {
        k: get_test_value(v) if isinstance(v, Variable) else v
        for k, v in kwargs.items()
    }

    at_rng = shared(rng, borrow=True)

    numpy_res = np.asarray(
        test_fn(*param_vals, random_state=copy(rng), **kwargs_vals))

    aesara_res = rv(*params, rng=at_rng, **kwargs)

    assert aesara_res.type.numpy_dtype.kind == numpy_res.dtype.kind

    numpy_shape = np.shape(numpy_res)
    numpy_bcast = [s == 1 for s in numpy_shape]
    np.testing.assert_array_equal(aesara_res.type.broadcastable, numpy_bcast)

    fn_inputs = [
        i for i in graph_inputs([aesara_res])
        if not isinstance(i, (Constant, SharedVariable))
    ]
    aesara_fn = function(fn_inputs, aesara_res, mode=py_mode)

    aesara_res_val = aesara_fn()

    assert aesara_res_val.flags.writeable

    np.testing.assert_array_equal(aesara_res_val.shape, numpy_res.shape)

    np.testing.assert_allclose(aesara_res_val, numpy_res)
Esempio n. 8
0
def test_nnet():
    x = vector("x")
    x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)

    out = sigmoid(x)
    fgraph = FunctionGraph([x], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = aet_nnet.ultra_fast_sigmoid(x)
    fgraph = FunctionGraph([x], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = softplus(x)
    fgraph = FunctionGraph([x], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Esempio n. 9
0
def test_normal_infer_shape():
    M_aet = iscalar("M")
    M_aet.tag.test_value = 3
    sd_aet = scalar("sd")
    sd_aet.tag.test_value = np.array(1.0, dtype=config.floatX)

    test_params = [
        ([aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)),
          sd_aet], None),
        (
            [
                aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)),
                sd_aet
            ],
            (M_aet, ),
        ),
        (
            [
                aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)),
                sd_aet
            ],
            (2, M_aet),
        ),
        ([aet.zeros((M_aet, )), sd_aet], None),
        ([aet.zeros((M_aet, )), sd_aet], (M_aet, )),
        ([aet.zeros((M_aet, )), sd_aet], (2, M_aet)),
        ([aet.zeros((M_aet, )), aet.ones((M_aet, ))], None),
        ([aet.zeros((M_aet, )), aet.ones((M_aet, ))], (2, M_aet)),
        (
            [
                np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
                np.array([[1e-6, 2e-6]], dtype=config.floatX),
            ],
            (3, 2, 2),
        ),
        (
            [
                np.array([1], dtype=config.floatX),
                np.array([10], dtype=config.floatX)
            ],
            (1, 2),
        ),
    ]
    for args, size in test_params:
        rv = normal(*args, size=size)
        rv_shape = tuple(normal._infer_shape(size or (), args, None))
        assert tuple(get_test_value(rv_shape)) == tuple(
            get_test_value(rv).shape)
Esempio n. 10
0
def rv_numpy_tester(rv, *params, rng=None, test_fn=None, **kwargs):
    """Test for correspondence between `RandomVariable` and NumPy shape and
    broadcast dimensions.
    """
    if rng is None:
        rng = np.random.default_rng()

    if test_fn is None:
        name = getattr(rv, "name", None)

        if name is None:
            name = rv.__name__

        def test_fn(*args, random_state=None, **kwargs):
            return getattr(random_state, name)(*args, **kwargs)

    param_vals = [
        get_test_value(p) if isinstance(p, Variable) else p for p in params
    ]
    kwargs_vals = {
        k: get_test_value(v) if isinstance(v, Variable) else v
        for k, v in kwargs.items()
    }

    at_rng = shared(rng, borrow=True)

    numpy_res = np.asarray(
        test_fn(*param_vals, random_state=copy(rng), **kwargs_vals))

    aesara_res = rv(*params, rng=at_rng, **kwargs)

    assert aesara_res.type.numpy_dtype.kind == numpy_res.dtype.kind

    numpy_shape = np.shape(numpy_res)
    numpy_bcast = [s == 1 for s in numpy_shape]
    np.testing.assert_array_equal(aesara_res.type.broadcastable, numpy_bcast)

    fn_inputs = [
        i for i in graph_inputs([aesara_res])
        if not isinstance(i, (Constant, SharedVariable))
    ]
    aesara_fn = function(fn_inputs, aesara_res, mode=py_mode)

    aesara_res_val = aesara_fn()

    np.testing.assert_array_equal(aesara_res_val.shape, numpy_res.shape)

    np.testing.assert_allclose(aesara_res_val, numpy_res)
Esempio n. 11
0
def test_identity():
    a = scalar("a")
    a.tag.test_value = 10

    out = aes.identity(a)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Esempio n. 12
0
def test_unique_nonconcrete():
    a = matrix("a")
    a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))

    out = aet_extra_ops.Unique()(a)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Esempio n. 13
0
def test_softmax_grad(axis):
    dy = matrix("dy")
    dy.tag.test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
    sm = matrix("sm")
    sm.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
    out = SoftmaxGrad(axis=axis)(dy, sm)
    fgraph = FunctionGraph([dy, sm], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Esempio n. 14
0
def test_dirichlet_infer_shape():
    M_aet = iscalar("M")
    M_aet.tag.test_value = 3

    test_params = [
        ([aet.ones((M_aet, ))], None),
        ([aet.ones((M_aet, ))], (M_aet + 1, )),
        ([aet.ones((M_aet, ))], (2, M_aet)),
        ([aet.ones((M_aet, M_aet + 1))], None),
        ([aet.ones((M_aet, M_aet + 1))], (M_aet + 2, )),
        ([aet.ones((M_aet, M_aet + 1))], (2, M_aet + 2, M_aet + 3)),
    ]
    for args, size in test_params:
        rv = dirichlet(*args, size=size)
        rv_shape = tuple(dirichlet._infer_shape(size or (), args, None))
        assert tuple(get_test_value(rv_shape)) == tuple(
            get_test_value(rv).shape)
Esempio n. 15
0
def test_jax_variadic_Scalar():
    mu = vector("mu", dtype=config.floatX)
    mu.tag.test_value = np.r_[0.1, 1.1].astype(config.floatX)
    tau = vector("tau", dtype=config.floatX)
    tau.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)

    res = -tau * mu

    fgraph = FunctionGraph([mu, tau], [res])

    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    res = -tau * (tau - mu)**2

    fgraph = FunctionGraph([mu, tau], [res])

    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Esempio n. 16
0
def test_arange_nonconcrete():

    a = scalar("a")
    a.tag.test_value = 10

    out = aet.arange(a)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Esempio n. 17
0
def test_TransMatConjugateStep():

    with pm.Model() as test_model, pytest.raises(ValueError):
        p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2)
        transmat = TransMatConjugateStep(p_0_rv)

    np.random.seed(2032)

    poiszero_sim, _ = simulate_poiszero_hmm(30, 150)
    y_test = poiszero_sim["Y_t"]

    with pm.Model() as test_model:
        p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2)
        p_1_rv = pm.Dirichlet("p_1", np.r_[1, 1], shape=2)

        P_tt = at.stack([p_0_rv, p_1_rv])
        P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt))

        pi_0_tt = compute_steady_state(P_rv)

        S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt, shape=y_test.shape[0])

        PoissonZeroProcess("Y_t", 9.0, S_rv, observed=y_test)

    with test_model:
        transmat = TransMatConjugateStep(P_rv)

    test_point = test_model.test_point.copy()
    test_point["S_t"] = (y_test > 0).astype(int)

    res = transmat.step(test_point)

    p_0_smpl = get_test_value(
        p_0_rv.distribution.transform.backward(res[p_0_rv.transformed.name]))
    p_1_smpl = get_test_value(
        p_1_rv.distribution.transform.backward(res[p_1_rv.transformed.name]))

    sampled_trans_mat = np.stack([p_0_smpl, p_1_smpl])

    true_trans_mat = (
        compute_trans_freqs(poiszero_sim["S_t"], 2, counts_only=True) +
        np.c_[[1, 1], [1, 1]])
    true_trans_mat = true_trans_mat / true_trans_mat.sum(0)[..., None]

    assert np.allclose(sampled_trans_mat, true_trans_mat, atol=0.3)
Esempio n. 18
0
def test_jax_logp():

    mu = vector("mu")
    mu.tag.test_value = np.r_[0.0, 0.0].astype(config.floatX)
    tau = vector("tau")
    tau.tag.test_value = np.r_[1.0, 1.0].astype(config.floatX)
    sigma = vector("sigma")
    sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(config.floatX)
    value = vector("value")
    value.tag.test_value = np.r_[0.1, -10].astype(config.floatX)

    logp = (-tau * (value - mu)**2 + log(tau / np.pi / 2.0)) / 2.0
    conditions = [sigma > 0]
    alltrue = aet_all([aet_all(1 * val) for val in conditions])
    normal_logp = aet.switch(alltrue, logp, -np.inf)

    fgraph = FunctionGraph([mu, tau, sigma, value], [normal_logp])

    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Esempio n. 19
0
def test_normal_ShapeFeature():
    M_aet = iscalar("M")
    M_aet.tag.test_value = 3
    sd_aet = scalar("sd")
    sd_aet.tag.test_value = np.array(1.0, dtype=config.floatX)

    d_rv = normal(aet.ones((M_aet, )), sd_aet, size=(2, M_aet))
    d_rv.tag.test_value

    fg = FunctionGraph(
        [i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)],
        [d_rv],
        clone=False,
        features=[ShapeFeature()],
    )
    s1, s2 = fg.shape_feature.shape_of[d_rv]

    assert get_test_value(s1) == get_test_value(d_rv).shape[0]
    assert get_test_value(s2) == get_test_value(d_rv).shape[1]
Esempio n. 20
0
def test_extra_ops_omni():
    a = matrix("a")
    a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))

    # This function also cannot take symbolic input.
    c = aet.as_tensor(5)
    out = aet_extra_ops.bartlett(c)
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4))
    out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4))
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs],
                       must_be_device_array=False)

    # The inputs are "concrete", yet it still has problems?
    out = aet_extra_ops.Unique()(aet.as_tensor(
        np.arange(6, dtype=config.floatX).reshape((3, 2))))
    fgraph = FunctionGraph([], [out])
    compare_jax_and_py(fgraph, [])
Esempio n. 21
0
def test_jax_multioutput():
    x = vector("x")
    x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
    y = vector("y")
    y.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)

    w = cosh(x**2 + y / 3.0)
    v = cosh(x / 3.0 + y**2)

    fgraph = FunctionGraph([x, y], [w, v])

    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Esempio n. 22
0
def rv_numpy_tester(rv, *params, **kwargs):
    """Test for correspondence between `RandomVariable` and NumPy shape and
    broadcast dimensions.
    """
    test_fn = kwargs.pop("test_fn", None)

    if test_fn is None:
        name = getattr(rv, "name", None)

        if name is None:
            name = rv.__name__

        test_fn = getattr(np.random, name)

    aesara_res = rv(*params, **kwargs)

    param_vals = [
        get_test_value(p) if isinstance(p, Variable) else p for p in params
    ]
    kwargs_vals = {
        k: get_test_value(v) if isinstance(v, Variable) else v
        for k, v in kwargs.items()
    }

    if "size" in kwargs:
        kwargs["size"] = get_test_value(kwargs["size"])

    numpy_res = np.asarray(test_fn(*param_vals, **kwargs_vals))

    assert aesara_res.type.numpy_dtype.kind == numpy_res.dtype.kind

    numpy_shape = np.shape(numpy_res)
    numpy_bcast = [s == 1 for s in numpy_shape]
    np.testing.assert_array_equal(aesara_res.type.broadcastable, numpy_bcast)

    aesara_res_val = aesara_res.get_test_value()
    np.testing.assert_array_equal(aesara_res_val.shape, numpy_res.shape)
Esempio n. 23
0
def test_jax_ifelse():

    true_vals = np.r_[1, 2, 3]
    false_vals = np.r_[-1, -2, -3]

    x = ifelse(np.array(True), true_vals, false_vals)
    x_fg = FunctionGraph([], [x])

    compare_jax_and_py(x_fg, [])

    a = dscalar("a")
    a.tag.test_value = np.array(0.2, dtype=config.floatX)
    x = ifelse(a < 0.5, true_vals, false_vals)
    x_fg = FunctionGraph([a], [x])  # I.e. False

    compare_jax_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs])
Esempio n. 24
0
def test_extra_ops():
    a = matrix("a")
    a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))

    out = at_extra_ops.cumsum(a, axis=0)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = at_extra_ops.cumprod(a, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = at_extra_ops.diff(a, n=2, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    out = at_extra_ops.repeat(a, (3, 3), axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    c = at.as_tensor(5)

    with pytest.raises(NotImplementedError):
        out = at_extra_ops.fill_diagonal(a, c)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = at_extra_ops.fill_diagonal_offset(a, c, c)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    with pytest.raises(NotImplementedError):
        out = at_extra_ops.Unique(axis=1)(a)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    indices = np.arange(np.product((3, 4)))
    out = at_extra_ops.unravel_index(indices, (3, 4), order="C")
    fgraph = FunctionGraph([], out)
    compare_jax_and_py(
        fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
    )
Esempio n. 25
0
def test_mvnormal_ShapeFeature():
    M_aet = iscalar("M")
    M_aet.tag.test_value = 2

    d_rv = multivariate_normal(aet.ones((M_aet, )), aet.eye(M_aet), size=2)

    fg = FunctionGraph(
        [i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)],
        [d_rv],
        clone=False,
        features=[ShapeFeature()],
    )

    s1, s2 = fg.shape_feature.shape_of[d_rv]

    assert get_test_value(s1) == 2
    assert M_aet in graph_inputs([s2])

    # Test broadcasted shapes
    mean = tensor(config.floatX, [True, False])
    mean.tag.test_value = np.array([[0, 1, 2]], dtype=config.floatX)

    test_covar = np.diag(np.array([1, 10, 100], dtype=config.floatX))
    test_covar = np.stack([test_covar, test_covar * 10.0])
    cov = aet.as_tensor(test_covar).type()
    cov.tag.test_value = test_covar

    d_rv = multivariate_normal(mean, cov, size=[2, 3])

    fg = FunctionGraph(
        [i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)],
        [d_rv],
        clone=False,
        features=[ShapeFeature()],
    )

    s1, s2, s3, s4 = fg.shape_feature.shape_of[d_rv]

    assert s1.get_test_value() == 2
    assert s2.get_test_value() == 3
    assert s3.get_test_value() == 2
    assert s4.get_test_value() == 3
Esempio n. 26
0
def test_test_value_op():

    x = log(np.ones((5, 5)))
    v = op.get_test_value(x)

    assert np.allclose(v, np.zeros((5, 5)))
Esempio n. 27
0
def test_test_value_shared():
    x = shared(np.zeros((5, 5)))
    v = op.get_test_value(x)

    assert np.all(v == np.zeros((5, 5)))
Esempio n. 28
0
def test_test_value_constant():
    x = aet.as_tensor_variable(np.zeros((5, 5)))
    v = op.get_test_value(x)

    assert np.all(v == np.zeros((5, 5)))
Esempio n. 29
0
def test_test_value_ndarray():
    x = np.zeros((5, 5))
    v = op.get_test_value(x)
    assert np.all(v == x)
Esempio n. 30
0
def test_test_value_python_objects():
    for x in ([0, 1, 2], 0, 0.5, 1):
        assert np.all(op.get_test_value(x) == x)