def _slogdet_jvp(primals, tangents): x, = primals g, = tangents sign, ans = slogdet(x) ans_dot = jnp.trace(solve(x, g), axis1=-1, axis2=-2) if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating): sign_dot = (ans_dot - jnp.real(ans_dot).astype(ans_dot.dtype)) * sign ans_dot = jnp.real(ans_dot) else: sign_dot = jnp.zeros_like(sign) return (sign, ans), (sign_dot, ans_dot)
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv): A, = primals dA, = tangents s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True) if compute_uv and full_matrices: # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf raise NotImplementedError( "Singular value decomposition JVP not implemented for full matrices") k = s.shape[-1] Ut, V = _H(U), _H(Vt) s_dim = s[..., None, :] dS = jnp.matmul(jnp.matmul(Ut, dA), V) ds = jnp.real(jnp.diagonal(dS, 0, -2, -1)) if not compute_uv: return (s,), (ds,) F = 1 / (jnp.square(s_dim) - jnp.square(_T(s_dim)) + jnp.eye(k, dtype=A.dtype)) F = F - jnp.eye(k, dtype=A.dtype) dSS = s_dim * dS SdS = _T(s_dim) * dS dU = jnp.matmul(U, F * (dSS + _T(dSS))) dV = jnp.matmul(V, F * (SdS + _T(SdS))) m, n = A.shape[-2:] if m > n: dU = dU + jnp.matmul(jnp.eye(m, dtype=A.dtype) - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim if n > m: dV = dV + jnp.matmul(jnp.eye(n, dtype=A.dtype) - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim return (s, U, Vt), (ds, dU, _T(dV))
def csd(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', axis=-1, average='mean'): freqs, _, Pxy = _spectral_helper(x, y, fs, window, nperseg, noverlap, nfft, detrend, return_onesided, scaling, axis, mode='psd') if y is not None: Pxy = Pxy + 0j # Ensure complex output when x is not y # Average over windows. if Pxy.ndim >= 2 and Pxy.size > 0: if Pxy.shape[-1] > 1: if average == 'median': bias = signal_helper._median_bias(Pxy.shape[-1]).astype(Pxy.dtype) if jnp.iscomplexobj(Pxy): Pxy = (jnp.median(jnp.real(Pxy), axis=-1) + 1j * jnp.median(jnp.imag(Pxy), axis=-1)) else: Pxy = jnp.median(Pxy, axis=-1) Pxy /= bias elif average == 'mean': Pxy = Pxy.mean(axis=-1) else: raise ValueError(f'average must be "median" or "mean", got {average}') else: Pxy = jnp.reshape(Pxy, Pxy.shape[:-1]) return freqs, Pxy
def body(k, state): pivot, perm, a = state m_idx = jnp.arange(m) n_idx = jnp.arange(n) if jnp.issubdtype(a.dtype, jnp.complexfloating): t = a[:, k] magnitude = jnp.abs(jnp.real(t)) + jnp.abs(jnp.imag(t)) else: magnitude = jnp.abs(a[:, k]) i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf)) pivot = ops.index_update(pivot, ops.index[k], i) a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ]) perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ]) # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes x = a[k, k] a = ops.index_update(a, ops.index[:, k], jnp.where(m_idx > k, a[:, k] / x, a[:, k])) # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:]) a = a - jnp.where( (m_idx[:, None] > k) & (n_idx > k), jnp.outer(a[:, k], a[k, :]), jnp.array(0, dtype=a.dtype)) return pivot, perm, a
def recursive_case(B, offset, b, agenda, blocks, eigenvectors): # The recursive case of the algorithm, specialized to a static block size # of B. H = _slice(blocks, (offset, 0), (b, b), (B, B)) V = _slice(eigenvectors, (0, offset), (n, b), (N, B)) split_point = jnp.nanmedian( _mask(jnp.diag(jnp.real(H)), (b, ), jnp.nan)) # TODO: Improve this? H_minus, V_minus, H_plus, V_plus, rank = split_spectrum(H, b, split_point, V0=V) blocks = _update_slice(blocks, H_minus, (offset, 0), (rank, rank)) blocks = _update_slice(blocks, H_plus, (offset + rank, 0), (b - rank, b - rank)) eigenvectors = _update_slice(eigenvectors, V_minus, (0, offset), (n, rank)) eigenvectors = _update_slice(eigenvectors, V_plus, (0, offset + rank), (n, b - rank)) agenda = agenda.push(_Subproblem(offset + rank, (b - rank))) agenda = agenda.push(_Subproblem(offset, rank)) return agenda, blocks, eigenvectors
def eigh_jvp_rule(primals, tangents, lower): # Derivative for eigh in the simplest case of distinct eigenvalues. # This is classic nondegenerate perurbation theory, but also see # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf # The general solution treating the case of degenerate eigenvalues is # considerably more complicated. Ambitious readers may refer to the general # methods below or refer to degenerate perturbation theory in physics. # https://www.win.tue.nl/analysis/reports/rana06-33.pdf and # https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf a, = primals a_dot, = tangents v, w_real = eigh_p.bind(symmetrize(a), lower=lower) # for complex numbers we need eigenvalues to be full dtype of v, a: w = w_real.astype(a.dtype) eye_n = jnp.eye(a.shape[-1], dtype=a.dtype) # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs. Fmat = jnp.reciprocal(eye_n + w[..., jnp.newaxis, :] - w[..., jnp.newaxis]) - eye_n # eigh impl doesn't support batch dims, but future-proof the grad. dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul, precision=lax.Precision.HIGHEST) vdag_adot_v = dot(dot(_H(v), a_dot), v) dv = dot(v, jnp.multiply(Fmat, vdag_adot_v)) dw = jnp.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1)) return (v, w_real), (dv, dw)
def _lu(a, permute_l): a = np_linalg._promote_arg_dtypes(jnp.asarray(a)) lu, pivots, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) m, n = jnp.shape(a) p = jnp.real(jnp.array(permutation == jnp.arange(m)[:, None], dtype=dtype)) k = min(m, n) l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype) u = jnp.triu(lu)[:k, :] if permute_l: return jnp.matmul(p, l), u else: return p, l, u
def _lu(a, permute_l): a, = _promote_dtypes_inexact(jnp.asarray(a)) lu, _, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) m, n = jnp.shape(a) p = jnp.real( jnp.array(permutation[None, :] == jnp.arange( m, dtype=permutation.dtype)[:, None], dtype=dtype)) k = min(m, n) l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype) u = jnp.triu(lu)[:k, :] if permute_l: return jnp.matmul(p, l), u else: return p, l, u
def _slogdet_lu(a): dtype = lax.dtype(a) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) iota = lax.expand_dims(jnp.arange(a.shape[-1]), range(pivot.ndim - 1)) parity = jnp.count_nonzero(pivot != iota, axis=-1) if jnp.iscomplexobj(a): sign = jnp.prod(diag / jnp.abs(diag), axis=-1) else: sign = jnp.array(1, dtype=dtype) parity = parity + jnp.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype), jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) return sign, jnp.real(logdet)
def qr_jvp_rule(primals, tangents, full_matrices): # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation. x, = primals dx, = tangents q, r = qr_p.bind(x, full_matrices=False) *_, m, n = x.shape if full_matrices or m < n: raise NotImplementedError( "Unimplemented case of QR decomposition derivative") dx_rinv = triangular_solve(r, dx) # Right side solve by default qt_dx_rinv = jnp.matmul(_H(q), dx_rinv) qt_dx_rinv_lower = jnp.tril(qt_dx_rinv, -1) do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower) # This is skew-symmetric # The following correction is necessary for complex inputs do = do + jnp.eye(n, dtype=do.dtype) * (qt_dx_rinv - jnp.real(qt_dx_rinv)) dq = jnp.matmul(q, do - qt_dx_rinv) + dx_rinv dr = jnp.matmul(qt_dx_rinv - do, r) return (q, r), (dq, dr)
def eigh(H, *, precision="float32", termination_size=256, n=None, sort_eigenvalues=True): """ Computes the eigendecomposition of the symmetric/Hermitian matrix H. Args: H: The `n x n` Hermitian input, padded to `N x N`. precision: :class:`~jax.lax.Precision` object specifying the matmul precision. termination_size: Recursion ends once the blocks reach this linear size. n: the true (dynamic) size of the matrix. sort_eigenvalues: If `True`, the eigenvalues will be sorted from lowest to highest. Returns: vals: The `n` eigenvalues of `H`. vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up to numerical error. """ M, N = H.shape if M != N: raise TypeError(f"Input H of shape {H.shape} must be square.") if N <= termination_size: if n is not None: H = _mask(H, (n, n), jnp.eye(N, dtype=H.dtype)) return lax_linalg.eigh_jacobi(H, sort_eigenvalues=sort_eigenvalues) # TODO(phawkins): consider rounding N up to a larger size to maximize reuse # between matrices. n = N if n is None else n with jax.default_matmul_precision(precision): eig_vals, eig_vecs = _eigh_work(H, n, termination_size=termination_size) eig_vals = _mask(jnp.real(eig_vals), (n, ), jnp.nan) if sort_eigenvalues: sort_idxs = jnp.argsort(eig_vals) eig_vals = eig_vals[sort_idxs] eig_vecs = eig_vecs[:, sort_idxs] return eig_vals, eig_vecs
def _sph_harm(m: jnp.ndarray, n: jnp.ndarray, theta: jnp.ndarray, phi: jnp.ndarray, n_max: int) -> jnp.ndarray: """Computes the spherical harmonics.""" cos_colatitude = jnp.cos(phi) legendre = _gen_associated_legendre(n_max, cos_colatitude, True) legendre_val = legendre[abs(m), n, jnp.arange(len(n))] angle = abs(m) * theta vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle)) harmonics = lax.complex(legendre_val * jnp.real(vandermonde), legendre_val * jnp.imag(vandermonde)) # Negative order. harmonics = jnp.where(m < 0, (-1.0)**abs(m) * jnp.conjugate(harmonics), harmonics) return harmonics
def slogdet(a): a = _promote_arg_dtypes(jnp.asarray(a)) dtype = lax.dtype(a) a_shape = jnp.shape(a) if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: msg = "Argument to slogdet() must have shape [..., n, n], got {}" raise ValueError(msg.format(a_shape)) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1) if jnp.iscomplexobj(a): sign = jnp.prod(diag / jnp.abs(diag), axis=-1) else: sign = jnp.array(1, dtype=dtype) parity = parity + jnp.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype), jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) return sign, jnp.real(logdet)
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv): A, = primals dA, = tangents s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True) if compute_uv and full_matrices: # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf raise NotImplementedError( "Singular value decomposition JVP not implemented for full matrices") Ut, V = _H(U), _H(Vt) s_dim = s[..., None, :] dS = jnp.matmul(jnp.matmul(Ut, dA), V) ds = jnp.real(jnp.diagonal(dS, 0, -2, -1)) if not compute_uv: return (s,), (ds,) s_diffs = jnp.square(s_dim) - jnp.square(_T(s_dim)) s_diffs_zeros = jnp.eye(s.shape[-1], dtype=A.dtype) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros dSS = s_dim * dS # dS.dot(jnp.diag(s)) SdS = _T(s_dim) * dS # jnp.diag(s).dot(dS) s_zeros = jnp.ones((), dtype=A.dtype) * (s == 0.) s_inv = 1 / (s + s_zeros) - s_zeros s_inv_mat = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(s_inv) dUdV_diag = .5 * (dS - _H(dS)) * s_inv_mat dU = jnp.matmul(U, F * (dSS + _H(dSS)) + dUdV_diag) dV = jnp.matmul(V, F * (SdS + _H(SdS))) m, n = A.shape[-2:] if m > n: dU = dU + jnp.matmul(jnp.eye(m, dtype=A.dtype) - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim if n > m: dV = dV + jnp.matmul(jnp.eye(n, dtype=A.dtype) - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim return (s, U, Vt), (ds, dU, _H(dV))
def split_spectrum(H, n, split_point, V0=None): """ The Hermitian matrix `H` is split into two matrices `H_minus` `H_plus`, respectively sharing its eigenspaces beneath and above its `split_point`th eigenvalue. Returns, in addition, `V_minus` and `V_plus`, isometries such that `Hi = Vi.conj().T @ H @ Vi`. If `V0` is not None, `V0 @ Vi` are returned instead; this allows the overall isometries mapping from an initial input matrix to progressively smaller blocks to be formed. Args: H: The Hermitian matrix to split. split_point: The eigenvalue to split along. V0: Matrix of isometries to be updated. Returns: H_minus: A Hermitian matrix sharing the eigenvalues of `H` beneath `split_point`. V_minus: An isometry from the input space of `V0` to `H_minus`. H_plus: A Hermitian matrix sharing the eigenvalues of `H` above `split_point`. V_plus: An isometry from the input space of `V0` to `H_plus`. rank: The dynamic size of the m subblock. """ N, _ = H.shape H_shift = H - (split_point * jnp.eye(N, dtype=split_point.dtype)).astype( H.dtype) U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True, dynamic_shape=(n, n)) P = -0.5 * (U - _mask(jnp.eye(N, dtype=H.dtype), (n, n))) rank = jnp.round(jnp.trace(jnp.real(P))).astype(jnp.int32) V_minus, V_plus = _projector_subspace(P, H, n, rank) H_minus = (V_minus.conj().T @ H) @ V_minus H_plus = (V_plus.conj().T @ H) @ V_plus if V0 is not None: V_minus = jnp.dot(V0, V_minus) V_plus = jnp.dot(V0, V_plus) return H_minus, V_minus, H_plus, V_plus, rank
def norm(x, ord=None, axis: Union[None, Tuple[int, ...], int] = None, keepdims=False): x = _promote_arg_dtypes(jnp.asarray(x)) x_shape = jnp.shape(x) ndim = len(x_shape) if axis is None: # NumPy has an undocumented behavior that admits arbitrary rank inputs if # `ord` is None: https://github.com/numpy/numpy/issues/14215 if ord is None: return jnp.sqrt( jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims)) axis = tuple(range(ndim)) elif isinstance(axis, tuple): axis = tuple(canonicalize_axis(x, ndim) for x in axis) else: axis = (canonicalize_axis(axis, ndim), ) num_axes = len(axis) if num_axes == 1: if ord is None or ord == 2: return jnp.sqrt( jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis, keepdims=keepdims)) elif ord == jnp.inf: return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims) elif ord == -jnp.inf: return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims) elif ord == 0: return jnp.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype, axis=axis, keepdims=keepdims) elif ord == 1: # Numpy has a special case for ord == 1 as an optimization. We don't # really need the optimization (XLA could do it for us), but the Numpy # code has slightly different type promotion semantics, so we need a # special case too. return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims) else: abs_x = jnp.abs(x) ord = lax._const(abs_x, ord) out = jnp.sum(abs_x**ord, axis=axis, keepdims=keepdims) return jnp.power(out, 1. / ord) elif num_axes == 2: row_axis, col_axis = cast(Tuple[int, ...], axis) if ord is None or ord in ('f', 'fro'): return jnp.sqrt( jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis, keepdims=keepdims)) elif ord == 1: if not keepdims and col_axis > row_axis: col_axis -= 1 return jnp.amax(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims), axis=col_axis, keepdims=keepdims) elif ord == -1: if not keepdims and col_axis > row_axis: col_axis -= 1 return jnp.amin(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims), axis=col_axis, keepdims=keepdims) elif ord == jnp.inf: if not keepdims and row_axis > col_axis: row_axis -= 1 return jnp.amax(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims), axis=row_axis, keepdims=keepdims) elif ord == -jnp.inf: if not keepdims and row_axis > col_axis: row_axis -= 1 return jnp.amin(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims), axis=row_axis, keepdims=keepdims) elif ord in ('nuc', 2, -2): x = jnp.moveaxis(x, axis, (-2, -1)) if ord == 2: reducer = jnp.amax elif ord == -2: reducer = jnp.amin else: reducer = jnp.sum y = reducer(svd(x, compute_uv=False), axis=-1) if keepdims: result_shape = list(x_shape) result_shape[axis[0]] = 1 result_shape[axis[1]] = 1 y = jnp.reshape(y, result_shape) return y else: raise ValueError("Invalid order '{}' for matrix norm.".format(ord)) else: raise ValueError( "Invalid axis values ({}) for jnp.linalg.norm.".format(axis))
def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a', select_range=None, tol=None): if not eigvals_only: raise NotImplementedError( "Calculation of eigenvectors is not implemented") def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x): """Implements the Sturm sequence recurrence.""" n = alpha.shape[0] zeros = jnp.zeros(x.shape, dtype=jnp.int32) ones = jnp.ones(x.shape, dtype=jnp.int32) # The first step in the Sturm sequence recurrence # requires special care if x is equal to alpha[0]. def sturm_step0(): q = alpha[0] - x count = jnp.where(q < 0, ones, zeros) q = jnp.where(alpha[0] == x, alpha0_perturbation, q) return q, count # Subsequent steps all take this form: def sturm_step(i, q, count): q = alpha[i] - beta_sq[i - 1] / q - x count = jnp.where(q <= pivmin, count + 1, count) q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q) return q, count # The first step initializes q and count. q, count = sturm_step0() # Peel off ((n-1) % blocksize) steps from the main loop, so we can run # the bulk of the iterations unrolled by a factor of blocksize. blocksize = 16 i = 1 peel = (n - 1) % blocksize unroll_cnt = peel def unrolled_steps(args): start, q, count = args for j in range(unroll_cnt): q, count = sturm_step(start + j, q, count) return start + unroll_cnt, q, count i, q, count = unrolled_steps((i, q, count)) # Run the remaining steps of the Sturm sequence using a partially # unrolled while loop. unroll_cnt = blocksize def cond(iqc): i, q, count = iqc return jnp.less(i, n) _, _, count = lax.while_loop(cond, unrolled_steps, (i, q, count)) return count alpha = jnp.asarray(d) beta = jnp.asarray(e) supported_dtypes = (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128) if alpha.dtype != beta.dtype: raise TypeError( "diagonal and off-diagonal values must have same dtype, " f"got {alpha.dtype} and {beta.dtype}") if alpha.dtype not in supported_dtypes or beta.dtype not in supported_dtypes: raise TypeError( "Only float32 and float64 inputs are supported as inputs " "to jax.scipy.linalg.eigh_tridiagonal, got " f"{alpha.dtype} and {beta.dtype}") n = alpha.shape[0] if n <= 1: return jnp.real(alpha) if jnp.issubdtype(alpha.dtype, jnp.complexfloating): alpha = jnp.real(alpha) beta_sq = jnp.real(beta * jnp.conj(beta)) beta_abs = jnp.sqrt(beta_sq) else: beta_abs = jnp.abs(beta) beta_sq = jnp.square(beta) # Estimate the largest and smallest eigenvalues of T using the Gershgorin # circle theorem. off_diag_abs_row_sum = jnp.concatenate( [beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0) lambda_est_max = jnp.amax(alpha + off_diag_abs_row_sum) lambda_est_min = jnp.amin(alpha - off_diag_abs_row_sum) # Upper bound on 2-norm of T. t_norm = jnp.maximum(jnp.abs(lambda_est_min), jnp.abs(lambda_est_max)) # Compute the smallest allowed pivot in the Sturm sequence to avoid # overflow. finfo = np.finfo(alpha.dtype) one = np.ones([], dtype=alpha.dtype) safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny) pivmin = safemin * jnp.maximum(1, jnp.amax(beta_sq)) alpha0_perturbation = jnp.square(finfo.eps * beta_abs[0]) abs_tol = finfo.eps * t_norm if tol is not None: abs_tol = jnp.maximum(tol, abs_tol) # In the worst case, when the absolute tolerance is eps*lambda_est_max and # lambda_est_max = -lambda_est_min, we have to take as many bisection steps # as there are bits in the mantissa plus 1. # The proof is left as an exercise to the reader. max_it = finfo.nmant + 1 # Determine the indices of the desired eigenvalues, based on select and # select_range. if select == 'a': target_counts = jnp.arange(n, dtype=jnp.int32) elif select == 'i': if select_range[0] > select_range[1]: raise ValueError('Got empty index range in select_range.') target_counts = jnp.arange(select_range[0], select_range[1] + 1, dtype=jnp.int32) elif select == 'v': # TODO(phawkins): requires dynamic shape support. raise NotImplementedError("eigh_tridiagonal(..., select='v') is not " "implemented") else: raise ValueError("'select must have a value in {'a', 'i', 'v'}.") # Run binary search for all desired eigenvalues in parallel, starting from # the interval lightly wider than the estimated # [lambda_est_min, lambda_est_max]. fudge = 2.1 # We widen starting interval the Gershgorin interval a bit. norm_slack = jnp.array(n, alpha.dtype) * fudge * finfo.eps * t_norm lower = lambda_est_min - norm_slack - 2 * fudge * pivmin upper = lambda_est_max + norm_slack + fudge * pivmin # Pre-broadcast the scalars used in the Sturm sequence for improved # performance. target_shape = jnp.shape(target_counts) lower = jnp.broadcast_to(lower, shape=target_shape) upper = jnp.broadcast_to(upper, shape=target_shape) mid = 0.5 * (upper + lower) pivmin = jnp.broadcast_to(pivmin, target_shape) alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape) # Start parallel binary searches. def cond(args): i, lower, _, upper = args return jnp.logical_and(jnp.less(i, max_it), jnp.less(abs_tol, jnp.amax(upper - lower))) def body(args): i, lower, mid, upper = args counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid) lower = jnp.where(counts <= target_counts, mid, lower) upper = jnp.where(counts > target_counts, mid, upper) mid = 0.5 * (lower + upper) return i + 1, lower, mid, upper _, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper)) return mid