Exemple #1
0
def _solve_triangular(a, b, trans, lower, unit_diagonal):
    if trans == 0 or trans == "N":
        transpose_a, conjugate_a = False, False
    elif trans == 1 or trans == "T":
        transpose_a, conjugate_a = True, False
    elif trans == 2 or trans == "C":
        transpose_a, conjugate_a = True, True
    else:
        raise ValueError("Invalid 'trans' value {}".format(trans))

    a, b = np_linalg._promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b))

    # lax_linalg.triangular_solve only supports matrix 'b's at the moment.
    b_is_vector = jnp.ndim(a) == jnp.ndim(b) + 1
    if b_is_vector:
        b = b[..., None]
    out = lax_linalg.triangular_solve(a,
                                      b,
                                      left_side=True,
                                      lower=lower,
                                      transpose_a=transpose_a,
                                      conjugate_a=conjugate_a,
                                      unit_diagonal=unit_diagonal)
    if b_is_vector:
        return out[..., 0]
    else:
        return out
Exemple #2
0
def eigh(a,
         b=None,
         lower=True,
         eigvals_only=False,
         overwrite_a=False,
         overwrite_b=False,
         turbo=True,
         eigvals=None,
         type=1,
         check_finite=True):
    del overwrite_a, overwrite_b, turbo, check_finite
    if b is not None:
        raise NotImplementedError(
            "Only the b=None case of eigh is implemented")
    if type != 1:
        raise NotImplementedError(
            "Only the type=1 case of eigh is implemented.")
    if eigvals is not None:
        raise NotImplementedError(
            "Only the eigvals=None case of eigh is implemented.")

    a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
    v, w = lax_linalg.eigh(a, lower=lower)

    if eigvals_only:
        return w
    else:
        return w, v
Exemple #3
0
def _cho_solve(c, b, lower):
  c, b = np_linalg._promote_arg_dtypes(jnp.asarray(c), jnp.asarray(b))
  lax_linalg._check_solve_shapes(c, b)
  b = lax_linalg.triangular_solve(c, b, left_side=True, lower=lower,
                                  transpose_a=not lower, conjugate_a=not lower)
  b = lax_linalg.triangular_solve(c, b, left_side=True, lower=lower,
                                  transpose_a=lower, conjugate_a=lower)
  return b
Exemple #4
0
def svd(a,
        full_matrices=True,
        compute_uv=True,
        overwrite_a=False,
        check_finite=True,
        lapack_driver='gesdd'):
    del overwrite_a, check_finite, lapack_driver
    a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
    return lax_linalg.svd(a, full_matrices, compute_uv)
Exemple #5
0
def _lu(a, permute_l):
    a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
    lu, pivots, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    m, n = jnp.shape(a)
    p = jnp.real(jnp.array(permutation == jnp.arange(m)[:, None], dtype=dtype))
    k = min(m, n)
    l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype)
    u = jnp.triu(lu)[:k, :]
    if permute_l:
        return jnp.matmul(p, l), u
    else:
        return p, l, u
Exemple #6
0
def _qr(a, mode, pivoting):
    if pivoting:
        raise NotImplementedError(
            "The pivoting=True case of qr is not implemented.")
    if mode in ("full", "r"):
        full_matrices = True
    elif mode == "economic":
        full_matrices = False
    else:
        raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
    a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
    q, r = lax_linalg.qr(a, full_matrices)
    if mode == "r":
        return r
    return q, r
Exemple #7
0
def _eigh(a, b, lower, eigvals_only, eigvals, type):
    if b is not None:
        raise NotImplementedError(
            "Only the b=None case of eigh is implemented")
    if type != 1:
        raise NotImplementedError(
            "Only the type=1 case of eigh is implemented.")
    if eigvals is not None:
        raise NotImplementedError(
            "Only the eigvals=None case of eigh is implemented.")

    a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
    v, w = lax_linalg.eigh(a, lower=lower)

    if eigvals_only:
        return w
    else:
        return w, v
Exemple #8
0
def _solve(a, b, sym_pos, lower):
    if not sym_pos:
        return np_linalg.solve(a, b)

    a, b = np_linalg._promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b))
    lax_linalg._check_solve_shapes(a, b)

    # With custom_linear_solve, we can reuse the same factorization when
    # computing sensitivities. This is considerably faster.
    factors = cho_factor(lax.stop_gradient(a), lower=lower)
    custom_solve = partial(lax.custom_linear_solve,
                           lambda x: lax_linalg._matvec_multiply(a, x),
                           solve=lambda _, x: cho_solve(factors, x),
                           symmetric=True)
    if a.ndim == b.ndim + 1:
        # b.shape == [..., m]
        return custom_solve(b)
    else:
        # b.shape == [..., m, k]
        return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
Exemple #9
0
def _cholesky(a, lower):
    a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
    l = lax_linalg.cholesky(a if lower else jnp.conj(_T(a)),
                            symmetrize_input=False)
    return l if lower else jnp.conj(_T(l))
Exemple #10
0
def lu_factor(a, overwrite_a=False, check_finite=True):
    del overwrite_a, check_finite
    a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
    lu, pivots, _ = lax_linalg.lu(a)
    return lu, pivots
Exemple #11
0
def _svd(a, *, full_matrices, compute_uv):
    a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
    return lax_linalg.svd(a, full_matrices, compute_uv)