def tensorsolve(a, b, axes=None): a = jnp.asarray(a) b = jnp.asarray(b) an = a.ndim if axes is not None: allaxes = list(range(0, an)) for k in axes: allaxes.remove(k) allaxes.insert(an, k) a = a.transpose(allaxes) Q = a.shape[-(an - b.ndim):] prod = 1 for k in Q: prod *= k a = a.reshape(-1, prod) b = b.ravel() res = jnp.asarray(la.solve(a, b)) res = res.reshape(Q) return res
def _solve(a, b, sym_pos, lower): if not sym_pos: return np_linalg.solve(a, b) a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b)) lax_linalg._check_solve_shapes(a, b) # With custom_linear_solve, we can reuse the same factorization when # computing sensitivities. This is considerably faster. factors = cho_factor(lax.stop_gradient(a), lower=lower) custom_solve = partial(lax.custom_linear_solve, lambda x: lax_linalg._matvec_multiply(a, x), solve=lambda _, x: cho_solve(factors, x), symmetric=True) if a.ndim == b.ndim + 1: # b.shape == [..., m] return custom_solve(b) else: # b.shape == [..., m, k] return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
def _solve_P_Q(P, Q, upper_triangular=False): if upper_triangular: return solve_triangular(Q, P) else: return np_linalg.solve(Q, P)