def test_gpu_eigh_opt(self): A = fmatrix("A") fn = aesara.function([A], eigh(A), mode=mode_with_gpu) assert any([ isinstance(node.op, GpuMagmaEigh) for node in fn.maker.fgraph.toposort() ])
def test_jax_basic_multiout(): rng = np.random.default_rng(213234) M = rng.normal(size=(3, 3)) X = M.dot(M.T) x = matrix("x") outs = aet_nlinalg.eig(x) out_fg = FunctionGraph([x], outs) def assert_fn(x, y): np.testing.assert_allclose(x.astype(config.floatX), y, rtol=1e-3) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) outs = aet_nlinalg.eigh(x) out_fg = FunctionGraph([x], outs) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) outs = aet_nlinalg.qr(x, mode="full") out_fg = FunctionGraph([x], outs) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) outs = aet_nlinalg.qr(x, mode="reduced") out_fg = FunctionGraph([x], outs) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) outs = aet_nlinalg.svd(x) out_fg = FunctionGraph([x], outs) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
def test_jax_basic_multiout(): np.random.seed(213234) M = np.random.normal(size=(3, 3)) X = M.dot(M.T) x = matrix("x") outs = aet_nlinalg.eig(x) out_fg = FunctionGraph([x], outs) def assert_fn(x, y): np.testing.assert_allclose(x.astype(config.floatX), y, rtol=1e-3) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) outs = aet_nlinalg.eigh(x) out_fg = FunctionGraph([x], outs) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) outs = aet_nlinalg.qr(x, mode="full") out_fg = FunctionGraph([x], outs) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) outs = aet_nlinalg.qr(x, mode="reduced") out_fg = FunctionGraph([x], outs) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) outs = aet_nlinalg.svd(x) out_fg = FunctionGraph([x], outs) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) # Test that a single output of a multi-output `Op` can be used as input to # another `Op` x = dvector() mx, amx = MaxAndArgmax([0])(x) out = mx * amx out_fg = FunctionGraph([x], [out]) compare_jax_and_py(out_fg, [np.r_[1, 2]])