コード例 #1
0
ファイル: linalg.py プロジェクト: xueeinstein/jax
    def _update_T_Z(m, T, Z):
        mu = np_linalg.eigvals(lax.dynamic_slice(T, (m - 1, m - 1),
                                                 (2, 2))) - T[m, m]
        r = np_linalg.norm(jnp.array([mu[0], T[m, m - 1]])).astype(T.dtype)
        c = mu[0] / r
        s = T[m, m - 1] / r
        G = jnp.array([[c.conj(), s], [-s, c]], dtype=T.dtype)

        # T[m-1:m+1, m-1:] = G @ T[m-1:m+1, m-1:]
        T_rows = lax.dynamic_slice_in_dim(T, m - 1, 2, axis=0)
        col_mask = jnp.arange(N) >= m - 1
        G_dot_T_zeroed_cols = G @ jnp.where(col_mask, T_rows, 0)
        T_rows_new = jnp.where(~col_mask, T_rows, G_dot_T_zeroed_cols)
        T = lax.dynamic_update_slice_in_dim(T, T_rows_new, m - 1, axis=0)

        # T[:m+1, m-1:m+1] = T[:m+1, m-1:m+1] @ G.conj().T
        T_cols = lax.dynamic_slice_in_dim(T, m - 1, 2, axis=1)
        row_mask = jnp.arange(N)[:, jnp.newaxis] < m + 1
        T_zeroed_rows_dot_GH = jnp.where(row_mask, T_cols, 0) @ G.conj().T
        T_cols_new = jnp.where(~row_mask, T_cols, T_zeroed_rows_dot_GH)
        T = lax.dynamic_update_slice_in_dim(T, T_cols_new, m - 1, axis=1)

        # Z[:, m-1:m+1] = Z[:, m-1:m+1] @ G.conj().T
        Z_cols = lax.dynamic_slice_in_dim(Z, m - 1, 2, axis=1)
        Z = lax.dynamic_update_slice_in_dim(Z,
                                            Z_cols @ G.conj().T,
                                            m - 1,
                                            axis=1)
        return T, Z
コード例 #2
0
ファイル: linalg.py プロジェクト: xueeinstein/jax
def funm(A, func, disp=True):
  A = jnp.asarray(A)
  if A.ndim != 2 or A.shape[0] != A.shape[1]:
    raise ValueError('expected square array_like input')

  T, Z = schur(A)
  T, Z = rsf2csf(T, Z)

  F = jnp.diag(func(jnp.diag(T)))
  F = F.astype(T.dtype.char)

  F, minden = _algorithm_11_1_1(F, T)
  F = Z @ F @ Z.conj().T

  if disp:
    return F

  if F.dtype.char.lower() == 'e':
    tol = jnp.finfo(jnp.float16).eps
  if F.dtype.char.lower() == 'f':
    tol = jnp.finfo(jnp.float32).eps
  else:
    tol = jnp.finfo(jnp.float64).eps

  minden = jnp.where(minden == 0.0, tol, minden)
  err = jnp.where(jnp.any(jnp.isinf(F)), jnp.inf, jnp.minimum(1, jnp.maximum(
          tol, (tol / minden) * norm(jnp.triu(T, 1), 1))))

  return F, err
コード例 #3
0
ファイル: linalg.py プロジェクト: xueeinstein/jax
def _calc_P_Q(A):
    A = jnp.asarray(A)
    if A.ndim != 2 or A.shape[0] != A.shape[1]:
        raise ValueError('expected A to be a square matrix')
    A_L1 = np_linalg.norm(A, 1)
    n_squarings = 0
    if A.dtype == 'float64' or A.dtype == 'complex128':
        maxnorm = 5.371920351148152
        n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
        A = A / 2**n_squarings.astype(A.dtype)
        conds = jnp.array([
            1.495585217958292e-002, 2.539398330063230e-001,
            9.504178996162932e-001, 2.097847961257068e+000
        ],
                          dtype=A_L1.dtype)
        idx = jnp.digitize(A_L1, conds)
        U, V = lax.switch(idx, [_pade3, _pade5, _pade7, _pade9, _pade13], A)
    elif A.dtype == 'float32' or A.dtype == 'complex64':
        maxnorm = 3.925724783138660
        n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
        A = A / 2**n_squarings.astype(A.dtype)
        conds = jnp.array([4.258730016922831e-001, 1.880152677804762e+000],
                          dtype=A_L1.dtype)
        idx = jnp.digitize(A_L1, conds)
        U, V = lax.switch(idx, [_pade3, _pade5, _pade7], A)
    else:
        raise TypeError(f"A.dtype={A.dtype} is not supported.")
    P = U + V  # p_m(A) : numerator
    Q = -U + V  # q_m(A) : denominator
    return P, Q, n_squarings
コード例 #4
0
def _calc_P_Q(A):
  A = jnp.asarray(A)
  if A.ndim != 2 or A.shape[0] != A.shape[1]:
    raise ValueError('expected A to be a square matrix')
  A_L1 = np_linalg.norm(A,1)
  n_squarings = 0
  if A.dtype == 'float64' or A.dtype == 'complex128':
   U3, V3 = _pade3(A)
   U5, V5 = _pade5(A)
   U7, V7 = _pade7(A)
   U9, V9 = _pade9(A)
   maxnorm = 5.371920351148152
   n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
   A = A / 2**n_squarings
   U13, V13 = _pade13(A)
   conds=jnp.array([1.495585217958292e-002, 2.539398330063230e-001,
                    9.504178996162932e-001, 2.097847961257068e+000])
   U = jnp.select((A_L1<conds), (U3, U5, U7, U9), U13)
   V = jnp.select((A_L1<conds), (V3, V5, V7, V9), V13)
  elif A.dtype == 'float32' or A.dtype == 'complex64':
    U3,V3 = _pade3(A)
    U5,V5 = _pade5(A)
    maxnorm = 3.925724783138660
    n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
    A = A / 2**n_squarings
    U7,V7 = _pade7(A)
    conds=jnp.array([4.258730016922831e-001, 1.880152677804762e+000])
    U = jnp.select((A_L1<conds), (U3, U5), U7)
    V = jnp.select((A_L1<conds), (V3, V5), V7)
  else:
    raise TypeError("A.dtype={} is not supported.".format(A.dtype))
  P = U + V  # p_m(A) : numerator
  Q = -U + V # q_m(A) : denominator
  return P, Q, n_squarings
コード例 #5
0
ファイル: linalg.py プロジェクト: yashk2810/jax
def cond(x, p=None):
    _assertNoEmpty2d(x)
    if p in (None, 2):
        s = la.svd(x, compute_uv=False)
        return s[..., 0] / s[..., -1]
    elif p == -2:
        s = la.svd(x, compute_uv=False)
        r = s[..., -1] / s[..., 0]
    else:
        _assertRankAtLeast2(x)
        _assertNdSquareness(x)
        invx = la.inv(x)
        r = la.norm(x, ord=p, axis=(-2, -1)) * la.norm(
            invx, ord=p, axis=(-2, -1))

    # Convert nans to infs unless the original array had nan entries
    orig_nan_check = jnp.full_like(r, ~jnp.isnan(r).any())
    nan_mask = jnp.logical_and(jnp.isnan(r), ~jnp.isnan(x).any(axis=(-2, -1)))
    r = jnp.where(orig_nan_check, jnp.where(nan_mask, jnp.inf, r), r)
    return r
コード例 #6
0
ファイル: eigh.py プロジェクト: cloudhan/jax
  def body_f_after_matmul(X):
    Q, _ = jnp_linalg.qr(X, mode="complete")
    # V1 = Q[:, :rank]
    # V2 = Q[:, rank:]
    V1 = _mask(Q, (n, rank))
    V2 = _slice(Q, (0, rank), (n, n - rank), (N, N))

    # TODO: might be able to get away with lower precision here
    error_matrix = jnp.dot(V2.conj().T, H)
    error_matrix = jnp.dot(error_matrix, V1)
    error = jnp_linalg.norm(error_matrix) / H_norm
    return V1, V2, error
コード例 #7
0
ファイル: vq.py プロジェクト: frederikwilde/jax
def vq(obs, code_book, check_finite=True):
    _check_arraylike("scipy.cluster.vq.vq", obs, code_book)
    if obs.ndim != code_book.ndim:
        raise ValueError("Observation and code_book should have the same rank")
    obs, code_book = _promote_dtypes_inexact(obs, code_book)
    if obs.ndim == 1:
        obs, code_book = obs[..., None], code_book[..., None]
    if obs.ndim != 2:
        raise ValueError("ndim different than 1 or 2 are not supported")

    # explicitly rank promotion
    dist = vmap(lambda ob: norm(ob[None] - code_book, axis=-1))(obs)
    code = argmin(dist, axis=-1)
    dist_min = vmap(operator.getitem)(dist, code)
    return code, dist_min
コード例 #8
0
ファイル: eigh.py プロジェクト: cloudhan/jax
def _projector_subspace(P, H, n, rank, maxiter=2):
  """ Decomposes the `n x n` rank `rank` Hermitian projector `P` into
  an `n x rank` isometry `V_minus` such that `P = V_minus @ V_minus.conj().T`
  and an `n x (n - rank)` isometry `V_minus` such that
  -(I - P) = V_plus @ V_plus.conj().T`.

  The subspaces are computed using the naiive QR eigendecomposition
  algorithm, which converges very quickly due to the sharp separation
  between the relevant eigenvalues of the projector.

  Args:
    P: A rank-`rank` Hermitian projector into the space of `H`'s
       first `rank` eigenpairs. `P` is padded to NxN.
    H: The aforementioned Hermitian matrix, which is used to track
       convergence.
    n: the true (dynamic) shape of `P`.
    rank: Rank of `P`.
    maxiter: Maximum number of iterations.
  Returns:
    V_minus, V_plus: Isometries into the eigenspaces described in the docstring.
  """
  # Choose an initial guess: the `rank` largest-norm columns of P.
  N, _ = P.shape
  column_norms = jnp_linalg.norm(P, axis=1)
  # `jnp.argsort` ensures NaNs sort last, so set masked-out column norms to NaN.
  column_norms = _mask(column_norms, (n,), jnp.nan)
  sort_idxs = jnp.argsort(column_norms)
  X = P[:, sort_idxs]
  # X = X[:, :rank]
  X = _mask(X, (n, rank))

  H_norm = jnp_linalg.norm(H)
  thresh = 10 * jnp.finfo(X.dtype).eps * H_norm

  # First iteration skips the matmul.
  def body_f_after_matmul(X):
    Q, _ = jnp_linalg.qr(X, mode="complete")
    # V1 = Q[:, :rank]
    # V2 = Q[:, rank:]
    V1 = _mask(Q, (n, rank))
    V2 = _slice(Q, (0, rank), (n, n - rank), (N, N))

    # TODO: might be able to get away with lower precision here
    error_matrix = jnp.dot(V2.conj().T, H)
    error_matrix = jnp.dot(error_matrix, V1)
    error = jnp_linalg.norm(error_matrix) / H_norm
    return V1, V2, error

  def cond_f(args):
    _, _, j, error = args
    still_counting = j < maxiter
    unconverged = error > thresh
    return jnp.logical_and(still_counting, unconverged)[0]

  def body_f(args):
    V1, _, j, _ = args
    X = jnp.dot(P, V1)
    V1, V2, error = body_f_after_matmul(X)
    return V1, V2, j + 1, error

  V1, V2, error = body_f_after_matmul(X)
  one = jnp.ones(1, dtype=jnp.int32)
  V1, V2, _, error = lax.while_loop(cond_f, body_f, (V1, V2, one, error))
  return V1, V2