def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False, overwrite_b=False, turbo=True, eigvals=None, type=1, check_finite=True): del overwrite_a, overwrite_b, turbo, check_finite if b is not None: raise NotImplementedError( "Only the b=None case of eigh is implemented") if type != 1: raise NotImplementedError( "Only the type=1 case of eigh is implemented.") if eigvals is not None: raise NotImplementedError( "Only the eigvals=None case of eigh is implemented.") a = np_linalg._promote_arg_dtypes(jnp.asarray(a)) v, w = lax_linalg.eigh(a, lower=lower) if eigvals_only: return w else: return w, v
def svd(a, full_matrices: bool = True, compute_uv: bool = True, hermitian: bool = False): a, = _promote_dtypes_inexact(jnp.asarray(a)) if hermitian: w, v = lax_linalg.eigh(a) s = lax.abs(v) if compute_uv: sign = lax.sign(v) idxs = lax.broadcasted_iota(np.int64, s.shape, dimension=s.ndim - 1) s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1) s = lax.rev(s, dimensions=[s.ndim - 1]) idxs = lax.rev(idxs, dimensions=[s.ndim - 1]) sign = lax.rev(sign, dimensions=[s.ndim - 1]) u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1) vh = _H(u * sign[..., None, :]) return u, s, vh else: return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim - 1]) return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
def eigh(a, UPLO=None, symmetrize_input=True): if UPLO is None or UPLO == "L": lower = True elif UPLO == "U": lower = False else: msg = "UPLO must be one of None, 'L', or 'U', got {}".format(UPLO) raise ValueError(msg) a = _promote_arg_dtypes(jnp.asarray(a)) v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input) return w, v
def _eigh(a, b, lower, eigvals_only, eigvals, type): if b is not None: raise NotImplementedError( "Only the b=None case of eigh is implemented") if type != 1: raise NotImplementedError( "Only the type=1 case of eigh is implemented.") if eigvals is not None: raise NotImplementedError( "Only the eigvals=None case of eigh is implemented.") a, = _promote_dtypes_inexact(jnp.asarray(a)) v, w = lax_linalg.eigh(a, lower=lower) if eigvals_only: return w else: return w, v