示例#1
0
文件: linalg.py 项目: xueeinstein/jax
def funm(A, func, disp=True):
  A = jnp.asarray(A)
  if A.ndim != 2 or A.shape[0] != A.shape[1]:
    raise ValueError('expected square array_like input')

  T, Z = schur(A)
  T, Z = rsf2csf(T, Z)

  F = jnp.diag(func(jnp.diag(T)))
  F = F.astype(T.dtype.char)

  F, minden = _algorithm_11_1_1(F, T)
  F = Z @ F @ Z.conj().T

  if disp:
    return F

  if F.dtype.char.lower() == 'e':
    tol = jnp.finfo(jnp.float16).eps
  if F.dtype.char.lower() == 'f':
    tol = jnp.finfo(jnp.float32).eps
  else:
    tol = jnp.finfo(jnp.float64).eps

  minden = jnp.where(minden == 0.0, tol, minden)
  err = jnp.where(jnp.any(jnp.isinf(F)), jnp.inf, jnp.minimum(1, jnp.maximum(
          tol, (tol / minden) * norm(jnp.triu(T, 1), 1))))

  return F, err
示例#2
0
文件: linalg.py 项目: xueeinstein/jax
def rsf2csf(T, Z, check_finite=True):
    T = jnp.asarray(T)
    Z = jnp.asarray(Z)

    for ind, X in enumerate([Z, T]):
        if X.ndim != 2 or X.shape[0] != X.shape[1]:
            arg = 'ZT'[ind]
            raise ValueError(f"Input '{arg}' must be square.")
    if T.shape[0] != Z.shape[0]:
        raise ValueError(
            f"Input array shapes must match: Z: {Z.shape} vs. T: {T.shape}")

    T, Z = _promote_dtypes_complex(T, Z)
    eps = jnp.finfo(T.dtype).eps
    N = T.shape[0]

    if N == 1:
        return T, Z

    def _update_T_Z(m, T, Z):
        mu = np_linalg.eigvals(lax.dynamic_slice(T, (m - 1, m - 1),
                                                 (2, 2))) - T[m, m]
        r = np_linalg.norm(jnp.array([mu[0], T[m, m - 1]])).astype(T.dtype)
        c = mu[0] / r
        s = T[m, m - 1] / r
        G = jnp.array([[c.conj(), s], [-s, c]], dtype=T.dtype)

        # T[m-1:m+1, m-1:] = G @ T[m-1:m+1, m-1:]
        T_rows = lax.dynamic_slice_in_dim(T, m - 1, 2, axis=0)
        col_mask = jnp.arange(N) >= m - 1
        G_dot_T_zeroed_cols = G @ jnp.where(col_mask, T_rows, 0)
        T_rows_new = jnp.where(~col_mask, T_rows, G_dot_T_zeroed_cols)
        T = lax.dynamic_update_slice_in_dim(T, T_rows_new, m - 1, axis=0)

        # T[:m+1, m-1:m+1] = T[:m+1, m-1:m+1] @ G.conj().T
        T_cols = lax.dynamic_slice_in_dim(T, m - 1, 2, axis=1)
        row_mask = jnp.arange(N)[:, jnp.newaxis] < m + 1
        T_zeroed_rows_dot_GH = jnp.where(row_mask, T_cols, 0) @ G.conj().T
        T_cols_new = jnp.where(~row_mask, T_cols, T_zeroed_rows_dot_GH)
        T = lax.dynamic_update_slice_in_dim(T, T_cols_new, m - 1, axis=1)

        # Z[:, m-1:m+1] = Z[:, m-1:m+1] @ G.conj().T
        Z_cols = lax.dynamic_slice_in_dim(Z, m - 1, 2, axis=1)
        Z = lax.dynamic_update_slice_in_dim(Z,
                                            Z_cols @ G.conj().T,
                                            m - 1,
                                            axis=1)
        return T, Z

    def _rsf2scf_iter(i, TZ):
        m = N - i
        T, Z = TZ
        T, Z = lax.cond(
            jnp.abs(T[m, m - 1]) > eps *
            (jnp.abs(T[m - 1, m - 1]) + jnp.abs(T[m, m])), _update_T_Z,
            lambda m, T, Z: (T, Z), m, T, Z)
        T = T.at[m, m - 1].set(0.0)
        return T, Z

    return lax.fori_loop(1, N, _rsf2scf_iter, (T, Z))
示例#3
0
def zeta(x, q=None):
    assert q is not None, "Riemann zeta function is not implemented yet."
    # Reference: Johansson, Fredrik.
    # "Rigorous high-precision computation of the Hurwitz zeta function and its derivatives."
    # Numerical Algorithms 69.2 (2015): 253-270.
    # https://arxiv.org/abs/1309.2877 - formula (5)
    # here we keep the same notation as in reference
    s, a = _promote_args_inexact("zeta", x, q)
    dtype = lax.dtype(a).type
    s_, a_ = jnp.expand_dims(s, -1), jnp.expand_dims(a, -1)
    # precision ~ N, M
    N = M = dtype(8) if lax.dtype(a) == jnp.float32 else dtype(16)
    assert M <= len(_BERNOULLI_COEFS)
    k = jnp.expand_dims(np.arange(N, dtype=N.dtype), tuple(range(a.ndim)))
    S = jnp.sum((a_ + k)**-s_, -1)
    I = lax.div((a + N)**(dtype(1) - s), s - dtype(1))
    T0 = (a + N)**-s
    m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim)))
    s_over_a = (s_ + m) / (a_ + N)
    T1 = jnp.cumprod(s_over_a, -1)[..., ::2]
    T1 = jnp.clip(T1, a_max=jnp.finfo(dtype).max)
    coefs = np.expand_dims(
        np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype),
        tuple(range(a.ndim)))
    T1 = T1 / coefs
    T = T0 * (dtype(0.5) + T1.sum(-1))
    return S + I + T
示例#4
0
文件: linalg.py 项目: ahoenselaar/jax
def matrix_rank(M, tol=None):
    M = _promote_arg_dtypes(jnp.asarray(M))
    if M.ndim > 2:
        raise TypeError("array should have 2 or fewer dimensions")
    if M.ndim < 2:
        return jnp.any(M != 0).astype(jnp.int32)
    S = svd(M, full_matrices=False, compute_uv=False)
    if tol is None:
        tol = S.max() * np.max(M.shape) * jnp.finfo(S.dtype).eps
    return jnp.sum(S > tol)
示例#5
0
文件: linalg.py 项目: ahoenselaar/jax
def _lstsq(a, b, rcond, *, numpy_resid=False):
    # TODO: add lstsq to lax_linalg and implement this function via those wrappers.
    # TODO: add custom jvp rule for more robust lstsq differentiation
    a, b = _promote_arg_dtypes(a, b)
    if a.shape[0] != b.shape[0]:
        raise ValueError("Leading dimensions of input arrays must match")
    b_orig_ndim = b.ndim
    if b_orig_ndim == 1:
        b = b[:, None]
    if a.ndim != 2:
        raise TypeError(
            f"{a.ndim}-dimensional array given. Array must be two-dimensional")
    if b.ndim != 2:
        raise TypeError(
            f"{b.ndim}-dimensional array given. Array must be one or two-dimensional"
        )
    m, n = a.shape
    dtype = a.dtype
    if rcond is None:
        rcond = jnp.finfo(dtype).eps * max(n, m)
    else:
        rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
    u, s, vt = svd(a, full_matrices=False)
    mask = s >= rcond * s[0]
    rank = mask.sum()
    safe_s = jnp.where(mask, s, 1)
    s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
    uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
    x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
    # Numpy returns empty residuals in some cases. To allow compilation, we
    # default to returning full residuals in all cases.
    if numpy_resid and (rank < n or m <= n):
        resid = jnp.asarray([])
    else:
        b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST)
        resid = norm(b - b_estimate, axis=0)**2
    if b_orig_ndim == 1:
        x = x.ravel()
    return x, resid, rank, s
示例#6
0
def _expn2(n, x):
    # x > 1.
    _c = _constant_like
    BIG = _c(x, 1.44115188075855872e17)
    MACHEP = jnp.finfo(BIG.dtype).eps  # ?
    zero = _c(x, 0.0)
    one = _c(x, 1.0)

    init = dict(
        k=_c(n, 1),
        pkm2=one,
        qkm2=x,
        pkm1=one,
        qkm1=x + n,
        ans=one / (x + n),
        t=_c(x, jnp.inf),
        r=zero,
        x=x,
    )

    def body(d):
        x = d["x"]
        d["k"] += _c(d["k"], 1)
        k = d["k"]
        odd = k % _c(k, 2) == _c(k, 1)
        yk = jnp.where(odd, one, x)
        xk = jnp.where(odd, n + (k - _c(k, 1)) / _c(k, 2), k / _c(k, 2))
        pk = d["pkm1"] * yk + d["pkm2"] * xk
        qk = d["qkm1"] * yk + d["qkm2"] * xk
        nz = qk != zero
        d["r"] = r = jnp.where(nz, pk / qk, d["r"])
        d["t"] = jnp.where(nz, abs((d["ans"] - r) / r), one)
        d["ans"] = jnp.where(nz, r, d["ans"])
        d["pkm2"] = d["pkm1"]
        d["pkm1"] = pk
        d["qkm2"] = d["qkm1"]
        d["qkm1"] = qk
        is_big = abs(pk) > BIG
        for s in "pq":
            for i in "12":
                key = s + "km" + i
                d[key] = jnp.where(is_big, d[key] / BIG, d[key])
        return d

    def cond(d):
        return (d["x"] > _c(d["k"], 0)) & (d["t"] > MACHEP)

    d = lax.while_loop(cond, body, init)
    return d["ans"] * jnp.exp(-x)
示例#7
0
文件: linalg.py 项目: ahoenselaar/jax
def pinv(a, rcond=None):
    # Uses same algorithm as
    # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
    a = jnp.conj(a)
    if rcond is None:
        max_rows_cols = max(a.shape[-2:])
        rcond = 10. * max_rows_cols * jnp.finfo(a.dtype).eps
    rcond = jnp.asarray(rcond)
    u, s, vh = svd(a, full_matrices=False)
    # Singular values less than or equal to ``rcond * largest_singular_value``
    # are set to zero.
    cutoff = rcond[..., jnp.newaxis] * jnp.amax(
        s, axis=-1, keepdims=True, initial=-jnp.inf)
    s = jnp.where(s > cutoff, s, jnp.inf)
    res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]))
    return lax.convert_element_type(res, a.dtype)
示例#8
0
def polydiv(u, v, *, trim_leading_zeros=False):
    _check_arraylike("polydiv", u, v)
    u, v = _promote_dtypes_inexact(u, v)
    m = len(u) - 1
    n = len(v) - 1
    scale = 1. / v[0]
    q = zeros(max(m - n + 1, 1), dtype=u.dtype)  # force same dtype
    for k in range(0, m - n + 1):
        d = scale * u[k]
        q = q.at[k].set(d)
        u = u.at[k:k + n + 1].add(-d * v)
    if trim_leading_zeros:
        # use the square root of finfo(dtype) to approximate the absolute tolerance used in numpy
        return q, trim_zeros_tol(u, tol=sqrt(finfo(u.dtype).eps), trim='f')
    else:
        return q, u
示例#9
0
def _expn1(n, x):
    # exponential integral En
    _c = _constant_like
    x = jnp.array(x)
    MACHEP = jnp.finfo(x.dtype).eps

    zero = _c(x, 0.0)
    one = _c(x, 1.0)
    psi = -jnp.euler_gamma - jnp.log(x)
    psi = lax.fori_loop(_c(n, 1), n, lambda i, psi: psi + one / i, psi)
    n1 = jnp.where(n == _c(n, 1), one + one, n)
    init = dict(
        x=x,
        z=-x,
        xk=zero,
        yk=one,
        pk=one - n,
        ans=jnp.where(n == _c(n, 1), zero, one / (one - n1)),
        t=jnp.inf,
    )

    def body(d):
        d["xk"] += one
        d["yk"] *= d["z"] / d["xk"]
        d["pk"] += one
        d["ans"] += jnp.where(d["pk"] != zero, d["yk"] / d["pk"], zero)
        d["t"] = jnp.where(d["ans"] != zero, abs(d["yk"] / d["ans"]), one)
        return d

    def cond(d):
        return (d["x"] > _c(d["x"], 0.0)) & (d["t"] > MACHEP)

    d = lax.while_loop(cond, body, init)
    t = n
    r = n - _c(n, 1)
    return d["z"]**r * psi / jnp.exp(gammaln(t)) - d["ans"]
示例#10
0
文件: linalg.py 项目: ahoenselaar/jax
def norm(x,
         ord=None,
         axis: Union[None, Tuple[int, ...], int] = None,
         keepdims=False):
    x = _promote_arg_dtypes(jnp.asarray(x))
    x_shape = jnp.shape(x)
    ndim = len(x_shape)

    if axis is None:
        # NumPy has an undocumented behavior that admits arbitrary rank inputs if
        # `ord` is None: https://github.com/numpy/numpy/issues/14215
        if ord is None:
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims))
        axis = tuple(range(ndim))
    elif isinstance(axis, tuple):
        axis = tuple(canonicalize_axis(x, ndim) for x in axis)
    else:
        axis = (canonicalize_axis(axis, ndim), )

    num_axes = len(axis)
    if num_axes == 1:
        if ord is None or ord == 2:
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)),
                        axis=axis,
                        keepdims=keepdims))
        elif ord == jnp.inf:
            return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims)
        elif ord == -jnp.inf:
            return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims)
        elif ord == 0:
            return jnp.sum(x != 0,
                           dtype=jnp.finfo(lax.dtype(x)).dtype,
                           axis=axis,
                           keepdims=keepdims)
        elif ord == 1:
            # Numpy has a special case for ord == 1 as an optimization. We don't
            # really need the optimization (XLA could do it for us), but the Numpy
            # code has slightly different type promotion semantics, so we need a
            # special case too.
            return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims)
        else:
            abs_x = jnp.abs(x)
            ord = lax._const(abs_x, ord)
            out = jnp.sum(abs_x**ord, axis=axis, keepdims=keepdims)
            return jnp.power(out, 1. / ord)

    elif num_axes == 2:
        row_axis, col_axis = cast(Tuple[int, ...], axis)
        if ord is None or ord in ('f', 'fro'):
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)),
                        axis=axis,
                        keepdims=keepdims))
        elif ord == 1:
            if not keepdims and col_axis > row_axis:
                col_axis -= 1
            return jnp.amax(jnp.sum(jnp.abs(x),
                                    axis=row_axis,
                                    keepdims=keepdims),
                            axis=col_axis,
                            keepdims=keepdims)
        elif ord == -1:
            if not keepdims and col_axis > row_axis:
                col_axis -= 1
            return jnp.amin(jnp.sum(jnp.abs(x),
                                    axis=row_axis,
                                    keepdims=keepdims),
                            axis=col_axis,
                            keepdims=keepdims)
        elif ord == jnp.inf:
            if not keepdims and row_axis > col_axis:
                row_axis -= 1
            return jnp.amax(jnp.sum(jnp.abs(x),
                                    axis=col_axis,
                                    keepdims=keepdims),
                            axis=row_axis,
                            keepdims=keepdims)
        elif ord == -jnp.inf:
            if not keepdims and row_axis > col_axis:
                row_axis -= 1
            return jnp.amin(jnp.sum(jnp.abs(x),
                                    axis=col_axis,
                                    keepdims=keepdims),
                            axis=row_axis,
                            keepdims=keepdims)
        elif ord in ('nuc', 2, -2):
            x = jnp.moveaxis(x, axis, (-2, -1))
            if ord == 2:
                reducer = jnp.amax
            elif ord == -2:
                reducer = jnp.amin
            else:
                reducer = jnp.sum
            y = reducer(svd(x, compute_uv=False), axis=-1)
            if keepdims:
                result_shape = list(x_shape)
                result_shape[axis[0]] = 1
                result_shape[axis[1]] = 1
                y = jnp.reshape(y, result_shape)
            return y
        else:
            raise ValueError("Invalid order '{}' for matrix norm.".format(ord))
    else:
        raise ValueError(
            "Invalid axis values ({}) for jnp.linalg.norm.".format(axis))
示例#11
0
def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
    _check_arraylike("polyfit", x, y)
    deg = core.concrete_or_error(int, deg, "deg must be int")
    order = deg + 1
    # check arguments
    if deg < 0:
        raise ValueError("expected deg >= 0")
    if x.ndim != 1:
        raise TypeError("expected 1D vector for x")
    if x.size == 0:
        raise TypeError("expected non-empty vector for x")
    if y.ndim < 1 or y.ndim > 2:
        raise TypeError("expected 1D or 2D array for y")
    if x.shape[0] != y.shape[0]:
        raise TypeError("expected x and y to have same length")

    # set rcond
    if rcond is None:
        rcond = len(x) * finfo(x.dtype).eps
    rcond = core.concrete_or_error(float, rcond, "rcond must be float")
    # set up least squares equation for powers of x
    lhs = vander(x, order)
    rhs = y

    # apply weighting
    if w is not None:
        _check_arraylike("polyfit", w)
        w, = _promote_dtypes_inexact(w)
        if w.ndim != 1:
            raise TypeError("expected a 1-d array for weights")
        if w.shape[0] != y.shape[0]:
            raise TypeError("expected w and y to have the same length")
        lhs *= w[:, np.newaxis]
        if rhs.ndim == 2:
            rhs *= w[:, np.newaxis]
        else:
            rhs *= w

    # scale lhs to improve condition number and solve
    scale = sqrt((lhs * lhs).sum(axis=0))
    lhs /= scale[np.newaxis, :]
    c, resids, rank, s = linalg.lstsq(lhs, rhs, rcond)
    c = (c.T / scale).T  # broadcast scale coefficients

    if full:
        return c, resids, rank, s, rcond
    elif cov:
        Vbase = linalg.inv(dot(lhs.T, lhs))
        Vbase /= outer(scale, scale)
        if cov == "unscaled":
            fac = 1
        else:
            if len(x) <= order:
                raise ValueError("the number of data points must exceed order "
                                 "to scale the covariance matrix")
            fac = resids / (len(x) - order)
            fac = fac[0]  #making np.array() of shape (1,) to int
        if y.ndim == 1:
            return c, Vbase * fac
        else:
            return c, Vbase[:, :, np.newaxis] * fac
    else:
        return c
示例#12
0
文件: eigh.py 项目: cloudhan/jax
def _projector_subspace(P, H, n, rank, maxiter=2):
  """ Decomposes the `n x n` rank `rank` Hermitian projector `P` into
  an `n x rank` isometry `V_minus` such that `P = V_minus @ V_minus.conj().T`
  and an `n x (n - rank)` isometry `V_minus` such that
  -(I - P) = V_plus @ V_plus.conj().T`.

  The subspaces are computed using the naiive QR eigendecomposition
  algorithm, which converges very quickly due to the sharp separation
  between the relevant eigenvalues of the projector.

  Args:
    P: A rank-`rank` Hermitian projector into the space of `H`'s
       first `rank` eigenpairs. `P` is padded to NxN.
    H: The aforementioned Hermitian matrix, which is used to track
       convergence.
    n: the true (dynamic) shape of `P`.
    rank: Rank of `P`.
    maxiter: Maximum number of iterations.
  Returns:
    V_minus, V_plus: Isometries into the eigenspaces described in the docstring.
  """
  # Choose an initial guess: the `rank` largest-norm columns of P.
  N, _ = P.shape
  column_norms = jnp_linalg.norm(P, axis=1)
  # `jnp.argsort` ensures NaNs sort last, so set masked-out column norms to NaN.
  column_norms = _mask(column_norms, (n,), jnp.nan)
  sort_idxs = jnp.argsort(column_norms)
  X = P[:, sort_idxs]
  # X = X[:, :rank]
  X = _mask(X, (n, rank))

  H_norm = jnp_linalg.norm(H)
  thresh = 10 * jnp.finfo(X.dtype).eps * H_norm

  # First iteration skips the matmul.
  def body_f_after_matmul(X):
    Q, _ = jnp_linalg.qr(X, mode="complete")
    # V1 = Q[:, :rank]
    # V2 = Q[:, rank:]
    V1 = _mask(Q, (n, rank))
    V2 = _slice(Q, (0, rank), (n, n - rank), (N, N))

    # TODO: might be able to get away with lower precision here
    error_matrix = jnp.dot(V2.conj().T, H)
    error_matrix = jnp.dot(error_matrix, V1)
    error = jnp_linalg.norm(error_matrix) / H_norm
    return V1, V2, error

  def cond_f(args):
    _, _, j, error = args
    still_counting = j < maxiter
    unconverged = error > thresh
    return jnp.logical_and(still_counting, unconverged)[0]

  def body_f(args):
    V1, _, j, _ = args
    X = jnp.dot(P, V1)
    V1, V2, error = body_f_after_matmul(X)
    return V1, V2, j + 1, error

  V1, V2, error = body_f_after_matmul(X)
  one = jnp.ones(1, dtype=jnp.int32)
  V1, V2, _, error = lax.while_loop(cond_f, body_f, (V1, V2, one, error))
  return V1, V2