def svd_jvp_rule(primals, tangents, full_matrices, compute_uv): A, = primals dA, = tangents s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True) if compute_uv and full_matrices: # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf raise NotImplementedError( "Singular value decomposition JVP not implemented for full matrices" ) k = s.shape[-1] Ut, V = _H(U), _H(Vt) s_dim = s[..., None, :] dS = np.matmul(np.matmul(Ut, dA), V) ds = np.real(np.diagonal(dS, 0, -2, -1)) F = 1 / (np.square(s_dim) - np.square(_T(s_dim)) + np.eye(k, dtype=A.dtype)) F = F - np.eye(k, dtype=A.dtype) dSS = s_dim * dS SdS = _T(s_dim) * dS dU = np.matmul(U, F * (dSS + _T(dSS))) dV = np.matmul(V, F * (SdS + _T(SdS))) m, n = A.shape[-2:] if m > n: dU = dU + np.matmul( np.eye(m, dtype=A.dtype) - np.matmul(U, Ut), np.matmul(dA, V)) / s_dim if n > m: dV = dV + np.matmul( np.eye(n, dtype=A.dtype) - np.matmul(V, Vt), np.matmul(_H(dA), U)) / s_dim return (s, U, Vt), (ds, dU, _T(dV))
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv): if full_matrices: #TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf raise NotImplementedError( "Singular value decomposition JVP not implemented for full matrices" ) A, = primals dA, = tangents s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True) k = s.shape[-1] Ut, V = np.conj(U).T, np.conj(Vt).T s_dim = s[..., None, :] dS = np.dot(np.dot(Ut, dA), V) ds = np.real(np.diag(dS)) F = 1 / (np.square(s_dim) - np.square(s_dim.T) + np.eye(k)) - np.eye(k) dSS = s_dim * dS SdS = s_dim.T * dS dU = np.dot(U, F * (dSS + dSS.T)) dV = np.dot(V, F * (SdS + SdS.T)) m, n = A.shape[-2], A.shape[-1] if m > n: dU = dU + np.dot(np.eye(m) - np.dot(U, Ut), np.dot(dA, V)) / s_dim if n > m: dV = dV + np.dot(np.eye(n) - np.dot(V, Vt), np.dot(np.conj(dA).T, U)) / s_dim return core.pack((s, U, Vt)), core.pack((ds, dU, dV.T))