コード例 #1
0
ファイル: eigh.py プロジェクト: cloudhan/jax
  def body_f_after_matmul(X):
    Q, _ = jnp_linalg.qr(X, mode="complete")
    # V1 = Q[:, :rank]
    # V2 = Q[:, rank:]
    V1 = _mask(Q, (n, rank))
    V2 = _slice(Q, (0, rank), (n, n - rank), (N, N))

    # TODO: might be able to get away with lower precision here
    error_matrix = jnp.dot(V2.conj().T, H)
    error_matrix = jnp.dot(error_matrix, V1)
    error = jnp_linalg.norm(error_matrix) / H_norm
    return V1, V2, error
コード例 #2
0
ファイル: eigh.py プロジェクト: cloudhan/jax
  def base_case(B, offset, b, agenda, blocks, eigenvectors):
    # Base case: for blocks under a minimum size, we cutoff the recursion
    # and call the TPU Jacobi eigendecomposition implementation. The Jacobi
    # algorithm works well for small matrices but scales poorly, so the two
    # complement each other well.
    H = _slice(blocks, (offset, 0), (b, b), (B, B))
    V = _slice(eigenvectors, (0, offset), (n, b), (N, B))

    # We replace the masked-out part of the matrix with the identity matrix.
    # We know that the TPU Jacobi eigh implementation will not alter the order
    # of the eigenvalues, so we know the eigendecomposition of the original
    # matrix is in the top-left corner of the eigendecomposition of the padded
    # matrix.
    # It is very important that the underlying eigh implementation does not sort
    # the eigenvalues for this reason! This is currently not true of JAX's CPU
    # and GPU eigendecompositions, and for those platforms this algorithm will
    # only do the right thing if termination_size == 1.
    H = _mask(H, (b, b), jnp.eye(B, dtype=H.dtype))
    eig_vecs, eig_vals = lax.linalg.eigh(H, sort_eigenvalues=False)
    eig_vecs = _mask(eig_vecs, (b, b))
    eig_vals = _mask(eig_vals, (b,))
    eig_vecs = jnp.dot(V, eig_vecs)

    blocks = _update_slice(blocks, eig_vals[:, None], (offset, 0), (b, b))
    eigenvectors = _update_slice(eigenvectors, eig_vecs, (0, offset), (n, b))
    return agenda, blocks, eigenvectors
コード例 #3
0
ファイル: linalg.py プロジェクト: yashk2810/jax
def multi_dot(arrays, *, precision=None):
    n = len(arrays)
    # optimization only makes sense for len(arrays) > 2
    if n < 2:
        raise ValueError("Expecting at least two arrays.")
    elif n == 2:
        return jnp.dot(arrays[0], arrays[1], precision=precision)

    arrays = [jnp.asarray(a) for a in arrays]

    # save original ndim to reshape the result array into the proper form later
    ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim
    # Explicitly convert vectors to 2D arrays to keep the logic of the internal
    # _multi_dot_* functions as simple as possible.
    if arrays[0].ndim == 1:
        arrays[0] = jnp.atleast_2d(arrays[0])
    if arrays[-1].ndim == 1:
        arrays[-1] = jnp.atleast_2d(arrays[-1]).T
    _assert2d(*arrays)

    # _multi_dot_three is much faster than _multi_dot_matrix_chain_order
    if n == 3:
        result = _multi_dot_three(*arrays, precision)
    else:
        order = _multi_dot_matrix_chain_order(arrays)
        result = _multi_dot(arrays, order, 0, n - 1, precision)

    # return proper shape
    if ndim_first == 1 and ndim_last == 1:
        return result[0, 0]  # scalar
    elif ndim_first == 1 or ndim_last == 1:
        return result.ravel()  # 1-D
    else:
        return result
コード例 #4
0
ファイル: linalg.py プロジェクト: yashk2810/jax
def _multi_dot(arrays, order, i, j, precision):
    """Actually do the multiplication with the given order."""
    if i == j:
        return arrays[i]
    else:
        return jnp.dot(_multi_dot(arrays, order, i, order[i, j], precision),
                       _multi_dot(arrays, order, order[i, j] + 1, j,
                                  precision),
                       precision=precision)
コード例 #5
0
ファイル: linalg.py プロジェクト: yashk2810/jax
def _multi_dot_three(A, B, C, precision):
    """
    Find the best order for three arrays and do the multiplication.
    For three arguments `_multi_dot_three` is approximately 15 times faster
    than `_multi_dot_matrix_chain_order`
    """
    a0, a1b0 = A.shape
    b1c0, c1 = C.shape
    # cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1
    cost1 = a0 * b1c0 * (a1b0 + c1)
    # cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1
    cost2 = a1b0 * c1 * (a0 + b1c0)

    if cost1 < cost2:
        return jnp.dot(jnp.dot(A, B, precision=precision),
                       C,
                       precision=precision)
    else:
        return jnp.dot(A,
                       jnp.dot(B, C, precision=precision),
                       precision=precision)
コード例 #6
0
def split_spectrum(H, n, split_point, V0=None):
    """ The Hermitian matrix `H` is split into two matrices `H_minus`
  `H_plus`, respectively sharing its eigenspaces beneath and above
  its `split_point`th eigenvalue.

  Returns, in addition, `V_minus` and `V_plus`, isometries such that
  `Hi = Vi.conj().T @ H @ Vi`. If `V0` is not None, `V0 @ Vi` are
  returned instead; this allows the overall isometries mapping from
  an initial input matrix to progressively smaller blocks to be formed.

  Args:
    H: The Hermitian matrix to split.
    split_point: The eigenvalue to split along.
    V0: Matrix of isometries to be updated.
  Returns:
    H_minus: A Hermitian matrix sharing the eigenvalues of `H` beneath
      `split_point`.
    V_minus: An isometry from the input space of `V0` to `H_minus`.
    H_plus: A Hermitian matrix sharing the eigenvalues of `H` above
      `split_point`.
    V_plus: An isometry from the input space of `V0` to `H_plus`.
    rank: The dynamic size of the m subblock.
  """
    N, _ = H.shape
    H_shift = H - (split_point * jnp.eye(N, dtype=split_point.dtype)).astype(
        H.dtype)
    U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True, dynamic_shape=(n, n))
    P = -0.5 * (U - _mask(jnp.eye(N, dtype=H.dtype), (n, n)))
    rank = jnp.round(jnp.trace(jnp.real(P))).astype(jnp.int32)

    V_minus, V_plus = _projector_subspace(P, H, n, rank)
    H_minus = (V_minus.conj().T @ H) @ V_minus
    H_plus = (V_plus.conj().T @ H) @ V_plus
    if V0 is not None:
        V_minus = jnp.dot(V0, V_minus)
        V_plus = jnp.dot(V0, V_plus)
    return H_minus, V_minus, H_plus, V_plus, rank
コード例 #7
0
ファイル: polynomial.py プロジェクト: frederikwilde/jax
def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
    _check_arraylike("polyfit", x, y)
    deg = core.concrete_or_error(int, deg, "deg must be int")
    order = deg + 1
    # check arguments
    if deg < 0:
        raise ValueError("expected deg >= 0")
    if x.ndim != 1:
        raise TypeError("expected 1D vector for x")
    if x.size == 0:
        raise TypeError("expected non-empty vector for x")
    if y.ndim < 1 or y.ndim > 2:
        raise TypeError("expected 1D or 2D array for y")
    if x.shape[0] != y.shape[0]:
        raise TypeError("expected x and y to have same length")

    # set rcond
    if rcond is None:
        rcond = len(x) * finfo(x.dtype).eps
    rcond = core.concrete_or_error(float, rcond, "rcond must be float")
    # set up least squares equation for powers of x
    lhs = vander(x, order)
    rhs = y

    # apply weighting
    if w is not None:
        _check_arraylike("polyfit", w)
        w, = _promote_dtypes_inexact(w)
        if w.ndim != 1:
            raise TypeError("expected a 1-d array for weights")
        if w.shape[0] != y.shape[0]:
            raise TypeError("expected w and y to have the same length")
        lhs *= w[:, np.newaxis]
        if rhs.ndim == 2:
            rhs *= w[:, np.newaxis]
        else:
            rhs *= w

    # scale lhs to improve condition number and solve
    scale = sqrt((lhs * lhs).sum(axis=0))
    lhs /= scale[np.newaxis, :]
    c, resids, rank, s = linalg.lstsq(lhs, rhs, rcond)
    c = (c.T / scale).T  # broadcast scale coefficients

    if full:
        return c, resids, rank, s, rcond
    elif cov:
        Vbase = linalg.inv(dot(lhs.T, lhs))
        Vbase /= outer(scale, scale)
        if cov == "unscaled":
            fac = 1
        else:
            if len(x) <= order:
                raise ValueError("the number of data points must exceed order "
                                 "to scale the covariance matrix")
            fac = resids / (len(x) - order)
            fac = fac[0]  #making np.array() of shape (1,) to int
        if y.ndim == 1:
            return c, Vbase * fac
        else:
            return c, Vbase[:, :, np.newaxis] * fac
    else:
        return c
コード例 #8
0
ファイル: linalg.py プロジェクト: xueeinstein/jax
def _precise_dot(A, B):
    return jnp.dot(A, B, precision=lax.Precision.HIGHEST)
コード例 #9
0
ファイル: eigh.py プロジェクト: cloudhan/jax
 def body_f(args):
   V1, _, j, _ = args
   X = jnp.dot(P, V1)
   V1, V2, error = body_f_after_matmul(X)
   return V1, V2, j + 1, error