コード例 #1
0
ファイル: linalg.py プロジェクト: gnecula/jax
def eigh(a,
         b=None,
         lower=True,
         eigvals_only=False,
         overwrite_a=False,
         overwrite_b=False,
         turbo=True,
         eigvals=None,
         type=1,
         check_finite=True):
    del overwrite_a, overwrite_b, turbo, check_finite
    if b is not None:
        raise NotImplementedError(
            "Only the b=None case of eigh is implemented")
    if type != 1:
        raise NotImplementedError(
            "Only the type=1 case of eigh is implemented.")
    if eigvals is not None:
        raise NotImplementedError(
            "Only the eigvals=None case of eigh is implemented.")

    a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
    v, w = lax_linalg.eigh(a, lower=lower)

    if eigvals_only:
        return w
    else:
        return w, v
コード例 #2
0
ファイル: linalg.py プロジェクト: cloudhan/jax
def svd(a,
        full_matrices: bool = True,
        compute_uv: bool = True,
        hermitian: bool = False):
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    if hermitian:
        w, v = lax_linalg.eigh(a)
        s = lax.abs(v)
        if compute_uv:
            sign = lax.sign(v)
            idxs = lax.broadcasted_iota(np.int64,
                                        s.shape,
                                        dimension=s.ndim - 1)
            s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1)
            s = lax.rev(s, dimensions=[s.ndim - 1])
            idxs = lax.rev(idxs, dimensions=[s.ndim - 1])
            sign = lax.rev(sign, dimensions=[s.ndim - 1])
            u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1)
            vh = _H(u * sign[..., None, :])
            return u, s, vh
        else:
            return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim - 1])

    return lax_linalg.svd(a,
                          full_matrices=full_matrices,
                          compute_uv=compute_uv)
コード例 #3
0
ファイル: linalg.py プロジェクト: ahoenselaar/jax
def eigh(a, UPLO=None, symmetrize_input=True):
    if UPLO is None or UPLO == "L":
        lower = True
    elif UPLO == "U":
        lower = False
    else:
        msg = "UPLO must be one of None, 'L', or 'U', got {}".format(UPLO)
        raise ValueError(msg)

    a = _promote_arg_dtypes(jnp.asarray(a))
    v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input)
    return w, v
コード例 #4
0
ファイル: linalg.py プロジェクト: xueeinstein/jax
def _eigh(a, b, lower, eigvals_only, eigvals, type):
    if b is not None:
        raise NotImplementedError(
            "Only the b=None case of eigh is implemented")
    if type != 1:
        raise NotImplementedError(
            "Only the type=1 case of eigh is implemented.")
    if eigvals is not None:
        raise NotImplementedError(
            "Only the eigvals=None case of eigh is implemented.")

    a, = _promote_dtypes_inexact(jnp.asarray(a))
    v, w = lax_linalg.eigh(a, lower=lower)

    if eigvals_only:
        return w
    else:
        return w, v