def svd_jvp_rule(primals, tangents, full_matrices, compute_uv): A, = primals dA, = tangents s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True) if compute_uv and full_matrices: # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf raise NotImplementedError( "Singular value decomposition JVP not implemented for full matrices") k = s.shape[-1] Ut, V = _H(U), _H(Vt) s_dim = s[..., None, :] dS = jnp.matmul(jnp.matmul(Ut, dA), V) ds = jnp.real(jnp.diagonal(dS, 0, -2, -1)) if not compute_uv: return (s,), (ds,) F = 1 / (jnp.square(s_dim) - jnp.square(_T(s_dim)) + jnp.eye(k, dtype=A.dtype)) F = F - jnp.eye(k, dtype=A.dtype) dSS = s_dim * dS SdS = _T(s_dim) * dS dU = jnp.matmul(U, F * (dSS + _T(dSS))) dV = jnp.matmul(V, F * (SdS + _T(SdS))) m, n = A.shape[-2:] if m > n: dU = dU + jnp.matmul(jnp.eye(m, dtype=A.dtype) - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim if n > m: dV = dV + jnp.matmul(jnp.eye(n, dtype=A.dtype) - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim return (s, U, Vt), (ds, dU, _T(dV))
def _lu_jvp_rule(primals, tangents): a, = primals a_dot, = tangents lu, pivots, permutation = lu_p.bind(a) a_shape = jnp.shape(a) m, n = a_shape[-2:] dtype = lax.dtype(a) k = min(m, n) batch_dims = a_shape[:-2] iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, ))) x = a_dot[iotas[:-1] + (permutation, slice(None))] # Differentiation of Matrix Functionals Using Triangular Factorization # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas # # LU = A # ==> L'U + LU' = A' # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U) # ==> L' = L . tril(inv(L) . A' . inv(U), -1) # U' = triu(inv(L) . A' . inv(U)) . U ndims = len(a_shape) l_padding = [(0, 0, 0)] * ndims l_padding[-1] = (0, m - k, 0) zero = jnp._constant_like(lu, 0) l = lax.pad(jnp.tril(lu[..., :, :k], -1), zero, l_padding) l = l + jnp.eye(m, m, dtype=dtype) u_eye = lax.pad(jnp.eye(n - k, n - k, dtype=dtype), zero, ((k, 0, 0), (k, 0, 0))) u_padding = [(0, 0, 0)] * ndims u_padding[-2] = (0, n - k, 0) u = lax.pad(jnp.triu(lu[..., :k, :]), zero, u_padding) + u_eye la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True, unit_diagonal=True) lau = triangular_solve(u, la, left_side=False, transpose_a=False, lower=False) l_dot = jnp.matmul(l, jnp.tril(lau, -1)) u_dot = jnp.matmul(jnp.triu(lau), u) lu_dot = l_dot + u_dot return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_value(pivots), ad_util.Zero.from_value(permutation))
def _pinv_jvp(rcond, primals, tangents): # The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems # Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM # Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432. # (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative) a, = primals a_dot, = tangents p = pinv(a, rcond=rcond) m, n = a.shape[-2:] # TODO(phawkins): on TPU, we would need to opt into high precision here. # TODO(phawkins): consider if this can be simplified in the Hermitian case. p_dot = -p @ a_dot @ p p_dot = p_dot + p @ _H(p) @ _H(a_dot) @ (jnp.eye(m, dtype=a.dtype) - a @ p) p_dot = p_dot + (jnp.eye(n, dtype=a.dtype) - p @ a) @ _H(a_dot) @ _H(p) @ p return p, p_dot
def eigh_jvp_rule(primals, tangents, lower): # Derivative for eigh in the simplest case of distinct eigenvalues. # This is classic nondegenerate perurbation theory, but also see # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf # The general solution treating the case of degenerate eigenvalues is # considerably more complicated. Ambitious readers may refer to the general # methods below or refer to degenerate perturbation theory in physics. # https://www.win.tue.nl/analysis/reports/rana06-33.pdf and # https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf a, = primals a_dot, = tangents v, w_real = eigh_p.bind(symmetrize(a), lower=lower) # for complex numbers we need eigenvalues to be full dtype of v, a: w = w_real.astype(a.dtype) eye_n = jnp.eye(a.shape[-1], dtype=a.dtype) # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs. Fmat = jnp.reciprocal(eye_n + w[..., jnp.newaxis, :] - w[..., jnp.newaxis]) - eye_n # eigh impl doesn't support batch dims, but future-proof the grad. dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul, precision=lax.Precision.HIGHEST) vdag_adot_v = dot(dot(_H(v), a_dot), v) dv = dot(v, jnp.multiply(Fmat, vdag_adot_v)) dw = jnp.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1)) return (v, w_real), (dv, dw)
def base_case(B, offset, b, agenda, blocks, eigenvectors): # Base case: for blocks under a minimum size, we cutoff the recursion # and call the TPU Jacobi eigendecomposition implementation. The Jacobi # algorithm works well for small matrices but scales poorly, so the two # complement each other well. H = _slice(blocks, (offset, 0), (b, b), (B, B)) V = _slice(eigenvectors, (0, offset), (n, b), (N, B)) # We replace the masked-out part of the matrix with the identity matrix. # We know that the TPU Jacobi eigh implementation will not alter the order # of the eigenvalues, so we know the eigendecomposition of the original # matrix is in the top-left corner of the eigendecomposition of the padded # matrix. # It is very important that the underlying eigh implementation does not sort # the eigenvalues for this reason! This is currently not true of JAX's CPU # and GPU eigendecompositions, and for those platforms this algorithm will # only do the right thing if termination_size == 1. H = _mask(H, (b, b), jnp.eye(B, dtype=H.dtype)) eig_vecs, eig_vals = lax.linalg.eigh(H, sort_eigenvalues=False) eig_vecs = _mask(eig_vecs, (b, b)) eig_vals = _mask(eig_vals, (b,)) eig_vecs = jnp.dot(V, eig_vecs) blocks = _update_slice(blocks, eig_vals[:, None], (offset, 0), (b, b)) eigenvectors = _update_slice(eigenvectors, eig_vecs, (0, offset), (n, b)) return agenda, blocks, eigenvectors
def _pade3(A): b = (120., 60., 12., 1.) ident = jnp.eye(*A.shape, dtype=A.dtype) A2 = _precise_dot(A, A) U = _precise_dot(A, (b[3] * A2 + b[1] * ident)) V = b[2] * A2 + b[0] * ident return U, V
def inv(a): if jnp.ndim(a) < 2 or a.shape[-1] != a.shape[-2]: raise ValueError( f"Argument to inv must have shape [..., n, n], got {a.shape}.") return solve( a, lax.broadcast(jnp.eye(a.shape[-1], dtype=lax.dtype(a)), a.shape[:-2]))
def matrix_power(a, n): a = _promote_arg_dtypes(jnp.asarray(a)) if a.ndim < 2: raise TypeError("{}-dimensional array given. Array must be at least " "two-dimensional".format(a.ndim)) if a.shape[-2] != a.shape[-1]: raise TypeError("Last 2 dimensions of the array must be square") try: n = operator.index(n) except TypeError as err: raise TypeError( "exponent must be an integer, got {}".format(n)) from err if n == 0: return jnp.broadcast_to(jnp.eye(a.shape[-2], dtype=a.dtype), a.shape) elif n < 0: a = inv(a) n = np.abs(n) if n == 1: return a elif n == 2: return a @ a elif n == 3: return (a @ a) @ a z = result = None while n > 0: z = a if z is None else (z @ z) n, bit = divmod(n, 2) if bit: result = z if result is None else (result @ z) return result
def _pade5(A): b = (30240., 15120., 3360., 420., 30., 1.) ident = jnp.eye(*A.shape, dtype=A.dtype) A2 = _precise_dot(A, A) A4 = _precise_dot(A2, A2) U = _precise_dot(A, b[5] * A4 + b[3] * A2 + b[1] * ident) V = b[4] * A4 + b[2] * A2 + b[0] * ident return U, V
def _pade7(A): b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.) ident = jnp.eye(*A.shape, dtype=A.dtype) A2 = _precise_dot(A, A) A4 = _precise_dot(A2, A2) A6 = _precise_dot(A4, A2) U = _precise_dot(A, b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident) V = b[6] * A6 + b[4] * A4 + b[2] * A2 + b[0] * ident return U, V
def _pade13(A): b = (64764752532480000., 32382376266240000., 7771770303897600., 1187353796428800., 129060195264000., 10559470521600., 670442572800., 33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.) ident = jnp.eye(*A.shape, dtype=A.dtype) A2 = _precise_dot(A, A) A4 = _precise_dot(A2, A2) A6 = _precise_dot(A4, A2) U = _precise_dot(A, _precise_dot(A6, b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident) V = _precise_dot(A6, b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident return U,V
def _pade9(A): b = (17643225600., 8821612800., 2075673600., 302702400., 30270240., 2162160., 110880., 3960., 90., 1.) ident = jnp.eye(*A.shape, dtype=A.dtype) A2 = _precise_dot(A, A) A4 = _precise_dot(A2, A2) A6 = _precise_dot(A4, A2) A8 = _precise_dot(A6, A2) U = _precise_dot(A, b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident) V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident return U,V
def _lu(a, permute_l): a = np_linalg._promote_arg_dtypes(jnp.asarray(a)) lu, pivots, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) m, n = jnp.shape(a) p = jnp.real(jnp.array(permutation == jnp.arange(m)[:, None], dtype=dtype)) k = min(m, n) l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype) u = jnp.triu(lu)[:k, :] if permute_l: return jnp.matmul(p, l), u else: return p, l, u
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv): A, = primals dA, = tangents s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True) if compute_uv and full_matrices: # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf raise NotImplementedError( "Singular value decomposition JVP not implemented for full matrices") Ut, V = _H(U), _H(Vt) s_dim = s[..., None, :] dS = jnp.matmul(jnp.matmul(Ut, dA), V) ds = jnp.real(jnp.diagonal(dS, 0, -2, -1)) if not compute_uv: return (s,), (ds,) s_diffs = jnp.square(s_dim) - jnp.square(_T(s_dim)) s_diffs_zeros = jnp.eye(s.shape[-1], dtype=A.dtype) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros dSS = s_dim * dS # dS.dot(jnp.diag(s)) SdS = _T(s_dim) * dS # jnp.diag(s).dot(dS) s_zeros = jnp.ones((), dtype=A.dtype) * (s == 0.) s_inv = 1 / (s + s_zeros) - s_zeros s_inv_mat = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(s_inv) dUdV_diag = .5 * (dS - _H(dS)) * s_inv_mat dU = jnp.matmul(U, F * (dSS + _H(dSS)) + dUdV_diag) dV = jnp.matmul(V, F * (SdS + _H(SdS))) m, n = A.shape[-2:] if m > n: dU = dU + jnp.matmul(jnp.eye(m, dtype=A.dtype) - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim if n > m: dV = dV + jnp.matmul(jnp.eye(n, dtype=A.dtype) - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim return (s, U, Vt), (ds, dU, _H(dV))
def split_spectrum(H, n, split_point, V0=None): """ The Hermitian matrix `H` is split into two matrices `H_minus` `H_plus`, respectively sharing its eigenspaces beneath and above its `split_point`th eigenvalue. Returns, in addition, `V_minus` and `V_plus`, isometries such that `Hi = Vi.conj().T @ H @ Vi`. If `V0` is not None, `V0 @ Vi` are returned instead; this allows the overall isometries mapping from an initial input matrix to progressively smaller blocks to be formed. Args: H: The Hermitian matrix to split. split_point: The eigenvalue to split along. V0: Matrix of isometries to be updated. Returns: H_minus: A Hermitian matrix sharing the eigenvalues of `H` beneath `split_point`. V_minus: An isometry from the input space of `V0` to `H_minus`. H_plus: A Hermitian matrix sharing the eigenvalues of `H` above `split_point`. V_plus: An isometry from the input space of `V0` to `H_plus`. rank: The dynamic size of the m subblock. """ N, _ = H.shape H_shift = H - (split_point * jnp.eye(N, dtype=split_point.dtype)).astype( H.dtype) U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True, dynamic_shape=(n, n)) P = -0.5 * (U - _mask(jnp.eye(N, dtype=H.dtype), (n, n))) rank = jnp.round(jnp.trace(jnp.real(P))).astype(jnp.int32) V_minus, V_plus = _projector_subspace(P, H, n, rank) H_minus = (V_minus.conj().T @ H) @ V_minus H_plus = (V_plus.conj().T @ H) @ V_plus if V0 is not None: V_minus = jnp.dot(V0, V_minus) V_plus = jnp.dot(V0, V_plus) return H_minus, V_minus, H_plus, V_plus, rank
def _empty_svd(a, *, full_matrices, compute_uv): batch_shape = a.shape[:-2] m, n = a.shape[-2:] s = jnp.empty(batch_shape + (0,), dtype=lax_internal._complex_basetype(a.dtype)) if not compute_uv: return (s,) if full_matrices: size = max(m, n) u = jnp.broadcast_to(jnp.eye(size, dtype=a.dtype), batch_shape + (size, size)) else: u = jnp.empty(batch_shape + (m, n), dtype=a.dtype) v = jnp.empty(batch_shape + (0, 0), dtype=a.dtype) if m < n: u, v = v, u return s, u, v
def _lu(a, permute_l): a, = _promote_dtypes_inexact(jnp.asarray(a)) lu, _, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) m, n = jnp.shape(a) p = jnp.real( jnp.array(permutation[None, :] == jnp.arange( m, dtype=permutation.dtype)[:, None], dtype=dtype)) k = min(m, n) l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype) u = jnp.triu(lu)[:k, :] if permute_l: return jnp.matmul(p, l), u else: return p, l, u
def qr_jvp_rule(primals, tangents, full_matrices): # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation. x, = primals dx, = tangents q, r = qr_p.bind(x, full_matrices=False) *_, m, n = x.shape if full_matrices or m < n: raise NotImplementedError( "Unimplemented case of QR decomposition derivative") dx_rinv = triangular_solve(r, dx) # Right side solve by default qt_dx_rinv = jnp.matmul(_H(q), dx_rinv) qt_dx_rinv_lower = jnp.tril(qt_dx_rinv, -1) do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower) # This is skew-symmetric # The following correction is necessary for complex inputs do = do + jnp.eye(n, dtype=do.dtype) * (qt_dx_rinv - jnp.real(qt_dx_rinv)) dq = jnp.matmul(q, do - qt_dx_rinv) + dx_rinv dr = jnp.matmul(qt_dx_rinv - do, r) return (q, r), (dq, dr)
def eigh(H, *, precision="float32", termination_size=256, n=None, sort_eigenvalues=True): """ Computes the eigendecomposition of the symmetric/Hermitian matrix H. Args: H: The `n x n` Hermitian input, padded to `N x N`. precision: :class:`~jax.lax.Precision` object specifying the matmul precision. termination_size: Recursion ends once the blocks reach this linear size. n: the true (dynamic) size of the matrix. sort_eigenvalues: If `True`, the eigenvalues will be sorted from lowest to highest. Returns: vals: The `n` eigenvalues of `H`. vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up to numerical error. """ M, N = H.shape if M != N: raise TypeError(f"Input H of shape {H.shape} must be square.") if N <= termination_size: if n is not None: H = _mask(H, (n, n), jnp.eye(N, dtype=H.dtype)) return lax_linalg.eigh_jacobi(H, sort_eigenvalues=sort_eigenvalues) # TODO(phawkins): consider rounding N up to a larger size to maximize reuse # between matrices. n = N if n is None else n with jax.default_matmul_precision(precision): eig_vals, eig_vecs = _eigh_work(H, n, termination_size=termination_size) eig_vals = _mask(jnp.real(eig_vals), (n, ), jnp.nan) if sort_eigenvalues: sort_idxs = jnp.argsort(eig_vals) eig_vals = eig_vals[sort_idxs] eig_vecs = eig_vecs[:, sort_idxs] return eig_vals, eig_vecs
def phi(X): l = jnp.tril(X) return l / (jnp._constant_like(X, 1) + jnp.eye(X.shape[-1], dtype=X.dtype))
def conv_general_dilated_patches( lhs: lax.Array, filter_shape: Sequence[int], window_strides: Sequence[int], padding: Union[str, Sequence[Tuple[int, int]]], lhs_dilation: Optional[Sequence[int]] = None, rhs_dilation: Optional[Sequence[int]] = None, dimension_numbers: Optional[lax.ConvGeneralDilatedDimensionNumbers] = None, precision: Optional[lax.PrecisionType] = None, preferred_element_type: Optional[DType] = None, ) -> lax.Array: """Extract patches subject to the receptive field of `conv_general_dilated`. Runs the input through a convolution with given parameters. The kernel of the convolution is constructed such that the output channel dimension `"C"` contains flattened image patches, so instead a single `"C"` dimension represents, for example, three dimensions `"chw"` collapsed. The order of these dimensions is `"c" + ''.join(c for c in rhs_spec if c not in 'OI')`, where `rhs_spec == dimension_numbers[1]`, and the size of this `"C"` dimension is therefore the size of each patch, i.e. `np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`, where `lhs_spec == dimension_numbers[0]`. Docstring below adapted from `jax.lax.conv_general_dilated`. See Also: https://www.tensorflow.org/xla/operation_semantics#conv_convolution Args: lhs: a rank `n+2` dimensional input array. filter_shape: a sequence of `n` integers, representing the receptive window spatial shape in the order as specified in `rhs_spec = dimension_numbers[1]`. window_strides: a sequence of `n` integers, representing the inter-window strides. padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. lhs_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `lhs`. LHS dilation is also known as transposed convolution. rhs_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `rhs`. RHS dilation is also known as atrous convolution. dimension_numbers: either `None`, or a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string of length `n+2`. `None` defaults to `("NCHWD..., OIHWD..., NCHWD...")`. precision: Optional. Either ``None``, which means the default precision for the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``). preferred_element_type: Optional. Either ``None``, which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype. Returns: A rank `n+2` array containing the flattened image patches in the output channel (`"C"`) dimension. For example if `dimension_numbers = ("NcHW", "OIwh", "CNHW")`, the output has dimension numbers `"CNHW" = "{cwh}NHW"`, with the size of dimension `"C"` equal to the size of each patch (`np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`). """ filter_shape = tuple(filter_shape) dimension_numbers = lax.conv_dimension_numbers(lhs.shape, (1, 1) + filter_shape, dimension_numbers) lhs_spec, rhs_spec, out_spec = dimension_numbers spatial_size = prod(filter_shape) n_channels = lhs.shape[lhs_spec[1]] # Move separate `lhs` spatial locations into separate `rhs` channels. rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2) rhs = rhs.reshape((spatial_size, 1) + filter_shape) rhs = jnp.tile(rhs, (n_channels, ) + (1, ) * (rhs.ndim - 1)) rhs = jnp.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1])) out = lax.conv_general_dilated( lhs=lhs, rhs=rhs, window_strides=window_strides, padding=padding, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers, precision=None if precision is None else (precision, lax.Precision.DEFAULT), feature_group_count=n_channels, preferred_element_type=preferred_element_type) return out
def _eigh_work(H, n, termination_size=256): """ The main work loop performing the symmetric eigendecomposition of H. Each step recursively computes a projector into the space of eigenvalues above jnp.mean(jnp.diag(H)). The result of the projections into and out of that space, along with the isometries accomplishing these, are then computed. This is performed recursively until the projections have size 1, and thus store an eigenvalue of the original input; the corresponding isometry is the related eigenvector. The results are then composed. This function cannot be Jitted because the internal split_spectrum cannot be. Args: H: The Hermitian input. n: The true (dynamic) shape of H. Returns: H, V: The result of the projection. """ # We turn what was originally a recursive algorithm into an iterative # algorithm with an explicit stack. N, _ = H.shape n = jnp.asarray(n, jnp.int32) agenda = Stack.create( N + 1, _Subproblem(jnp.array(0, jnp.int32), jnp.array(0, jnp.int32))) agenda = agenda.push(_Subproblem(offset=jnp.int32(0), size=n)) # eigenvectors is the array in which we build the output eigenvectors. # We initialize it with the identity matrix so the initial matrix # multiplications in_split_spectrum_jittable are the identity. eigenvectors = jnp.eye(N, dtype=H.dtype) # blocks is an array representing a stack of Hermitian matrix blocks that we # need to recursively decompose. Subproblems are different sizes, so the stack # of blocks is ragged. Subproblems are left-aligned (i.e. starting at the 0th # column). Here is an ASCII art picture of three blocks A, B, C, embedded # in the larger `blocks` workspace (represented with trailing dots). # # A A A . . . # A A A . . . # A A A . . . # B B . . . . # B B . . . . # C C C C . . # C C C C . . # C C C C . . # C C C C . . # # Each step of the algorithm subdivides a block into two subblocks whose # sizes sum to the original block size. We overwrite the original block with # those two subblocks so we don't need any additional scratch space. # # At termination, "blocks" will contain 1x1 blocks (i.e., the eigenvalues) in # its first column. blocks = H def base_case(B, offset, b, agenda, blocks, eigenvectors): # Base case: for blocks under a minimum size, we cutoff the recursion # and call the TPU Jacobi eigendecomposition implementation. The Jacobi # algorithm works well for small matrices but scales poorly, so the two # complement each other well. H = _slice(blocks, (offset, 0), (b, b), (B, B)) V = _slice(eigenvectors, (0, offset), (n, b), (N, B)) # We replace the masked-out part of the matrix with the identity matrix. # We know that the TPU Jacobi eigh implementation will not alter the order # of the eigenvalues, so we know the eigendecomposition of the original # matrix is in the top-left corner of the eigendecomposition of the padded # matrix. # It is very important that the underlying eigh implementation does not sort # the eigenvalues for this reason! This is currently not true of JAX's CPU # and GPU eigendecompositions, and for those platforms this algorithm will # only do the right thing if termination_size == 1. H = _mask(H, (b, b), jnp.eye(B, dtype=H.dtype)) eig_vecs, eig_vals = lax.linalg.eigh(H, sort_eigenvalues=False) eig_vecs = _mask(eig_vecs, (b, b)) eig_vals = _mask(eig_vals, (b,)) eig_vecs = jnp.dot(V, eig_vecs) blocks = _update_slice(blocks, eig_vals[:, None], (offset, 0), (b, b)) eigenvectors = _update_slice(eigenvectors, eig_vecs, (0, offset), (n, b)) return agenda, blocks, eigenvectors def recursive_case(B, offset, b, agenda, blocks, eigenvectors): # The recursive case of the algorithm, specialized to a static block size # of B. H = _slice(blocks, (offset, 0), (b, b), (B, B)) V = _slice(eigenvectors, (0, offset), (n, b), (N, B)) split_point = jnp.nanmedian(_mask(jnp.diag(H), (b,), jnp.nan)) # TODO: Improve this? H_minus, V_minus, H_plus, V_plus, rank = split_spectrum(H, b, split_point, V0=V) blocks = _update_slice(blocks, H_minus, (offset, 0), (rank, rank)) blocks = _update_slice(blocks, H_plus, (offset + rank, 0), (b - rank, b - rank)) eigenvectors = _update_slice(eigenvectors, V_minus, (0, offset), (n, rank)) eigenvectors = _update_slice(eigenvectors, V_plus, (0, offset + rank), (n, b - rank)) agenda = agenda.push(_Subproblem(offset + rank, (b - rank))) agenda = agenda.push(_Subproblem(offset, rank)) return agenda, blocks, eigenvectors def loop_cond(state): agenda, _, _ = state return ~agenda.empty() # It would be wasteful to perform all computation padded up to the original # matrix size. Instead, we form buckets of padded sizes e.g., # [256, 512, 1024, ..., N], aiming for a balance between compilation time # and runtime. cutoff = min(N, termination_size) buckets = [cutoff] branches = [partial(base_case, cutoff)] i = cutoff while i < N: i = min(2 * i, N) buckets.append(i) branches.append(partial(recursive_case, i)) buckets = jnp.array(buckets) def loop_body(state): agenda, blocks, eigenvectors = state (offset, b), agenda = agenda.pop() which = jnp.where(buckets < b, jnp.iinfo(jnp.int32).max, buckets) choice = jnp.argmin(which) return lax.switch(choice, branches, offset, b, agenda, blocks, eigenvectors) _, blocks, eigenvectors = lax.while_loop( loop_cond, loop_body, (agenda, blocks, eigenvectors)) return blocks[:, 0], eigenvectors