def body_f_after_matmul(X): Q, _ = jnp_linalg.qr(X, mode="complete") # V1 = Q[:, :rank] # V2 = Q[:, rank:] V1 = _mask(Q, (n, rank)) V2 = _slice(Q, (0, rank), (n, n - rank), (N, N)) # TODO: might be able to get away with lower precision here error_matrix = jnp.dot(V2.conj().T, H) error_matrix = jnp.dot(error_matrix, V1) error = jnp_linalg.norm(error_matrix) / H_norm return V1, V2, error
def base_case(B, offset, b, agenda, blocks, eigenvectors): # Base case: for blocks under a minimum size, we cutoff the recursion # and call the TPU Jacobi eigendecomposition implementation. The Jacobi # algorithm works well for small matrices but scales poorly, so the two # complement each other well. H = _slice(blocks, (offset, 0), (b, b), (B, B)) V = _slice(eigenvectors, (0, offset), (n, b), (N, B)) # We replace the masked-out part of the matrix with the identity matrix. # We know that the TPU Jacobi eigh implementation will not alter the order # of the eigenvalues, so we know the eigendecomposition of the original # matrix is in the top-left corner of the eigendecomposition of the padded # matrix. # It is very important that the underlying eigh implementation does not sort # the eigenvalues for this reason! This is currently not true of JAX's CPU # and GPU eigendecompositions, and for those platforms this algorithm will # only do the right thing if termination_size == 1. H = _mask(H, (b, b), jnp.eye(B, dtype=H.dtype)) eig_vecs, eig_vals = lax.linalg.eigh(H, sort_eigenvalues=False) eig_vecs = _mask(eig_vecs, (b, b)) eig_vals = _mask(eig_vals, (b,)) eig_vecs = jnp.dot(V, eig_vecs) blocks = _update_slice(blocks, eig_vals[:, None], (offset, 0), (b, b)) eigenvectors = _update_slice(eigenvectors, eig_vecs, (0, offset), (n, b)) return agenda, blocks, eigenvectors
def multi_dot(arrays, *, precision=None): n = len(arrays) # optimization only makes sense for len(arrays) > 2 if n < 2: raise ValueError("Expecting at least two arrays.") elif n == 2: return jnp.dot(arrays[0], arrays[1], precision=precision) arrays = [jnp.asarray(a) for a in arrays] # save original ndim to reshape the result array into the proper form later ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim # Explicitly convert vectors to 2D arrays to keep the logic of the internal # _multi_dot_* functions as simple as possible. if arrays[0].ndim == 1: arrays[0] = jnp.atleast_2d(arrays[0]) if arrays[-1].ndim == 1: arrays[-1] = jnp.atleast_2d(arrays[-1]).T _assert2d(*arrays) # _multi_dot_three is much faster than _multi_dot_matrix_chain_order if n == 3: result = _multi_dot_three(*arrays, precision) else: order = _multi_dot_matrix_chain_order(arrays) result = _multi_dot(arrays, order, 0, n - 1, precision) # return proper shape if ndim_first == 1 and ndim_last == 1: return result[0, 0] # scalar elif ndim_first == 1 or ndim_last == 1: return result.ravel() # 1-D else: return result
def _multi_dot(arrays, order, i, j, precision): """Actually do the multiplication with the given order.""" if i == j: return arrays[i] else: return jnp.dot(_multi_dot(arrays, order, i, order[i, j], precision), _multi_dot(arrays, order, order[i, j] + 1, j, precision), precision=precision)
def _multi_dot_three(A, B, C, precision): """ Find the best order for three arrays and do the multiplication. For three arguments `_multi_dot_three` is approximately 15 times faster than `_multi_dot_matrix_chain_order` """ a0, a1b0 = A.shape b1c0, c1 = C.shape # cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1 cost1 = a0 * b1c0 * (a1b0 + c1) # cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1 cost2 = a1b0 * c1 * (a0 + b1c0) if cost1 < cost2: return jnp.dot(jnp.dot(A, B, precision=precision), C, precision=precision) else: return jnp.dot(A, jnp.dot(B, C, precision=precision), precision=precision)
def split_spectrum(H, n, split_point, V0=None): """ The Hermitian matrix `H` is split into two matrices `H_minus` `H_plus`, respectively sharing its eigenspaces beneath and above its `split_point`th eigenvalue. Returns, in addition, `V_minus` and `V_plus`, isometries such that `Hi = Vi.conj().T @ H @ Vi`. If `V0` is not None, `V0 @ Vi` are returned instead; this allows the overall isometries mapping from an initial input matrix to progressively smaller blocks to be formed. Args: H: The Hermitian matrix to split. split_point: The eigenvalue to split along. V0: Matrix of isometries to be updated. Returns: H_minus: A Hermitian matrix sharing the eigenvalues of `H` beneath `split_point`. V_minus: An isometry from the input space of `V0` to `H_minus`. H_plus: A Hermitian matrix sharing the eigenvalues of `H` above `split_point`. V_plus: An isometry from the input space of `V0` to `H_plus`. rank: The dynamic size of the m subblock. """ N, _ = H.shape H_shift = H - (split_point * jnp.eye(N, dtype=split_point.dtype)).astype( H.dtype) U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True, dynamic_shape=(n, n)) P = -0.5 * (U - _mask(jnp.eye(N, dtype=H.dtype), (n, n))) rank = jnp.round(jnp.trace(jnp.real(P))).astype(jnp.int32) V_minus, V_plus = _projector_subspace(P, H, n, rank) H_minus = (V_minus.conj().T @ H) @ V_minus H_plus = (V_plus.conj().T @ H) @ V_plus if V0 is not None: V_minus = jnp.dot(V0, V_minus) V_plus = jnp.dot(V0, V_plus) return H_minus, V_minus, H_plus, V_plus, rank
def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False): _check_arraylike("polyfit", x, y) deg = core.concrete_or_error(int, deg, "deg must be int") order = deg + 1 # check arguments if deg < 0: raise ValueError("expected deg >= 0") if x.ndim != 1: raise TypeError("expected 1D vector for x") if x.size == 0: raise TypeError("expected non-empty vector for x") if y.ndim < 1 or y.ndim > 2: raise TypeError("expected 1D or 2D array for y") if x.shape[0] != y.shape[0]: raise TypeError("expected x and y to have same length") # set rcond if rcond is None: rcond = len(x) * finfo(x.dtype).eps rcond = core.concrete_or_error(float, rcond, "rcond must be float") # set up least squares equation for powers of x lhs = vander(x, order) rhs = y # apply weighting if w is not None: _check_arraylike("polyfit", w) w, = _promote_dtypes_inexact(w) if w.ndim != 1: raise TypeError("expected a 1-d array for weights") if w.shape[0] != y.shape[0]: raise TypeError("expected w and y to have the same length") lhs *= w[:, np.newaxis] if rhs.ndim == 2: rhs *= w[:, np.newaxis] else: rhs *= w # scale lhs to improve condition number and solve scale = sqrt((lhs * lhs).sum(axis=0)) lhs /= scale[np.newaxis, :] c, resids, rank, s = linalg.lstsq(lhs, rhs, rcond) c = (c.T / scale).T # broadcast scale coefficients if full: return c, resids, rank, s, rcond elif cov: Vbase = linalg.inv(dot(lhs.T, lhs)) Vbase /= outer(scale, scale) if cov == "unscaled": fac = 1 else: if len(x) <= order: raise ValueError("the number of data points must exceed order " "to scale the covariance matrix") fac = resids / (len(x) - order) fac = fac[0] #making np.array() of shape (1,) to int if y.ndim == 1: return c, Vbase * fac else: return c, Vbase[:, :, np.newaxis] * fac else: return c
def _precise_dot(A, B): return jnp.dot(A, B, precision=lax.Precision.HIGHEST)
def body_f(args): V1, _, j, _ = args X = jnp.dot(P, V1) V1, V2, error = body_f_after_matmul(X) return V1, V2, j + 1, error