def eigh_jvp_rule(primals, tangents, lower): # Derivative for eigh in the simplest case of distinct eigenvalues. # This is classic nondegenerate perurbation theory, but also see # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf # The general solution treating the case of degenerate eigenvalues is # considerably more complicated. Ambitious readers may refer to the general # methods below or refer to degenerate perturbation theory in physics. # https://www.win.tue.nl/analysis/reports/rana06-33.pdf and # https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf a, = primals a_dot, = tangents v, w = eigh_p.bind(symmetrize(a), lower=lower) # for complex numbers we need eigenvalues to be full dtype of v, a: w = w.astype(a.dtype) eye_n = np.eye(a.shape[-1], dtype=a.dtype) # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs. Fmat = np.reciprocal(eye_n + w[..., np.newaxis, :] - w[..., np.newaxis]) - eye_n # eigh impl doesn't support batch dims, but future-proof the grad. dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul, precision=lax.Precision.HIGHEST) vdag_adot_v = dot(dot(_H(v), a_dot), v) dv = dot(v, np.multiply(Fmat, vdag_adot_v)) dw = np.diagonal(vdag_adot_v, axis1=-2, axis2=-1) return (v, w), (dv, dw)
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv): A, = primals dA, = tangents s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True) if compute_uv and full_matrices: # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf raise NotImplementedError( "Singular value decomposition JVP not implemented for full matrices" ) k = s.shape[-1] Ut, V = _H(U), _H(Vt) s_dim = s[..., None, :] dS = np.matmul(np.matmul(Ut, dA), V) ds = np.real(np.diagonal(dS, 0, -2, -1)) F = 1 / (np.square(s_dim) - np.square(_T(s_dim)) + np.eye(k, dtype=A.dtype)) F = F - np.eye(k, dtype=A.dtype) dSS = s_dim * dS SdS = _T(s_dim) * dS dU = np.matmul(U, F * (dSS + _T(dSS))) dV = np.matmul(V, F * (SdS + _T(SdS))) m, n = A.shape[-2:] if m > n: dU = dU + np.matmul( np.eye(m, dtype=A.dtype) - np.matmul(U, Ut), np.matmul(dA, V)) / s_dim if n > m: dV = dV + np.matmul( np.eye(n, dtype=A.dtype) - np.matmul(V, Vt), np.matmul(_H(dA), U)) / s_dim return (s, U, Vt), (ds, dU, _T(dV))