Esempio n. 1
0
def solve(a, b):
    a, b = _promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b))
    _check_solve_shapes(a, b)

    # With custom_linear_solve, we can reuse the same factorization when
    # computing sensitivities. This is considerably faster.
    lu, _, permutation = lax_linalg.lu(lax.stop_gradient(a))
    custom_solve = partial(
        lax.custom_linear_solve,
        lambda x: _matvec_multiply(a, x),
        solve=lambda _, x: lax_linalg.lu_solve(lu, permutation, x, trans=0),
        transpose_solve=lambda _, x: lax_linalg.lu_solve(
            lu, permutation, x, trans=1))
    if a.ndim == b.ndim + 1:
        # b.shape == [..., m]
        return custom_solve(b)
    else:
        # b.shape == [..., m, k]
        return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
Esempio n. 2
0
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
    del overwrite_b, check_finite
    lu, pivots = lu_and_piv
    m, n = lu.shape[-2:]
    perm = lax_linalg.lu_pivots_to_permutation(pivots, m)
    return lax_linalg.lu_solve(lu, perm, b, trans)