예제 #1
0
파일: linalg.py 프로젝트: xueeinstein/jax
def _calc_P_Q(A):
    A = jnp.asarray(A)
    if A.ndim != 2 or A.shape[0] != A.shape[1]:
        raise ValueError('expected A to be a square matrix')
    A_L1 = np_linalg.norm(A, 1)
    n_squarings = 0
    if A.dtype == 'float64' or A.dtype == 'complex128':
        maxnorm = 5.371920351148152
        n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
        A = A / 2**n_squarings.astype(A.dtype)
        conds = jnp.array([
            1.495585217958292e-002, 2.539398330063230e-001,
            9.504178996162932e-001, 2.097847961257068e+000
        ],
                          dtype=A_L1.dtype)
        idx = jnp.digitize(A_L1, conds)
        U, V = lax.switch(idx, [_pade3, _pade5, _pade7, _pade9, _pade13], A)
    elif A.dtype == 'float32' or A.dtype == 'complex64':
        maxnorm = 3.925724783138660
        n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
        A = A / 2**n_squarings.astype(A.dtype)
        conds = jnp.array([4.258730016922831e-001, 1.880152677804762e+000],
                          dtype=A_L1.dtype)
        idx = jnp.digitize(A_L1, conds)
        U, V = lax.switch(idx, [_pade3, _pade5, _pade7], A)
    else:
        raise TypeError(f"A.dtype={A.dtype} is not supported.")
    P = U + V  # p_m(A) : numerator
    Q = -U + V  # q_m(A) : denominator
    return P, Q, n_squarings
예제 #2
0
def _calc_P_Q(A):
  A = jnp.asarray(A)
  if A.ndim != 2 or A.shape[0] != A.shape[1]:
    raise ValueError('expected A to be a square matrix')
  A_L1 = np_linalg.norm(A,1)
  n_squarings = 0
  if A.dtype == 'float64' or A.dtype == 'complex128':
   U3, V3 = _pade3(A)
   U5, V5 = _pade5(A)
   U7, V7 = _pade7(A)
   U9, V9 = _pade9(A)
   maxnorm = 5.371920351148152
   n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
   A = A / 2**n_squarings
   U13, V13 = _pade13(A)
   conds=jnp.array([1.495585217958292e-002, 2.539398330063230e-001,
                    9.504178996162932e-001, 2.097847961257068e+000])
   U = jnp.select((A_L1<conds), (U3, U5, U7, U9), U13)
   V = jnp.select((A_L1<conds), (V3, V5, V7, V9), V13)
  elif A.dtype == 'float32' or A.dtype == 'complex64':
    U3,V3 = _pade3(A)
    U5,V5 = _pade5(A)
    maxnorm = 3.925724783138660
    n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
    A = A / 2**n_squarings
    U7,V7 = _pade7(A)
    conds=jnp.array([4.258730016922831e-001, 1.880152677804762e+000])
    U = jnp.select((A_L1<conds), (U3, U5), U7)
    V = jnp.select((A_L1<conds), (V3, V5), V7)
  else:
    raise TypeError("A.dtype={} is not supported.".format(A.dtype))
  P = U + V  # p_m(A) : numerator
  Q = -U + V # q_m(A) : denominator
  return P, Q, n_squarings
예제 #3
0
파일: linalg.py 프로젝트: xueeinstein/jax
def funm(A, func, disp=True):
  A = jnp.asarray(A)
  if A.ndim != 2 or A.shape[0] != A.shape[1]:
    raise ValueError('expected square array_like input')

  T, Z = schur(A)
  T, Z = rsf2csf(T, Z)

  F = jnp.diag(func(jnp.diag(T)))
  F = F.astype(T.dtype.char)

  F, minden = _algorithm_11_1_1(F, T)
  F = Z @ F @ Z.conj().T

  if disp:
    return F

  if F.dtype.char.lower() == 'e':
    tol = jnp.finfo(jnp.float16).eps
  if F.dtype.char.lower() == 'f':
    tol = jnp.finfo(jnp.float32).eps
  else:
    tol = jnp.finfo(jnp.float64).eps

  minden = jnp.where(minden == 0.0, tol, minden)
  err = jnp.where(jnp.any(jnp.isinf(F)), jnp.inf, jnp.minimum(1, jnp.maximum(
          tol, (tol / minden) * norm(jnp.triu(T, 1), 1))))

  return F, err
예제 #4
0
파일: polynomial.py 프로젝트: jbampton/jax
def polyint(p, m=1, k=None):
  m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
  k = 0 if k is None else k
  _check_arraylike("polyint", p, k)
  p, k = _promote_dtypes_inexact(p, k)
  if m < 0:
    raise ValueError("Order of integral must be positive (see polyder)")
  k = atleast_1d(k)
  if len(k) == 1:
    k = full((m,), k[0])
  if k.shape != (m,):
    raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.")
  if m == 0:
    return p
  else:
    coeff = maximum(1, arange(len(p) + m, 0, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0)
    return true_divide(concatenate((p, k)), coeff)
예제 #5
0
파일: linalg.py 프로젝트: xueeinstein/jax
def eigh_tridiagonal(d,
                     e,
                     *,
                     eigvals_only=False,
                     select='a',
                     select_range=None,
                     tol=None):
    if not eigvals_only:
        raise NotImplementedError(
            "Calculation of eigenvectors is not implemented")

    def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x):
        """Implements the Sturm sequence recurrence."""
        n = alpha.shape[0]
        zeros = jnp.zeros(x.shape, dtype=jnp.int32)
        ones = jnp.ones(x.shape, dtype=jnp.int32)

        # The first step in the Sturm sequence recurrence
        # requires special care if x is equal to alpha[0].
        def sturm_step0():
            q = alpha[0] - x
            count = jnp.where(q < 0, ones, zeros)
            q = jnp.where(alpha[0] == x, alpha0_perturbation, q)
            return q, count

        # Subsequent steps all take this form:
        def sturm_step(i, q, count):
            q = alpha[i] - beta_sq[i - 1] / q - x
            count = jnp.where(q <= pivmin, count + 1, count)
            q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q)
            return q, count

        # The first step initializes q and count.
        q, count = sturm_step0()

        # Peel off ((n-1) % blocksize) steps from the main loop, so we can run
        # the bulk of the iterations unrolled by a factor of blocksize.
        blocksize = 16
        i = 1
        peel = (n - 1) % blocksize
        unroll_cnt = peel

        def unrolled_steps(args):
            start, q, count = args
            for j in range(unroll_cnt):
                q, count = sturm_step(start + j, q, count)
            return start + unroll_cnt, q, count

        i, q, count = unrolled_steps((i, q, count))

        # Run the remaining steps of the Sturm sequence using a partially
        # unrolled while loop.
        unroll_cnt = blocksize

        def cond(iqc):
            i, q, count = iqc
            return jnp.less(i, n)

        _, _, count = lax.while_loop(cond, unrolled_steps, (i, q, count))
        return count

    alpha = jnp.asarray(d)
    beta = jnp.asarray(e)
    supported_dtypes = (jnp.float32, jnp.float64, jnp.complex64,
                        jnp.complex128)
    if alpha.dtype != beta.dtype:
        raise TypeError(
            "diagonal and off-diagonal values must have same dtype, "
            f"got {alpha.dtype} and {beta.dtype}")
    if alpha.dtype not in supported_dtypes or beta.dtype not in supported_dtypes:
        raise TypeError(
            "Only float32 and float64 inputs are supported as inputs "
            "to jax.scipy.linalg.eigh_tridiagonal, got "
            f"{alpha.dtype} and {beta.dtype}")
    n = alpha.shape[0]
    if n <= 1:
        return jnp.real(alpha)

    if jnp.issubdtype(alpha.dtype, jnp.complexfloating):
        alpha = jnp.real(alpha)
        beta_sq = jnp.real(beta * jnp.conj(beta))
        beta_abs = jnp.sqrt(beta_sq)
    else:
        beta_abs = jnp.abs(beta)
        beta_sq = jnp.square(beta)

    # Estimate the largest and smallest eigenvalues of T using the Gershgorin
    # circle theorem.
    off_diag_abs_row_sum = jnp.concatenate(
        [beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0)
    lambda_est_max = jnp.amax(alpha + off_diag_abs_row_sum)
    lambda_est_min = jnp.amin(alpha - off_diag_abs_row_sum)
    # Upper bound on 2-norm of T.
    t_norm = jnp.maximum(jnp.abs(lambda_est_min), jnp.abs(lambda_est_max))

    # Compute the smallest allowed pivot in the Sturm sequence to avoid
    # overflow.
    finfo = np.finfo(alpha.dtype)
    one = np.ones([], dtype=alpha.dtype)
    safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny)
    pivmin = safemin * jnp.maximum(1, jnp.amax(beta_sq))
    alpha0_perturbation = jnp.square(finfo.eps * beta_abs[0])
    abs_tol = finfo.eps * t_norm
    if tol is not None:
        abs_tol = jnp.maximum(tol, abs_tol)

    # In the worst case, when the absolute tolerance is eps*lambda_est_max and
    # lambda_est_max = -lambda_est_min, we have to take as many bisection steps
    # as there are bits in the mantissa plus 1.
    # The proof is left as an exercise to the reader.
    max_it = finfo.nmant + 1

    # Determine the indices of the desired eigenvalues, based on select and
    # select_range.
    if select == 'a':
        target_counts = jnp.arange(n, dtype=jnp.int32)
    elif select == 'i':
        if select_range[0] > select_range[1]:
            raise ValueError('Got empty index range in select_range.')
        target_counts = jnp.arange(select_range[0],
                                   select_range[1] + 1,
                                   dtype=jnp.int32)
    elif select == 'v':
        # TODO(phawkins): requires dynamic shape support.
        raise NotImplementedError("eigh_tridiagonal(..., select='v') is not "
                                  "implemented")
    else:
        raise ValueError("'select must have a value in {'a', 'i', 'v'}.")

    # Run binary search for all desired eigenvalues in parallel, starting from
    # the interval lightly wider than the estimated
    # [lambda_est_min, lambda_est_max].
    fudge = 2.1  # We widen starting interval the Gershgorin interval a bit.
    norm_slack = jnp.array(n, alpha.dtype) * fudge * finfo.eps * t_norm
    lower = lambda_est_min - norm_slack - 2 * fudge * pivmin
    upper = lambda_est_max + norm_slack + fudge * pivmin

    # Pre-broadcast the scalars used in the Sturm sequence for improved
    # performance.
    target_shape = jnp.shape(target_counts)
    lower = jnp.broadcast_to(lower, shape=target_shape)
    upper = jnp.broadcast_to(upper, shape=target_shape)
    mid = 0.5 * (upper + lower)
    pivmin = jnp.broadcast_to(pivmin, target_shape)
    alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape)

    # Start parallel binary searches.
    def cond(args):
        i, lower, _, upper = args
        return jnp.logical_and(jnp.less(i, max_it),
                               jnp.less(abs_tol, jnp.amax(upper - lower)))

    def body(args):
        i, lower, mid, upper = args
        counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid)
        lower = jnp.where(counts <= target_counts, mid, lower)
        upper = jnp.where(counts > target_counts, mid, upper)
        mid = 0.5 * (lower + upper)
        return i + 1, lower, mid, upper

    _, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper))
    return mid