def test_svd(self): A = matrix("A", dtype=self.dtype) U, S, VT = svd(A) fn = function([A], [U, S, VT]) a = self.rng.rand(4, 4).astype(self.dtype) n_u, n_s, n_vt = np.linalg.svd(a) t_u, t_s, t_vt = fn(a) assert _allclose(n_u, t_u) assert _allclose(n_s, t_s) assert _allclose(n_vt, t_vt) fn = function([A], svd(A, compute_uv=False)) t_s = fn(a) assert _allclose(n_s, t_s)
def test_jax_basic_multiout(): rng = np.random.default_rng(213234) M = rng.normal(size=(3, 3)) X = M.dot(M.T) x = matrix("x") outs = aet_nlinalg.eig(x) out_fg = FunctionGraph([x], outs) def assert_fn(x, y): np.testing.assert_allclose(x.astype(config.floatX), y, rtol=1e-3) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) outs = aet_nlinalg.eigh(x) out_fg = FunctionGraph([x], outs) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) outs = aet_nlinalg.qr(x, mode="full") out_fg = FunctionGraph([x], outs) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) outs = aet_nlinalg.qr(x, mode="reduced") out_fg = FunctionGraph([x], outs) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) outs = aet_nlinalg.svd(x) out_fg = FunctionGraph([x], outs) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
def test_jax_basic_multiout(): np.random.seed(213234) M = np.random.normal(size=(3, 3)) X = M.dot(M.T) x = matrix("x") outs = aet_nlinalg.eig(x) out_fg = FunctionGraph([x], outs) def assert_fn(x, y): np.testing.assert_allclose(x.astype(config.floatX), y, rtol=1e-3) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) outs = aet_nlinalg.eigh(x) out_fg = FunctionGraph([x], outs) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) outs = aet_nlinalg.qr(x, mode="full") out_fg = FunctionGraph([x], outs) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) outs = aet_nlinalg.qr(x, mode="reduced") out_fg = FunctionGraph([x], outs) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) outs = aet_nlinalg.svd(x) out_fg = FunctionGraph([x], outs) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) # Test that a single output of a multi-output `Op` can be used as input to # another `Op` x = dvector() mx, amx = MaxAndArgmax([0])(x) out = mx * amx out_fg = FunctionGraph([x], [out]) compare_jax_and_py(out_fg, [np.r_[1, 2]])