def solve(a, b): a, b = _promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b)) _check_solve_shapes(a, b) # With custom_linear_solve, we can reuse the same factorization when # computing sensitivities. This is considerably faster. lu, _, permutation = lax_linalg.lu(lax.stop_gradient(a)) custom_solve = partial( lax.custom_linear_solve, lambda x: _matvec_multiply(a, x), solve=lambda _, x: lax_linalg.lu_solve(lu, permutation, x, trans=0), transpose_solve=lambda _, x: lax_linalg.lu_solve( lu, permutation, x, trans=1)) if a.ndim == b.ndim + 1: # b.shape == [..., m] return custom_solve(b) else: # b.shape == [..., m, k] return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): del overwrite_b, check_finite lu, pivots = lu_and_piv m, n = lu.shape[-2:] perm = lax_linalg.lu_pivots_to_permutation(pivots, m) return lax_linalg.lu_solve(lu, perm, b, trans)