def test_categorical_samples(): rng_state = np.random.RandomState( np.random.MT19937(np.random.SeedSequence(1234))) p = np.array([[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], dtype=config.floatX) p = p / p.sum(axis=-1) assert categorical.rng_fn(rng_state, p, size=None).shape == p.shape[:-1] with raises(ValueError): categorical.rng_fn(rng_state, p, size=10) assert categorical.rng_fn(rng_state, p, size=(10, 3)).shape == (10, 3) assert categorical.rng_fn(rng_state, p, size=(10, 2, 3)).shape == (10, 2, 3) res = categorical(p) assert np.array_equal(get_test_value(res), np.arange(3)) res = categorical(p, size=(10, 3)) exp_res = np.tile(np.arange(3), (10, 1)) assert np.array_equal(get_test_value(res), exp_res) res = categorical(p, size=(10, 2, 3)) exp_res = np.tile(np.arange(3), (10, 2, 1)) assert np.array_equal(get_test_value(res), exp_res)
def test_dirichlet_samples(): alphas = np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX) res = get_test_value(dirichlet(alphas)) assert np.all(np.diag(res) >= res) res = get_test_value(dirichlet(alphas, size=2)) assert res.shape == (2, 3, 3) assert all(np.all(np.diag(r) >= r) for r in res) for i in range(alphas.shape[0]): res = get_test_value(dirichlet(alphas[i])) assert np.all(res[i] > np.delete(res, [i])) res = get_test_value(dirichlet(alphas[i], size=2)) assert res.shape == (2, 3) assert all(np.all(r[i] > np.delete(r, [i])) for r in res) rng_state = np.random.RandomState( np.random.MT19937(np.random.SeedSequence(1234))) alphas = np.array([[1000, 1, 1], [1, 1000, 1], [1, 1, 1000]], dtype=config.floatX) assert dirichlet.rng_fn(rng_state, alphas, None).shape == alphas.shape assert dirichlet.rng_fn(rng_state, alphas, size=10).shape == (10, ) + alphas.shape assert (dirichlet.rng_fn(rng_state, alphas, size=(10, 2)).shape == (10, 2) + alphas.shape)
def test_jax_BatchedDot(): # tensor3 . tensor3 a = tensor3("a") a.tag.test_value = (np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape( (10, 5, 3))) b = tensor3("b") b.tag.test_value = (np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape( (10, 3, 2))) out = aet_blas.BatchedDot()(a, b) fgraph = FunctionGraph([a, b], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) # A dimension mismatch should raise a TypeError for compatibility inputs = [get_test_value(a)[:-1], get_test_value(b)] opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) jax_mode = Mode(JAXLinker(), opts) aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode) with pytest.raises(TypeError): aesara_jax_fn(*inputs) # matrix . matrix a = matrix("a") a.tag.test_value = np.linspace(-1, 1, 5 * 3).astype(config.floatX).reshape( (5, 3)) b = matrix("b") b.tag.test_value = np.linspace(1, -1, 5 * 3).astype(config.floatX).reshape( (5, 3)) out = aet_blas.BatchedDot()(a, b) fgraph = FunctionGraph([a, b], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_polyagamma_samples(): _ = importorskip("pypolyagamma") # Sampled values should be scalars a = np.array(1.1, dtype=config.floatX) b = np.array(-10.5, dtype=config.floatX) pg_rv = polyagamma(a, b) assert get_test_value(pg_rv).shape == () pg_rv = polyagamma(a, b, size=[1]) assert get_test_value(pg_rv).shape == (1, ) pg_rv = polyagamma(a, b, size=[2, 3]) bcast_smpl = get_test_value(pg_rv) assert bcast_smpl.shape == (2, 3) # Make sure they're not all equal assert np.all(np.abs(np.diff(bcast_smpl.flat)) > 0.0) a = np.array([1.1, 3], dtype=config.floatX) b = np.array(-10.5, dtype=config.floatX) pg_rv = polyagamma(a, b) bcast_smpl = get_test_value(pg_rv) assert bcast_smpl.shape == (2, ) assert np.all(np.abs(np.diff(bcast_smpl.flat)) > 0.0) pg_rv = polyagamma(a, b, size=(3, 2)) bcast_smpl = get_test_value(pg_rv) assert bcast_smpl.shape == (3, 2) assert np.all(np.abs(np.diff(bcast_smpl.flat)) > 0.0)
def test_tensor_basics(): y = vector("y") y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) x = vector("x") x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) A = matrix("A") A.tag.test_value = np.empty((2, 2), dtype=config.floatX) alpha = scalar("alpha") alpha.tag.test_value = np.array(3.0, dtype=config.floatX) beta = scalar("beta") beta.tag.test_value = np.array(5.0, dtype=config.floatX) # This should be converted into a `Gemv` `Op` when the non-JAX compatible # optimizations are turned on; however, when using JAX mode, it should # leave the expression alone. out = y.dot(alpha * A).dot(x) + beta * y fgraph = FunctionGraph([y, x, A, alpha, beta], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = maximum(y, x) fgraph = FunctionGraph([y, x], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_max(y) fgraph = FunctionGraph([y], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_extra_ops(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) out = aet_extra_ops.cumsum(a, axis=0) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_extra_ops.cumprod(a, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_extra_ops.diff(a, n=2, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_extra_ops.repeat(a, (3, 3), axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) # This function also cannot take symbolic input. c = aet.as_tensor(5) out = aet_extra_ops.bartlett(c) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = aet_extra_ops.fill_diagonal(a, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = aet_extra_ops.fill_diagonal_offset(a, c, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = aet_extra_ops.Unique(axis=1)(a) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) indices = np.arange(np.product((3, 4))) out = aet_extra_ops.unravel_index(indices, (3, 4), order="C") fgraph = FunctionGraph([], out) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False) multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4)) out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4)) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False) # The inputs are "concrete", yet it still has problems? out = aet_extra_ops.Unique()(aet.as_tensor( np.arange(6, dtype=config.floatX).reshape((3, 2)))) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [])
def compare_sample_values(rv, *params, rng=None, test_fn=None, **kwargs): """Test for equivalence between `RandomVariable` and NumPy/other samples. An equivalently named method on a NumPy RNG object will be used, unless `test_fn` is specified. """ if rng is None: rng = np.random.default_rng() if test_fn is None: name = getattr(rv, "name", None) if name is None: name = rv.__name__ def test_fn(*args, random_state=None, **kwargs): return getattr(random_state, name)(*args, **kwargs) param_vals = [ get_test_value(p) if isinstance(p, Variable) else p for p in params ] kwargs_vals = { k: get_test_value(v) if isinstance(v, Variable) else v for k, v in kwargs.items() } at_rng = shared(rng, borrow=True) numpy_res = np.asarray( test_fn(*param_vals, random_state=copy(rng), **kwargs_vals)) aesara_res = rv(*params, rng=at_rng, **kwargs) assert aesara_res.type.numpy_dtype.kind == numpy_res.dtype.kind numpy_shape = np.shape(numpy_res) numpy_bcast = [s == 1 for s in numpy_shape] np.testing.assert_array_equal(aesara_res.type.broadcastable, numpy_bcast) fn_inputs = [ i for i in graph_inputs([aesara_res]) if not isinstance(i, (Constant, SharedVariable)) ] aesara_fn = function(fn_inputs, aesara_res, mode=py_mode) aesara_res_val = aesara_fn() assert aesara_res_val.flags.writeable np.testing.assert_array_equal(aesara_res_val.shape, numpy_res.shape) np.testing.assert_allclose(aesara_res_val, numpy_res)
def test_nnet(): x = vector("x") x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) out = sigmoid(x) fgraph = FunctionGraph([x], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_nnet.ultra_fast_sigmoid(x) fgraph = FunctionGraph([x], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = softplus(x) fgraph = FunctionGraph([x], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_normal_infer_shape(): M_aet = iscalar("M") M_aet.tag.test_value = 3 sd_aet = scalar("sd") sd_aet.tag.test_value = np.array(1.0, dtype=config.floatX) test_params = [ ([aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_aet], None), ( [ aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_aet ], (M_aet, ), ), ( [ aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_aet ], (2, M_aet), ), ([aet.zeros((M_aet, )), sd_aet], None), ([aet.zeros((M_aet, )), sd_aet], (M_aet, )), ([aet.zeros((M_aet, )), sd_aet], (2, M_aet)), ([aet.zeros((M_aet, )), aet.ones((M_aet, ))], None), ([aet.zeros((M_aet, )), aet.ones((M_aet, ))], (2, M_aet)), ( [ np.array([[-1, 20], [300, -4000]], dtype=config.floatX), np.array([[1e-6, 2e-6]], dtype=config.floatX), ], (3, 2, 2), ), ( [ np.array([1], dtype=config.floatX), np.array([10], dtype=config.floatX) ], (1, 2), ), ] for args, size in test_params: rv = normal(*args, size=size) rv_shape = tuple(normal._infer_shape(size or (), args, None)) assert tuple(get_test_value(rv_shape)) == tuple( get_test_value(rv).shape)
def rv_numpy_tester(rv, *params, rng=None, test_fn=None, **kwargs): """Test for correspondence between `RandomVariable` and NumPy shape and broadcast dimensions. """ if rng is None: rng = np.random.default_rng() if test_fn is None: name = getattr(rv, "name", None) if name is None: name = rv.__name__ def test_fn(*args, random_state=None, **kwargs): return getattr(random_state, name)(*args, **kwargs) param_vals = [ get_test_value(p) if isinstance(p, Variable) else p for p in params ] kwargs_vals = { k: get_test_value(v) if isinstance(v, Variable) else v for k, v in kwargs.items() } at_rng = shared(rng, borrow=True) numpy_res = np.asarray( test_fn(*param_vals, random_state=copy(rng), **kwargs_vals)) aesara_res = rv(*params, rng=at_rng, **kwargs) assert aesara_res.type.numpy_dtype.kind == numpy_res.dtype.kind numpy_shape = np.shape(numpy_res) numpy_bcast = [s == 1 for s in numpy_shape] np.testing.assert_array_equal(aesara_res.type.broadcastable, numpy_bcast) fn_inputs = [ i for i in graph_inputs([aesara_res]) if not isinstance(i, (Constant, SharedVariable)) ] aesara_fn = function(fn_inputs, aesara_res, mode=py_mode) aesara_res_val = aesara_fn() np.testing.assert_array_equal(aesara_res_val.shape, numpy_res.shape) np.testing.assert_allclose(aesara_res_val, numpy_res)
def test_identity(): a = scalar("a") a.tag.test_value = 10 out = aes.identity(a) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_unique_nonconcrete(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) out = aet_extra_ops.Unique()(a) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_softmax_grad(axis): dy = matrix("dy") dy.tag.test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) sm = matrix("sm") sm.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) out = SoftmaxGrad(axis=axis)(dy, sm) fgraph = FunctionGraph([dy, sm], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_dirichlet_infer_shape(): M_aet = iscalar("M") M_aet.tag.test_value = 3 test_params = [ ([aet.ones((M_aet, ))], None), ([aet.ones((M_aet, ))], (M_aet + 1, )), ([aet.ones((M_aet, ))], (2, M_aet)), ([aet.ones((M_aet, M_aet + 1))], None), ([aet.ones((M_aet, M_aet + 1))], (M_aet + 2, )), ([aet.ones((M_aet, M_aet + 1))], (2, M_aet + 2, M_aet + 3)), ] for args, size in test_params: rv = dirichlet(*args, size=size) rv_shape = tuple(dirichlet._infer_shape(size or (), args, None)) assert tuple(get_test_value(rv_shape)) == tuple( get_test_value(rv).shape)
def test_jax_variadic_Scalar(): mu = vector("mu", dtype=config.floatX) mu.tag.test_value = np.r_[0.1, 1.1].astype(config.floatX) tau = vector("tau", dtype=config.floatX) tau.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) res = -tau * mu fgraph = FunctionGraph([mu, tau], [res]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) res = -tau * (tau - mu)**2 fgraph = FunctionGraph([mu, tau], [res]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_arange_nonconcrete(): a = scalar("a") a.tag.test_value = 10 out = aet.arange(a) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_TransMatConjugateStep(): with pm.Model() as test_model, pytest.raises(ValueError): p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2) transmat = TransMatConjugateStep(p_0_rv) np.random.seed(2032) poiszero_sim, _ = simulate_poiszero_hmm(30, 150) y_test = poiszero_sim["Y_t"] with pm.Model() as test_model: p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2) p_1_rv = pm.Dirichlet("p_1", np.r_[1, 1], shape=2) P_tt = at.stack([p_0_rv, p_1_rv]) P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt)) pi_0_tt = compute_steady_state(P_rv) S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt, shape=y_test.shape[0]) PoissonZeroProcess("Y_t", 9.0, S_rv, observed=y_test) with test_model: transmat = TransMatConjugateStep(P_rv) test_point = test_model.test_point.copy() test_point["S_t"] = (y_test > 0).astype(int) res = transmat.step(test_point) p_0_smpl = get_test_value( p_0_rv.distribution.transform.backward(res[p_0_rv.transformed.name])) p_1_smpl = get_test_value( p_1_rv.distribution.transform.backward(res[p_1_rv.transformed.name])) sampled_trans_mat = np.stack([p_0_smpl, p_1_smpl]) true_trans_mat = ( compute_trans_freqs(poiszero_sim["S_t"], 2, counts_only=True) + np.c_[[1, 1], [1, 1]]) true_trans_mat = true_trans_mat / true_trans_mat.sum(0)[..., None] assert np.allclose(sampled_trans_mat, true_trans_mat, atol=0.3)
def test_jax_logp(): mu = vector("mu") mu.tag.test_value = np.r_[0.0, 0.0].astype(config.floatX) tau = vector("tau") tau.tag.test_value = np.r_[1.0, 1.0].astype(config.floatX) sigma = vector("sigma") sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(config.floatX) value = vector("value") value.tag.test_value = np.r_[0.1, -10].astype(config.floatX) logp = (-tau * (value - mu)**2 + log(tau / np.pi / 2.0)) / 2.0 conditions = [sigma > 0] alltrue = aet_all([aet_all(1 * val) for val in conditions]) normal_logp = aet.switch(alltrue, logp, -np.inf) fgraph = FunctionGraph([mu, tau, sigma, value], [normal_logp]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_normal_ShapeFeature(): M_aet = iscalar("M") M_aet.tag.test_value = 3 sd_aet = scalar("sd") sd_aet.tag.test_value = np.array(1.0, dtype=config.floatX) d_rv = normal(aet.ones((M_aet, )), sd_aet, size=(2, M_aet)) d_rv.tag.test_value fg = FunctionGraph( [i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)], [d_rv], clone=False, features=[ShapeFeature()], ) s1, s2 = fg.shape_feature.shape_of[d_rv] assert get_test_value(s1) == get_test_value(d_rv).shape[0] assert get_test_value(s2) == get_test_value(d_rv).shape[1]
def test_extra_ops_omni(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) # This function also cannot take symbolic input. c = aet.as_tensor(5) out = aet_extra_ops.bartlett(c) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4)) out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4)) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False) # The inputs are "concrete", yet it still has problems? out = aet_extra_ops.Unique()(aet.as_tensor( np.arange(6, dtype=config.floatX).reshape((3, 2)))) fgraph = FunctionGraph([], [out]) compare_jax_and_py(fgraph, [])
def test_jax_multioutput(): x = vector("x") x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) y = vector("y") y.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) w = cosh(x**2 + y / 3.0) v = cosh(x / 3.0 + y**2) fgraph = FunctionGraph([x, y], [w, v]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def rv_numpy_tester(rv, *params, **kwargs): """Test for correspondence between `RandomVariable` and NumPy shape and broadcast dimensions. """ test_fn = kwargs.pop("test_fn", None) if test_fn is None: name = getattr(rv, "name", None) if name is None: name = rv.__name__ test_fn = getattr(np.random, name) aesara_res = rv(*params, **kwargs) param_vals = [ get_test_value(p) if isinstance(p, Variable) else p for p in params ] kwargs_vals = { k: get_test_value(v) if isinstance(v, Variable) else v for k, v in kwargs.items() } if "size" in kwargs: kwargs["size"] = get_test_value(kwargs["size"]) numpy_res = np.asarray(test_fn(*param_vals, **kwargs_vals)) assert aesara_res.type.numpy_dtype.kind == numpy_res.dtype.kind numpy_shape = np.shape(numpy_res) numpy_bcast = [s == 1 for s in numpy_shape] np.testing.assert_array_equal(aesara_res.type.broadcastable, numpy_bcast) aesara_res_val = aesara_res.get_test_value() np.testing.assert_array_equal(aesara_res_val.shape, numpy_res.shape)
def test_jax_ifelse(): true_vals = np.r_[1, 2, 3] false_vals = np.r_[-1, -2, -3] x = ifelse(np.array(True), true_vals, false_vals) x_fg = FunctionGraph([], [x]) compare_jax_and_py(x_fg, []) a = dscalar("a") a.tag.test_value = np.array(0.2, dtype=config.floatX) x = ifelse(a < 0.5, true_vals, false_vals) x_fg = FunctionGraph([a], [x]) # I.e. False compare_jax_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs])
def test_extra_ops(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) out = at_extra_ops.cumsum(a, axis=0) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = at_extra_ops.cumprod(a, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = at_extra_ops.diff(a, n=2, axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = at_extra_ops.repeat(a, (3, 3), axis=1) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) c = at.as_tensor(5) with pytest.raises(NotImplementedError): out = at_extra_ops.fill_diagonal(a, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = at_extra_ops.fill_diagonal_offset(a, c, c) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) with pytest.raises(NotImplementedError): out = at_extra_ops.Unique(axis=1)(a) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) indices = np.arange(np.product((3, 4))) out = at_extra_ops.unravel_index(indices, (3, 4), order="C") fgraph = FunctionGraph([], out) compare_jax_and_py( fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False )
def test_mvnormal_ShapeFeature(): M_aet = iscalar("M") M_aet.tag.test_value = 2 d_rv = multivariate_normal(aet.ones((M_aet, )), aet.eye(M_aet), size=2) fg = FunctionGraph( [i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)], [d_rv], clone=False, features=[ShapeFeature()], ) s1, s2 = fg.shape_feature.shape_of[d_rv] assert get_test_value(s1) == 2 assert M_aet in graph_inputs([s2]) # Test broadcasted shapes mean = tensor(config.floatX, [True, False]) mean.tag.test_value = np.array([[0, 1, 2]], dtype=config.floatX) test_covar = np.diag(np.array([1, 10, 100], dtype=config.floatX)) test_covar = np.stack([test_covar, test_covar * 10.0]) cov = aet.as_tensor(test_covar).type() cov.tag.test_value = test_covar d_rv = multivariate_normal(mean, cov, size=[2, 3]) fg = FunctionGraph( [i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)], [d_rv], clone=False, features=[ShapeFeature()], ) s1, s2, s3, s4 = fg.shape_feature.shape_of[d_rv] assert s1.get_test_value() == 2 assert s2.get_test_value() == 3 assert s3.get_test_value() == 2 assert s4.get_test_value() == 3
def test_test_value_op(): x = log(np.ones((5, 5))) v = op.get_test_value(x) assert np.allclose(v, np.zeros((5, 5)))
def test_test_value_shared(): x = shared(np.zeros((5, 5))) v = op.get_test_value(x) assert np.all(v == np.zeros((5, 5)))
def test_test_value_constant(): x = aet.as_tensor_variable(np.zeros((5, 5))) v = op.get_test_value(x) assert np.all(v == np.zeros((5, 5)))
def test_test_value_ndarray(): x = np.zeros((5, 5)) v = op.get_test_value(x) assert np.all(v == x)
def test_test_value_python_objects(): for x in ([0, 1, 2], 0, 0.5, 1): assert np.all(op.get_test_value(x) == x)