def _solve_triangular(a, b, trans, lower, unit_diagonal): if trans == 0 or trans == "N": transpose_a, conjugate_a = False, False elif trans == 1 or trans == "T": transpose_a, conjugate_a = True, False elif trans == 2 or trans == "C": transpose_a, conjugate_a = True, True else: raise ValueError("Invalid 'trans' value {}".format(trans)) a, b = np_linalg._promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b)) # lax_linalg.triangular_solve only supports matrix 'b's at the moment. b_is_vector = jnp.ndim(a) == jnp.ndim(b) + 1 if b_is_vector: b = b[..., None] out = lax_linalg.triangular_solve(a, b, left_side=True, lower=lower, transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) if b_is_vector: return out[..., 0] else: return out
def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False, overwrite_b=False, turbo=True, eigvals=None, type=1, check_finite=True): del overwrite_a, overwrite_b, turbo, check_finite if b is not None: raise NotImplementedError( "Only the b=None case of eigh is implemented") if type != 1: raise NotImplementedError( "Only the type=1 case of eigh is implemented.") if eigvals is not None: raise NotImplementedError( "Only the eigvals=None case of eigh is implemented.") a = np_linalg._promote_arg_dtypes(jnp.asarray(a)) v, w = lax_linalg.eigh(a, lower=lower) if eigvals_only: return w else: return w, v
def _cho_solve(c, b, lower): c, b = np_linalg._promote_arg_dtypes(jnp.asarray(c), jnp.asarray(b)) lax_linalg._check_solve_shapes(c, b) b = lax_linalg.triangular_solve(c, b, left_side=True, lower=lower, transpose_a=not lower, conjugate_a=not lower) b = lax_linalg.triangular_solve(c, b, left_side=True, lower=lower, transpose_a=lower, conjugate_a=lower) return b
def svd(a, full_matrices=True, compute_uv=True, overwrite_a=False, check_finite=True, lapack_driver='gesdd'): del overwrite_a, check_finite, lapack_driver a = np_linalg._promote_arg_dtypes(jnp.asarray(a)) return lax_linalg.svd(a, full_matrices, compute_uv)
def _lu(a, permute_l): a = np_linalg._promote_arg_dtypes(jnp.asarray(a)) lu, pivots, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) m, n = jnp.shape(a) p = jnp.real(jnp.array(permutation == jnp.arange(m)[:, None], dtype=dtype)) k = min(m, n) l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype) u = jnp.triu(lu)[:k, :] if permute_l: return jnp.matmul(p, l), u else: return p, l, u
def _qr(a, mode, pivoting): if pivoting: raise NotImplementedError( "The pivoting=True case of qr is not implemented.") if mode in ("full", "r"): full_matrices = True elif mode == "economic": full_matrices = False else: raise ValueError("Unsupported QR decomposition mode '{}'".format(mode)) a = np_linalg._promote_arg_dtypes(jnp.asarray(a)) q, r = lax_linalg.qr(a, full_matrices) if mode == "r": return r return q, r
def _eigh(a, b, lower, eigvals_only, eigvals, type): if b is not None: raise NotImplementedError( "Only the b=None case of eigh is implemented") if type != 1: raise NotImplementedError( "Only the type=1 case of eigh is implemented.") if eigvals is not None: raise NotImplementedError( "Only the eigvals=None case of eigh is implemented.") a = np_linalg._promote_arg_dtypes(jnp.asarray(a)) v, w = lax_linalg.eigh(a, lower=lower) if eigvals_only: return w else: return w, v
def _solve(a, b, sym_pos, lower): if not sym_pos: return np_linalg.solve(a, b) a, b = np_linalg._promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b)) lax_linalg._check_solve_shapes(a, b) # With custom_linear_solve, we can reuse the same factorization when # computing sensitivities. This is considerably faster. factors = cho_factor(lax.stop_gradient(a), lower=lower) custom_solve = partial(lax.custom_linear_solve, lambda x: lax_linalg._matvec_multiply(a, x), solve=lambda _, x: cho_solve(factors, x), symmetric=True) if a.ndim == b.ndim + 1: # b.shape == [..., m] return custom_solve(b) else: # b.shape == [..., m, k] return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
def _cholesky(a, lower): a = np_linalg._promote_arg_dtypes(jnp.asarray(a)) l = lax_linalg.cholesky(a if lower else jnp.conj(_T(a)), symmetrize_input=False) return l if lower else jnp.conj(_T(l))
def lu_factor(a, overwrite_a=False, check_finite=True): del overwrite_a, check_finite a = np_linalg._promote_arg_dtypes(jnp.asarray(a)) lu, pivots, _ = lax_linalg.lu(a) return lu, pivots
def _svd(a, *, full_matrices, compute_uv): a = np_linalg._promote_arg_dtypes(jnp.asarray(a)) return lax_linalg.svd(a, full_matrices, compute_uv)