示例#1
0
文件: linalg.py 项目: xueeinstein/jax
def _slogdet_jvp(primals, tangents):
  x, = primals
  g, = tangents
  sign, ans = slogdet(x)
  ans_dot = jnp.trace(solve(x, g), axis1=-1, axis2=-2)
  if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating):
    sign_dot = (ans_dot - jnp.real(ans_dot).astype(ans_dot.dtype)) * sign
    ans_dot = jnp.real(ans_dot)
  else:
    sign_dot = jnp.zeros_like(sign)
  return (sign, ans), (sign_dot, ans_dot)
示例#2
0
文件: linalg.py 项目: varun-alla/jax
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
  A, = primals
  dA, = tangents
  s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)

  if compute_uv and full_matrices:
    # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
    raise NotImplementedError(
      "Singular value decomposition JVP not implemented for full matrices")

  k = s.shape[-1]
  Ut, V = _H(U), _H(Vt)
  s_dim = s[..., None, :]
  dS = jnp.matmul(jnp.matmul(Ut, dA), V)
  ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))

  if not compute_uv:
    return (s,), (ds,)

  F = 1 / (jnp.square(s_dim) - jnp.square(_T(s_dim)) + jnp.eye(k, dtype=A.dtype))
  F = F - jnp.eye(k, dtype=A.dtype)
  dSS = s_dim * dS
  SdS = _T(s_dim) * dS
  dU = jnp.matmul(U, F * (dSS + _T(dSS)))
  dV = jnp.matmul(V, F * (SdS + _T(SdS)))

  m, n = A.shape[-2:]
  if m > n:
    dU = dU + jnp.matmul(jnp.eye(m, dtype=A.dtype) - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim
  if n > m:
    dV = dV + jnp.matmul(jnp.eye(n, dtype=A.dtype) - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim
  return (s, U, Vt), (ds, dU, _T(dV))
示例#3
0
文件: signal.py 项目: jbampton/jax
def csd(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
        detrend='constant', return_onesided=True, scaling='density',
        axis=-1, average='mean'):
  freqs, _, Pxy = _spectral_helper(x, y, fs, window, nperseg, noverlap, nfft,
                                  detrend, return_onesided, scaling, axis,
                                  mode='psd')
  if y is not None:
    Pxy = Pxy + 0j  # Ensure complex output when x is not y

  # Average over windows.
  if Pxy.ndim >= 2 and Pxy.size > 0:
    if Pxy.shape[-1] > 1:
      if average == 'median':
        bias = signal_helper._median_bias(Pxy.shape[-1]).astype(Pxy.dtype)
        if jnp.iscomplexobj(Pxy):
          Pxy = (jnp.median(jnp.real(Pxy), axis=-1)
                  + 1j * jnp.median(jnp.imag(Pxy), axis=-1))
        else:
          Pxy = jnp.median(Pxy, axis=-1)
        Pxy /= bias
      elif average == 'mean':
        Pxy = Pxy.mean(axis=-1)
      else:
        raise ValueError(f'average must be "median" or "mean", got {average}')
    else:
      Pxy = jnp.reshape(Pxy, Pxy.shape[:-1])

  return freqs, Pxy
示例#4
0
    def body(k, state):
        pivot, perm, a = state
        m_idx = jnp.arange(m)
        n_idx = jnp.arange(n)

        if jnp.issubdtype(a.dtype, jnp.complexfloating):
            t = a[:, k]
            magnitude = jnp.abs(jnp.real(t)) + jnp.abs(jnp.imag(t))
        else:
            magnitude = jnp.abs(a[:, k])
        i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf))
        pivot = ops.index_update(pivot, ops.index[k], i)

        a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ])

        perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ])

        # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
        x = a[k, k]
        a = ops.index_update(a, ops.index[:, k],
                             jnp.where(m_idx > k, a[:, k] / x, a[:, k]))

        # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:])
        a = a - jnp.where(
            (m_idx[:, None] > k) & (n_idx > k), jnp.outer(a[:, k], a[k, :]),
            jnp.array(0, dtype=a.dtype))
        return pivot, perm, a
示例#5
0
    def recursive_case(B, offset, b, agenda, blocks, eigenvectors):
        # The recursive case of the algorithm, specialized to a static block size
        # of B.
        H = _slice(blocks, (offset, 0), (b, b), (B, B))
        V = _slice(eigenvectors, (0, offset), (n, b), (N, B))

        split_point = jnp.nanmedian(
            _mask(jnp.diag(jnp.real(H)), (b, ),
                  jnp.nan))  # TODO: Improve this?
        H_minus, V_minus, H_plus, V_plus, rank = split_spectrum(H,
                                                                b,
                                                                split_point,
                                                                V0=V)

        blocks = _update_slice(blocks, H_minus, (offset, 0), (rank, rank))
        blocks = _update_slice(blocks, H_plus, (offset + rank, 0),
                               (b - rank, b - rank))
        eigenvectors = _update_slice(eigenvectors, V_minus, (0, offset),
                                     (n, rank))
        eigenvectors = _update_slice(eigenvectors, V_plus, (0, offset + rank),
                                     (n, b - rank))

        agenda = agenda.push(_Subproblem(offset + rank, (b - rank)))
        agenda = agenda.push(_Subproblem(offset, rank))
        return agenda, blocks, eigenvectors
示例#6
0
def eigh_jvp_rule(primals, tangents, lower):
    # Derivative for eigh in the simplest case of distinct eigenvalues.
    # This is classic nondegenerate perurbation theory, but also see
    # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
    # The general solution treating the case of degenerate eigenvalues is
    # considerably more complicated. Ambitious readers may refer to the general
    # methods below or refer to degenerate perturbation theory in physics.
    # https://www.win.tue.nl/analysis/reports/rana06-33.pdf and
    # https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf
    a, = primals
    a_dot, = tangents

    v, w_real = eigh_p.bind(symmetrize(a), lower=lower)

    # for complex numbers we need eigenvalues to be full dtype of v, a:
    w = w_real.astype(a.dtype)
    eye_n = jnp.eye(a.shape[-1], dtype=a.dtype)
    # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
    Fmat = jnp.reciprocal(eye_n + w[..., jnp.newaxis, :] -
                          w[..., jnp.newaxis]) - eye_n
    # eigh impl doesn't support batch dims, but future-proof the grad.
    dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul,
                  precision=lax.Precision.HIGHEST)
    vdag_adot_v = dot(dot(_H(v), a_dot), v)
    dv = dot(v, jnp.multiply(Fmat, vdag_adot_v))
    dw = jnp.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1))
    return (v, w_real), (dv, dw)
示例#7
0
文件: linalg.py 项目: gnecula/jax
def _lu(a, permute_l):
    a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
    lu, pivots, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    m, n = jnp.shape(a)
    p = jnp.real(jnp.array(permutation == jnp.arange(m)[:, None], dtype=dtype))
    k = min(m, n)
    l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype)
    u = jnp.triu(lu)[:k, :]
    if permute_l:
        return jnp.matmul(p, l), u
    else:
        return p, l, u
示例#8
0
文件: linalg.py 项目: xueeinstein/jax
def _lu(a, permute_l):
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    lu, _, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    m, n = jnp.shape(a)
    p = jnp.real(
        jnp.array(permutation[None, :] == jnp.arange(
            m, dtype=permutation.dtype)[:, None],
                  dtype=dtype))
    k = min(m, n)
    l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype)
    u = jnp.triu(lu)[:k, :]
    if permute_l:
        return jnp.matmul(p, l), u
    else:
        return p, l, u
示例#9
0
文件: linalg.py 项目: cloudhan/jax
def _slogdet_lu(a):
    dtype = lax.dtype(a)
    lu, pivot, _ = lax_linalg.lu(a)
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
    iota = lax.expand_dims(jnp.arange(a.shape[-1]), range(pivot.ndim - 1))
    parity = jnp.count_nonzero(pivot != iota, axis=-1)
    if jnp.iscomplexobj(a):
        sign = jnp.prod(diag / jnp.abs(diag), axis=-1)
    else:
        sign = jnp.array(1, dtype=dtype)
        parity = parity + jnp.count_nonzero(diag < 0, axis=-1)
    sign = jnp.where(is_zero, jnp.array(0, dtype=dtype),
                     sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
    logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype),
                       jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
    return sign, jnp.real(logdet)
示例#10
0
def qr_jvp_rule(primals, tangents, full_matrices):
    # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
    x, = primals
    dx, = tangents
    q, r = qr_p.bind(x, full_matrices=False)
    *_, m, n = x.shape
    if full_matrices or m < n:
        raise NotImplementedError(
            "Unimplemented case of QR decomposition derivative")
    dx_rinv = triangular_solve(r, dx)  # Right side solve by default
    qt_dx_rinv = jnp.matmul(_H(q), dx_rinv)
    qt_dx_rinv_lower = jnp.tril(qt_dx_rinv, -1)
    do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower)  # This is skew-symmetric
    # The following correction is necessary for complex inputs
    do = do + jnp.eye(n, dtype=do.dtype) * (qt_dx_rinv - jnp.real(qt_dx_rinv))
    dq = jnp.matmul(q, do - qt_dx_rinv) + dx_rinv
    dr = jnp.matmul(qt_dx_rinv - do, r)
    return (q, r), (dq, dr)
示例#11
0
def eigh(H,
         *,
         precision="float32",
         termination_size=256,
         n=None,
         sort_eigenvalues=True):
    """ Computes the eigendecomposition of the symmetric/Hermitian matrix H.

  Args:
    H: The `n x n` Hermitian input, padded to `N x N`.
    precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
    termination_size: Recursion ends once the blocks reach this linear size.
    n: the true (dynamic) size of the matrix.
    sort_eigenvalues: If `True`, the eigenvalues will be sorted from lowest to
      highest.
  Returns:
    vals: The `n` eigenvalues of `H`.
    vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector
      of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up
      to numerical error.
  """
    M, N = H.shape
    if M != N:
        raise TypeError(f"Input H of shape {H.shape} must be square.")

    if N <= termination_size:
        if n is not None:
            H = _mask(H, (n, n), jnp.eye(N, dtype=H.dtype))
        return lax_linalg.eigh_jacobi(H, sort_eigenvalues=sort_eigenvalues)

    # TODO(phawkins): consider rounding N up to a larger size to maximize reuse
    # between matrices.

    n = N if n is None else n
    with jax.default_matmul_precision(precision):
        eig_vals, eig_vecs = _eigh_work(H,
                                        n,
                                        termination_size=termination_size)
    eig_vals = _mask(jnp.real(eig_vals), (n, ), jnp.nan)
    if sort_eigenvalues:
        sort_idxs = jnp.argsort(eig_vals)
        eig_vals = eig_vals[sort_idxs]
        eig_vecs = eig_vecs[:, sort_idxs]
    return eig_vals, eig_vecs
示例#12
0
def _sph_harm(m: jnp.ndarray, n: jnp.ndarray, theta: jnp.ndarray,
              phi: jnp.ndarray, n_max: int) -> jnp.ndarray:
    """Computes the spherical harmonics."""

    cos_colatitude = jnp.cos(phi)

    legendre = _gen_associated_legendre(n_max, cos_colatitude, True)
    legendre_val = legendre[abs(m), n, jnp.arange(len(n))]

    angle = abs(m) * theta
    vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle))
    harmonics = lax.complex(legendre_val * jnp.real(vandermonde),
                            legendre_val * jnp.imag(vandermonde))

    # Negative order.
    harmonics = jnp.where(m < 0, (-1.0)**abs(m) * jnp.conjugate(harmonics),
                          harmonics)

    return harmonics
示例#13
0
文件: linalg.py 项目: ahoenselaar/jax
def slogdet(a):
    a = _promote_arg_dtypes(jnp.asarray(a))
    dtype = lax.dtype(a)
    a_shape = jnp.shape(a)
    if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
        msg = "Argument to slogdet() must have shape [..., n, n], got {}"
        raise ValueError(msg.format(a_shape))
    lu, pivot, _ = lax_linalg.lu(a)
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
    parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1)
    if jnp.iscomplexobj(a):
        sign = jnp.prod(diag / jnp.abs(diag), axis=-1)
    else:
        sign = jnp.array(1, dtype=dtype)
        parity = parity + jnp.count_nonzero(diag < 0, axis=-1)
    sign = jnp.where(is_zero, jnp.array(0, dtype=dtype),
                     sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
    logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype),
                       jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
    return sign, jnp.real(logdet)
示例#14
0
文件: linalg.py 项目: nbswords/jax
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
  A, = primals
  dA, = tangents
  s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)

  if compute_uv and full_matrices:
    # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
    raise NotImplementedError(
      "Singular value decomposition JVP not implemented for full matrices")

  Ut, V = _H(U), _H(Vt)
  s_dim = s[..., None, :]
  dS = jnp.matmul(jnp.matmul(Ut, dA), V)
  ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))

  if not compute_uv:
    return (s,), (ds,)

  s_diffs = jnp.square(s_dim) - jnp.square(_T(s_dim))
  s_diffs_zeros = jnp.eye(s.shape[-1], dtype=A.dtype)  # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.)  # is 1. where s_diffs is 0. and is 0. everywhere else
  F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros
  dSS = s_dim * dS  # dS.dot(jnp.diag(s))
  SdS = _T(s_dim) * dS  # jnp.diag(s).dot(dS)

  s_zeros = jnp.ones((), dtype=A.dtype) * (s == 0.)
  s_inv = 1 / (s + s_zeros) - s_zeros
  s_inv_mat = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(s_inv)
  dUdV_diag = .5 * (dS - _H(dS)) * s_inv_mat
  dU = jnp.matmul(U, F * (dSS + _H(dSS)) + dUdV_diag)
  dV = jnp.matmul(V, F * (SdS + _H(SdS)))

  m, n = A.shape[-2:]
  if m > n:
    dU = dU + jnp.matmul(jnp.eye(m, dtype=A.dtype) - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim
  if n > m:
    dV = dV + jnp.matmul(jnp.eye(n, dtype=A.dtype) - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim

  return (s, U, Vt), (ds, dU, _H(dV))
示例#15
0
def split_spectrum(H, n, split_point, V0=None):
    """ The Hermitian matrix `H` is split into two matrices `H_minus`
  `H_plus`, respectively sharing its eigenspaces beneath and above
  its `split_point`th eigenvalue.

  Returns, in addition, `V_minus` and `V_plus`, isometries such that
  `Hi = Vi.conj().T @ H @ Vi`. If `V0` is not None, `V0 @ Vi` are
  returned instead; this allows the overall isometries mapping from
  an initial input matrix to progressively smaller blocks to be formed.

  Args:
    H: The Hermitian matrix to split.
    split_point: The eigenvalue to split along.
    V0: Matrix of isometries to be updated.
  Returns:
    H_minus: A Hermitian matrix sharing the eigenvalues of `H` beneath
      `split_point`.
    V_minus: An isometry from the input space of `V0` to `H_minus`.
    H_plus: A Hermitian matrix sharing the eigenvalues of `H` above
      `split_point`.
    V_plus: An isometry from the input space of `V0` to `H_plus`.
    rank: The dynamic size of the m subblock.
  """
    N, _ = H.shape
    H_shift = H - (split_point * jnp.eye(N, dtype=split_point.dtype)).astype(
        H.dtype)
    U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True, dynamic_shape=(n, n))
    P = -0.5 * (U - _mask(jnp.eye(N, dtype=H.dtype), (n, n)))
    rank = jnp.round(jnp.trace(jnp.real(P))).astype(jnp.int32)

    V_minus, V_plus = _projector_subspace(P, H, n, rank)
    H_minus = (V_minus.conj().T @ H) @ V_minus
    H_plus = (V_plus.conj().T @ H) @ V_plus
    if V0 is not None:
        V_minus = jnp.dot(V0, V_minus)
        V_plus = jnp.dot(V0, V_plus)
    return H_minus, V_minus, H_plus, V_plus, rank
示例#16
0
文件: linalg.py 项目: ahoenselaar/jax
def norm(x,
         ord=None,
         axis: Union[None, Tuple[int, ...], int] = None,
         keepdims=False):
    x = _promote_arg_dtypes(jnp.asarray(x))
    x_shape = jnp.shape(x)
    ndim = len(x_shape)

    if axis is None:
        # NumPy has an undocumented behavior that admits arbitrary rank inputs if
        # `ord` is None: https://github.com/numpy/numpy/issues/14215
        if ord is None:
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims))
        axis = tuple(range(ndim))
    elif isinstance(axis, tuple):
        axis = tuple(canonicalize_axis(x, ndim) for x in axis)
    else:
        axis = (canonicalize_axis(axis, ndim), )

    num_axes = len(axis)
    if num_axes == 1:
        if ord is None or ord == 2:
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)),
                        axis=axis,
                        keepdims=keepdims))
        elif ord == jnp.inf:
            return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims)
        elif ord == -jnp.inf:
            return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims)
        elif ord == 0:
            return jnp.sum(x != 0,
                           dtype=jnp.finfo(lax.dtype(x)).dtype,
                           axis=axis,
                           keepdims=keepdims)
        elif ord == 1:
            # Numpy has a special case for ord == 1 as an optimization. We don't
            # really need the optimization (XLA could do it for us), but the Numpy
            # code has slightly different type promotion semantics, so we need a
            # special case too.
            return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims)
        else:
            abs_x = jnp.abs(x)
            ord = lax._const(abs_x, ord)
            out = jnp.sum(abs_x**ord, axis=axis, keepdims=keepdims)
            return jnp.power(out, 1. / ord)

    elif num_axes == 2:
        row_axis, col_axis = cast(Tuple[int, ...], axis)
        if ord is None or ord in ('f', 'fro'):
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)),
                        axis=axis,
                        keepdims=keepdims))
        elif ord == 1:
            if not keepdims and col_axis > row_axis:
                col_axis -= 1
            return jnp.amax(jnp.sum(jnp.abs(x),
                                    axis=row_axis,
                                    keepdims=keepdims),
                            axis=col_axis,
                            keepdims=keepdims)
        elif ord == -1:
            if not keepdims and col_axis > row_axis:
                col_axis -= 1
            return jnp.amin(jnp.sum(jnp.abs(x),
                                    axis=row_axis,
                                    keepdims=keepdims),
                            axis=col_axis,
                            keepdims=keepdims)
        elif ord == jnp.inf:
            if not keepdims and row_axis > col_axis:
                row_axis -= 1
            return jnp.amax(jnp.sum(jnp.abs(x),
                                    axis=col_axis,
                                    keepdims=keepdims),
                            axis=row_axis,
                            keepdims=keepdims)
        elif ord == -jnp.inf:
            if not keepdims and row_axis > col_axis:
                row_axis -= 1
            return jnp.amin(jnp.sum(jnp.abs(x),
                                    axis=col_axis,
                                    keepdims=keepdims),
                            axis=row_axis,
                            keepdims=keepdims)
        elif ord in ('nuc', 2, -2):
            x = jnp.moveaxis(x, axis, (-2, -1))
            if ord == 2:
                reducer = jnp.amax
            elif ord == -2:
                reducer = jnp.amin
            else:
                reducer = jnp.sum
            y = reducer(svd(x, compute_uv=False), axis=-1)
            if keepdims:
                result_shape = list(x_shape)
                result_shape[axis[0]] = 1
                result_shape[axis[1]] = 1
                y = jnp.reshape(y, result_shape)
            return y
        else:
            raise ValueError("Invalid order '{}' for matrix norm.".format(ord))
    else:
        raise ValueError(
            "Invalid axis values ({}) for jnp.linalg.norm.".format(axis))
示例#17
0
文件: linalg.py 项目: xueeinstein/jax
def eigh_tridiagonal(d,
                     e,
                     *,
                     eigvals_only=False,
                     select='a',
                     select_range=None,
                     tol=None):
    if not eigvals_only:
        raise NotImplementedError(
            "Calculation of eigenvectors is not implemented")

    def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x):
        """Implements the Sturm sequence recurrence."""
        n = alpha.shape[0]
        zeros = jnp.zeros(x.shape, dtype=jnp.int32)
        ones = jnp.ones(x.shape, dtype=jnp.int32)

        # The first step in the Sturm sequence recurrence
        # requires special care if x is equal to alpha[0].
        def sturm_step0():
            q = alpha[0] - x
            count = jnp.where(q < 0, ones, zeros)
            q = jnp.where(alpha[0] == x, alpha0_perturbation, q)
            return q, count

        # Subsequent steps all take this form:
        def sturm_step(i, q, count):
            q = alpha[i] - beta_sq[i - 1] / q - x
            count = jnp.where(q <= pivmin, count + 1, count)
            q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q)
            return q, count

        # The first step initializes q and count.
        q, count = sturm_step0()

        # Peel off ((n-1) % blocksize) steps from the main loop, so we can run
        # the bulk of the iterations unrolled by a factor of blocksize.
        blocksize = 16
        i = 1
        peel = (n - 1) % blocksize
        unroll_cnt = peel

        def unrolled_steps(args):
            start, q, count = args
            for j in range(unroll_cnt):
                q, count = sturm_step(start + j, q, count)
            return start + unroll_cnt, q, count

        i, q, count = unrolled_steps((i, q, count))

        # Run the remaining steps of the Sturm sequence using a partially
        # unrolled while loop.
        unroll_cnt = blocksize

        def cond(iqc):
            i, q, count = iqc
            return jnp.less(i, n)

        _, _, count = lax.while_loop(cond, unrolled_steps, (i, q, count))
        return count

    alpha = jnp.asarray(d)
    beta = jnp.asarray(e)
    supported_dtypes = (jnp.float32, jnp.float64, jnp.complex64,
                        jnp.complex128)
    if alpha.dtype != beta.dtype:
        raise TypeError(
            "diagonal and off-diagonal values must have same dtype, "
            f"got {alpha.dtype} and {beta.dtype}")
    if alpha.dtype not in supported_dtypes or beta.dtype not in supported_dtypes:
        raise TypeError(
            "Only float32 and float64 inputs are supported as inputs "
            "to jax.scipy.linalg.eigh_tridiagonal, got "
            f"{alpha.dtype} and {beta.dtype}")
    n = alpha.shape[0]
    if n <= 1:
        return jnp.real(alpha)

    if jnp.issubdtype(alpha.dtype, jnp.complexfloating):
        alpha = jnp.real(alpha)
        beta_sq = jnp.real(beta * jnp.conj(beta))
        beta_abs = jnp.sqrt(beta_sq)
    else:
        beta_abs = jnp.abs(beta)
        beta_sq = jnp.square(beta)

    # Estimate the largest and smallest eigenvalues of T using the Gershgorin
    # circle theorem.
    off_diag_abs_row_sum = jnp.concatenate(
        [beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0)
    lambda_est_max = jnp.amax(alpha + off_diag_abs_row_sum)
    lambda_est_min = jnp.amin(alpha - off_diag_abs_row_sum)
    # Upper bound on 2-norm of T.
    t_norm = jnp.maximum(jnp.abs(lambda_est_min), jnp.abs(lambda_est_max))

    # Compute the smallest allowed pivot in the Sturm sequence to avoid
    # overflow.
    finfo = np.finfo(alpha.dtype)
    one = np.ones([], dtype=alpha.dtype)
    safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny)
    pivmin = safemin * jnp.maximum(1, jnp.amax(beta_sq))
    alpha0_perturbation = jnp.square(finfo.eps * beta_abs[0])
    abs_tol = finfo.eps * t_norm
    if tol is not None:
        abs_tol = jnp.maximum(tol, abs_tol)

    # In the worst case, when the absolute tolerance is eps*lambda_est_max and
    # lambda_est_max = -lambda_est_min, we have to take as many bisection steps
    # as there are bits in the mantissa plus 1.
    # The proof is left as an exercise to the reader.
    max_it = finfo.nmant + 1

    # Determine the indices of the desired eigenvalues, based on select and
    # select_range.
    if select == 'a':
        target_counts = jnp.arange(n, dtype=jnp.int32)
    elif select == 'i':
        if select_range[0] > select_range[1]:
            raise ValueError('Got empty index range in select_range.')
        target_counts = jnp.arange(select_range[0],
                                   select_range[1] + 1,
                                   dtype=jnp.int32)
    elif select == 'v':
        # TODO(phawkins): requires dynamic shape support.
        raise NotImplementedError("eigh_tridiagonal(..., select='v') is not "
                                  "implemented")
    else:
        raise ValueError("'select must have a value in {'a', 'i', 'v'}.")

    # Run binary search for all desired eigenvalues in parallel, starting from
    # the interval lightly wider than the estimated
    # [lambda_est_min, lambda_est_max].
    fudge = 2.1  # We widen starting interval the Gershgorin interval a bit.
    norm_slack = jnp.array(n, alpha.dtype) * fudge * finfo.eps * t_norm
    lower = lambda_est_min - norm_slack - 2 * fudge * pivmin
    upper = lambda_est_max + norm_slack + fudge * pivmin

    # Pre-broadcast the scalars used in the Sturm sequence for improved
    # performance.
    target_shape = jnp.shape(target_counts)
    lower = jnp.broadcast_to(lower, shape=target_shape)
    upper = jnp.broadcast_to(upper, shape=target_shape)
    mid = 0.5 * (upper + lower)
    pivmin = jnp.broadcast_to(pivmin, target_shape)
    alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape)

    # Start parallel binary searches.
    def cond(args):
        i, lower, _, upper = args
        return jnp.logical_and(jnp.less(i, max_it),
                               jnp.less(abs_tol, jnp.amax(upper - lower)))

    def body(args):
        i, lower, mid, upper = args
        counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid)
        lower = jnp.where(counts <= target_counts, mid, lower)
        upper = jnp.where(counts > target_counts, mid, upper)
        mid = 0.5 * (lower + upper)
        return i + 1, lower, mid, upper

    _, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper))
    return mid