def test_gpu_matrix_inverse_inplace_opt(self): A = fmatrix("A") fn = aesara.function([A], matrix_inverse(A), mode=mode_with_gpu) assert any([ node.op.inplace for node in fn.maker.fgraph.toposort() if isinstance(node.op, GpuMagmaMatrixInverse) ])
def test_inverse_singular(): singular = np.array([[1, 0, 0]] + [[0, 1, 0]] * 2, dtype=aesara.config.floatX) a = tensor.matrix() f = function([a], matrix_inverse(a)) with pytest.raises(np.linalg.LinAlgError): f(singular)
def test_transinv_to_invtrans(): X = matrix("X") Y = matrix_inverse(X) Z = Y.transpose() f = aesara.function([X], Z) if config.mode != "FAST_COMPILE": for node in f.maker.fgraph.toposort(): if isinstance(node.op, MatrixInverse): assert isinstance(node.inputs[0].owner.op, DimShuffle) if isinstance(node.op, DimShuffle): assert node.inputs[0].name == "X"
def test_rop_lop(): mx = matrix("mx") mv = matrix("mv") v = vector("v") y = matrix_inverse(mx).sum(axis=0) yv = aesara.gradient.Rop(y, mx, mv) rop_f = function([mx, mv], yv) sy, _ = aesara.scan( lambda i, y, x, v: (aesara.gradient.grad(y[i], x) * v).sum(), sequences=aet.arange(y.shape[0]), non_sequences=[y, mx, mv], ) scan_f = function([mx, mv], sy) rng = np.random.default_rng(utt.fetch_seed()) vx = np.asarray(rng.standard_normal((4, 4)), aesara.config.floatX) vv = np.asarray(rng.standard_normal((4, 4)), aesara.config.floatX) v1 = rop_f(vx, vv) v2 = scan_f(vx, vv) assert _allclose(v1, v2), f"ROP mismatch: {v1} {v2}" raised = False try: aesara.gradient.Rop(aesara.clone_replace(y, replace={mx: break_op(mx)}), mx, mv) except ValueError: raised = True if not raised: raise Exception( "Op did not raised an error even though the function" " is not differentiable" ) vv = np.asarray(rng.uniform(size=(4,)), aesara.config.floatX) yv = aesara.gradient.Lop(y, mx, v) lop_f = function([mx, v], yv) sy = aesara.gradient.grad((v * y).sum(), mx) scan_f = function([mx, v], sy) v1 = lop_f(vx, vv) v2 = scan_f(vx, vv) assert _allclose(v1, v2), f"LOP mismatch: {v1} {v2}"
def grad(self, inputs, g_outputs): [gz] = g_outputs [x] = inputs return [gz * matrix_inverse(x).T]
def test_matrix_inverse_solve(): A = dmatrix("A") b = dmatrix("b") node = matrix_inverse(A).dot(b).owner [out] = inv_as_solve.transform(None, node) assert isinstance(out.owner.op, Solve)
def test_jax_basic(): rng = np.random.default_rng(28494) x = matrix("x") y = matrix("y") b = vector("b") # `ScalarOp` z = cosh(x**2 + y / 3.0) # `[Inc]Subtensor` out = aet_subtensor.set_subtensor(z[0], -10.0) out = aet_subtensor.inc_subtensor(out[0, 1], 2.0) out = out[:5, :3] out_fg = FunctionGraph([x, y], [out]) test_input_vals = [ np.tile(np.arange(10), (10, 1)).astype(config.floatX), np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX), ] (jax_res, ) = compare_jax_and_py(out_fg, test_input_vals) # Confirm that the `Subtensor` slice operations are correct assert jax_res.shape == (5, 3) # Confirm that the `IncSubtensor` operations are correct assert jax_res[0, 0] == -10.0 assert jax_res[0, 1] == -8.0 out = clip(x, y, 5) out_fg = FunctionGraph([x, y], [out]) compare_jax_and_py(out_fg, test_input_vals) out = aet.diagonal(x, 0) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]) out = aet_slinalg.cholesky(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], ) # not sure why this isn't working yet with lower=False out = aet_slinalg.Cholesky(lower=False)(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], ) out = aet_slinalg.solve(x, b) out_fg = FunctionGraph([x, b], [out]) compare_jax_and_py( out_fg, [ np.eye(10).astype(config.floatX), np.arange(10).astype(config.floatX), ], ) out = aet.diag(b) out_fg = FunctionGraph([b], [out]) compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)]) out = aet_nlinalg.det(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]) out = aet_nlinalg.matrix_inverse(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], )