Пример #1
0
def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=None):
    if overwrite_data is not None:
        raise NotImplementedError("overwrite_data argument not implemented.")
    if type not in ['constant', 'linear']:
        raise ValueError("Trend type must be 'linear' or 'constant'.")
    data, = _promote_dtypes_inexact(jnp.asarray(data))
    if type == 'constant':
        return data - data.mean(axis, keepdims=True)
    else:
        N = data.shape[axis]
        # bp is static, so we use np operations to avoid pushing to device.
        bp = np.sort(np.unique(np.r_[0, bp, N]))
        if bp[0] < 0 or bp[-1] > N:
            raise ValueError(
                "Breakpoints must be non-negative and less than length of data along given axis."
            )
        data = jnp.moveaxis(data, axis, 0)
        shape = data.shape
        data = data.reshape(N, -1)
        for m in range(len(bp) - 1):
            Npts = bp[m + 1] - bp[m]
            A = jnp.vstack([
                jnp.ones(Npts, dtype=data.dtype),
                jnp.arange(1, Npts + 1, dtype=data.dtype) / Npts
            ]).T
            sl = slice(bp[m], bp[m + 1])
            coef, *_ = linalg.lstsq(A, data[sl])
            data = data.at[sl].add(
                -jnp.matmul(A, coef, precision=lax.Precision.HIGHEST))
        return jnp.moveaxis(data.reshape(shape), 0, axis)
Пример #2
0
def _roots_no_zeros(p):
    # build companion matrix and find its eigenvalues (the roots)
    if p.size < 2:
        return array([], dtype=dtypes._to_complex_dtype(p.dtype))
    A = diag(ones((p.size - 2, ), p.dtype), -1)
    A = A.at[0, :].set(-p[1:] / p[0])
    return linalg.eigvals(A)
Пример #3
0
def _unique_sorted_mask(ar, axis):
    aux = moveaxis(ar, axis, 0)
    if np.issubdtype(aux.dtype, np.complexfloating):
        # Work around issue in sorting of complex numbers with Nan only in the
        # imaginary component. This can be removed if sorting in this situation
        # is fixed to match numpy.
        aux = where(isnan(aux), _lax_const(aux, np.nan), aux)
    size, *out_shape = aux.shape
    if _prod(out_shape) == 0:
        size = 1
        perm = zeros(1, dtype=int)
    else:
        perm = lexsort(aux.reshape(size, _prod(out_shape)).T[::-1])
    aux = aux[perm]
    if aux.size:
        if dtypes.issubdtype(aux.dtype, np.inexact):
            # This is appropriate for both float and complex due to the documented behavior of np.unique:
            # See https://github.com/numpy/numpy/blob/v1.22.0/numpy/lib/arraysetops.py#L212-L220
            neq = lambda x, y: lax.ne(x, y) & ~(isnan(x) & isnan(y))
        else:
            neq = lax.ne
        mask = ones(size, dtype=bool).at[1:].set(
            any(neq(aux[1:], aux[:-1]), tuple(range(1, aux.ndim))))
    else:
        mask = zeros(size, dtype=bool)
    return aux, mask, perm
Пример #4
0
def _roots_no_zeros(p):
    # assume: p does not have leading zeros and has length > 1
    p, = _promote_dtypes_inexact(p)

    # build companion matrix and find its eigenvalues (the roots)
    A = diag(ones((p.size - 2, ), p.dtype), -1)
    A = A.at[0, :].set(-p[1:] / p[0])
    roots = linalg.eigvals(A)
    return roots
Пример #5
0
def poly(seq_of_zeros):
    _check_arraylike('poly', seq_of_zeros)
    seq_of_zeros, = _promote_dtypes_inexact(seq_of_zeros)
    seq_of_zeros = atleast_1d(seq_of_zeros)

    sh = seq_of_zeros.shape
    if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0:
        # import at runtime to avoid circular import
        from jax._src.numpy import linalg
        seq_of_zeros = linalg.eigvals(seq_of_zeros)

    if seq_of_zeros.ndim != 1:
        raise ValueError("input must be 1d or non-empty square 2d array.")

    dt = seq_of_zeros.dtype
    if len(seq_of_zeros) == 0:
        return ones((), dtype=dt)

    a = ones((1, ), dtype=dt)
    for k in range(len(seq_of_zeros)):
        a = convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), mode='full')

    return a
Пример #6
0
    def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x):
        """Implements the Sturm sequence recurrence."""
        n = alpha.shape[0]
        zeros = jnp.zeros(x.shape, dtype=jnp.int32)
        ones = jnp.ones(x.shape, dtype=jnp.int32)

        # The first step in the Sturm sequence recurrence
        # requires special care if x is equal to alpha[0].
        def sturm_step0():
            q = alpha[0] - x
            count = jnp.where(q < 0, ones, zeros)
            q = jnp.where(alpha[0] == x, alpha0_perturbation, q)
            return q, count

        # Subsequent steps all take this form:
        def sturm_step(i, q, count):
            q = alpha[i] - beta_sq[i - 1] / q - x
            count = jnp.where(q <= pivmin, count + 1, count)
            q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q)
            return q, count

        # The first step initializes q and count.
        q, count = sturm_step0()

        # Peel off ((n-1) % blocksize) steps from the main loop, so we can run
        # the bulk of the iterations unrolled by a factor of blocksize.
        blocksize = 16
        i = 1
        peel = (n - 1) % blocksize
        unroll_cnt = peel

        def unrolled_steps(args):
            start, q, count = args
            for j in range(unroll_cnt):
                q, count = sturm_step(start + j, q, count)
            return start + unroll_cnt, q, count

        i, q, count = unrolled_steps((i, q, count))

        # Run the remaining steps of the Sturm sequence using a partially
        # unrolled while loop.
        unroll_cnt = blocksize

        def cond(iqc):
            i, q, count = iqc
            return jnp.less(i, n)

        _, _, count = lax.while_loop(cond, unrolled_steps, (i, q, count))
        return count
Пример #7
0
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
  A, = primals
  dA, = tangents
  s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)

  if compute_uv and full_matrices:
    # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
    raise NotImplementedError(
      "Singular value decomposition JVP not implemented for full matrices")

  Ut, V = _H(U), _H(Vt)
  s_dim = s[..., None, :]
  dS = jnp.matmul(jnp.matmul(Ut, dA), V)
  ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))

  if not compute_uv:
    return (s,), (ds,)

  s_diffs = jnp.square(s_dim) - jnp.square(_T(s_dim))
  s_diffs_zeros = jnp.eye(s.shape[-1], dtype=A.dtype)  # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.)  # is 1. where s_diffs is 0. and is 0. everywhere else
  F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros
  dSS = s_dim * dS  # dS.dot(jnp.diag(s))
  SdS = _T(s_dim) * dS  # jnp.diag(s).dot(dS)

  s_zeros = jnp.ones((), dtype=A.dtype) * (s == 0.)
  s_inv = 1 / (s + s_zeros) - s_zeros
  s_inv_mat = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(s_inv)
  dUdV_diag = .5 * (dS - _H(dS)) * s_inv_mat
  dU = jnp.matmul(U, F * (dSS + _H(dSS)) + dUdV_diag)
  dV = jnp.matmul(V, F * (SdS + _H(SdS)))

  m, n = A.shape[-2:]
  if m > n:
    dU = dU + jnp.matmul(jnp.eye(m, dtype=A.dtype) - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim
  if n > m:
    dV = dV + jnp.matmul(jnp.eye(n, dtype=A.dtype) - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim

  return (s, U, Vt), (ds, dU, _H(dV))
Пример #8
0
def _cofactor_solve(a, b):
    """Equivalent to det(a)*solve(a, b) for nonsingular mat.

  Intermediate function used for jvp and vjp of det.
  This function borrows heavily from jax.numpy.linalg.solve and
  jax.numpy.linalg.slogdet to compute the gradient of the determinant
  in a way that is well defined even for low rank matrices.

  This function handles two different cases:
  * rank(a) == n or n-1
  * rank(a) < n-1

  For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix.
  Rather than computing det(a)*solve(a, b), which would return NaN, we work
  directly with the LU decomposition. If a = p @ l @ u, then
  det(a)*solve(a, b) =
  prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b =
  prod(diag(u)) * triangular_solve(u, solve(p @ l, b))
  If a is rank n-1, then the lower right corner of u will be zero and the
  triangular_solve will fail.
  Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
  Then y_{n}
  x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
  x_{n} * prod_{i=1...n-1}(u_{ii})
  So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
  we can avoid the triangular_solve failing.
  To correctly compute the rest of y_{i} for i != n, we simply multiply
  x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.

  For the second case, a check is done on the matrix to see if `solve`
  returns NaN or Inf, and gives a matrix of zeros as a result, as the
  gradient of the determinant of a matrix with rank less than n-1 is 0.
  This will still return the correct value for rank n-1 matrices, as the check
  is applied *after* the lower right corner of u has been updated.

  Args:
    a: A square matrix or batch of matrices, possibly singular.
    b: A matrix, or batch of matrices of the same dimension as a.

  Returns:
    det(a) and cofactor(a)^T*b, aka adjugate(a)*b
  """
    a = _promote_arg_dtypes(jnp.asarray(a))
    b = _promote_arg_dtypes(jnp.asarray(b))
    a_shape = jnp.shape(a)
    b_shape = jnp.shape(b)
    a_ndims = len(a_shape)
    if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
            and b_shape[-2:] == a_shape[-2:]):
        msg = ("The arguments to _cofactor_solve must have shapes "
               "a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
        raise ValueError(msg.format(a_shape, b_shape))
    if a_shape[-1] == 1:
        return a[..., 0, 0], b
    # lu contains u in the upper triangular matrix and l in the strict lower
    # triangular matrix.
    # The diagonal of l is set to ones without loss of generality.
    lu, pivots, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
    x = jnp.broadcast_to(b, batch_dims + b.shape[-2:])
    lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:])
    # Compute (partial) determinant, ignoring last diagonal of LU
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1)
    sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype)
    # partial_det[:, -1] contains the full determinant and
    # partial_det[:, -2] contains det(u) / u_{nn}.
    partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
    lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2])
    permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1], ))
    iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, )))
    # filter out any matrices that are not full rank
    d = jnp.ones(x.shape[:-1], x.dtype)
    d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
    d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1)
    d = jnp.tile(d[..., None, None], d.ndim * (1, ) + x.shape[-2:])
    x = jnp.where(d, jnp.zeros_like(x), x)  # first filter
    x = x[iotas[:-1] + (permutation, slice(None))]
    x = lax_linalg.triangular_solve(lu,
                                    x,
                                    left_side=True,
                                    lower=True,
                                    unit_diagonal=True)
    x = jnp.concatenate(
        (x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]),
        axis=-2)
    x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
    x = jnp.where(d, jnp.zeros_like(x), x)  # second filter

    return partial_det[..., -1], x
Пример #9
0
def _projector_subspace(P, H, n, rank, maxiter=2):
  """ Decomposes the `n x n` rank `rank` Hermitian projector `P` into
  an `n x rank` isometry `V_minus` such that `P = V_minus @ V_minus.conj().T`
  and an `n x (n - rank)` isometry `V_minus` such that
  -(I - P) = V_plus @ V_plus.conj().T`.

  The subspaces are computed using the naiive QR eigendecomposition
  algorithm, which converges very quickly due to the sharp separation
  between the relevant eigenvalues of the projector.

  Args:
    P: A rank-`rank` Hermitian projector into the space of `H`'s
       first `rank` eigenpairs. `P` is padded to NxN.
    H: The aforementioned Hermitian matrix, which is used to track
       convergence.
    n: the true (dynamic) shape of `P`.
    rank: Rank of `P`.
    maxiter: Maximum number of iterations.
  Returns:
    V_minus, V_plus: Isometries into the eigenspaces described in the docstring.
  """
  # Choose an initial guess: the `rank` largest-norm columns of P.
  N, _ = P.shape
  column_norms = jnp_linalg.norm(P, axis=1)
  # `jnp.argsort` ensures NaNs sort last, so set masked-out column norms to NaN.
  column_norms = _mask(column_norms, (n,), jnp.nan)
  sort_idxs = jnp.argsort(column_norms)
  X = P[:, sort_idxs]
  # X = X[:, :rank]
  X = _mask(X, (n, rank))

  H_norm = jnp_linalg.norm(H)
  thresh = 10 * jnp.finfo(X.dtype).eps * H_norm

  # First iteration skips the matmul.
  def body_f_after_matmul(X):
    Q, _ = jnp_linalg.qr(X, mode="complete")
    # V1 = Q[:, :rank]
    # V2 = Q[:, rank:]
    V1 = _mask(Q, (n, rank))
    V2 = _slice(Q, (0, rank), (n, n - rank), (N, N))

    # TODO: might be able to get away with lower precision here
    error_matrix = jnp.dot(V2.conj().T, H)
    error_matrix = jnp.dot(error_matrix, V1)
    error = jnp_linalg.norm(error_matrix) / H_norm
    return V1, V2, error

  def cond_f(args):
    _, _, j, error = args
    still_counting = j < maxiter
    unconverged = error > thresh
    return jnp.logical_and(still_counting, unconverged)[0]

  def body_f(args):
    V1, _, j, _ = args
    X = jnp.dot(P, V1)
    V1, V2, error = body_f_after_matmul(X)
    return V1, V2, j + 1, error

  V1, V2, error = body_f_after_matmul(X)
  one = jnp.ones(1, dtype=jnp.int32)
  V1, V2, _, error = lax.while_loop(cond_f, body_f, (V1, V2, one, error))
  return V1, V2