Пример #1
0
def tensorsolve(a, b, axes=None):
    a = jnp.asarray(a)
    b = jnp.asarray(b)
    an = a.ndim
    if axes is not None:
        allaxes = list(range(0, an))
        for k in axes:
            allaxes.remove(k)
            allaxes.insert(an, k)

        a = a.transpose(allaxes)

    Q = a.shape[-(an - b.ndim):]

    prod = 1
    for k in Q:
        prod *= k

    a = a.reshape(-1, prod)
    b = b.ravel()

    res = jnp.asarray(la.solve(a, b))
    res = res.reshape(Q)

    return res
Пример #2
0
def _solve(a, b, sym_pos, lower):
    if not sym_pos:
        return np_linalg.solve(a, b)

    a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
    lax_linalg._check_solve_shapes(a, b)

    # With custom_linear_solve, we can reuse the same factorization when
    # computing sensitivities. This is considerably faster.
    factors = cho_factor(lax.stop_gradient(a), lower=lower)
    custom_solve = partial(lax.custom_linear_solve,
                           lambda x: lax_linalg._matvec_multiply(a, x),
                           solve=lambda _, x: cho_solve(factors, x),
                           symmetric=True)
    if a.ndim == b.ndim + 1:
        # b.shape == [..., m]
        return custom_solve(b)
    else:
        # b.shape == [..., m, k]
        return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
Пример #3
0
def _solve_P_Q(P, Q, upper_triangular=False):
    if upper_triangular:
        return solve_triangular(Q, P)
    else:
        return np_linalg.solve(Q, P)