def svd_jvp_rule(primals, tangents, full_matrices, compute_uv): A, = primals dA, = tangents s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True) if compute_uv and full_matrices: # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf raise NotImplementedError( "Singular value decomposition JVP not implemented for full matrices" ) k = s.shape[-1] Ut, V = _H(U), _H(Vt) s_dim = s[..., None, :] dS = np.matmul(np.matmul(Ut, dA), V) ds = np.real(np.diagonal(dS, 0, -2, -1)) F = 1 / (np.square(s_dim) - np.square(_T(s_dim)) + np.eye(k, dtype=A.dtype)) F = F - np.eye(k, dtype=A.dtype) dSS = s_dim * dS SdS = _T(s_dim) * dS dU = np.matmul(U, F * (dSS + _T(dSS))) dV = np.matmul(V, F * (SdS + _T(SdS))) m, n = A.shape[-2:] if m > n: dU = dU + np.matmul( np.eye(m, dtype=A.dtype) - np.matmul(U, Ut), np.matmul(dA, V)) / s_dim if n > m: dV = dV + np.matmul( np.eye(n, dtype=A.dtype) - np.matmul(V, Vt), np.matmul(_H(dA), U)) / s_dim return (s, U, Vt), (ds, dU, _T(dV))
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv): if full_matrices: #TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf raise NotImplementedError( "Singular value decomposition JVP not implemented for full matrices" ) A, = primals dA, = tangents s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True) k = s.shape[-1] Ut, V = np.conj(U).T, np.conj(Vt).T s_dim = s[..., None, :] dS = np.dot(np.dot(Ut, dA), V) ds = np.real(np.diag(dS)) F = 1 / (np.square(s_dim) - np.square(s_dim.T) + np.eye(k)) - np.eye(k) dSS = s_dim * dS SdS = s_dim.T * dS dU = np.dot(U, F * (dSS + dSS.T)) dV = np.dot(V, F * (SdS + SdS.T)) m, n = A.shape[-2], A.shape[-1] if m > n: dU = dU + np.dot(np.eye(m) - np.dot(U, Ut), np.dot(dA, V)) / s_dim if n > m: dV = dV + np.dot(np.eye(n) - np.dot(V, Vt), np.dot(np.conj(dA).T, U)) / s_dim return core.pack((s, U, Vt)), core.pack((ds, dU, dV.T))
def _lu_jvp_rule(primals, tangents): a, = primals a_dot, = tangents lu, pivots = lu_p.bind(a) if a_dot is ad_util.zero: return (core.pack( (lu, pivots)), ad.TangentTuple((ad_util.zero, ad_util.zero))) a_shape = np.shape(a) m, n = a_shape[-2:] dtype = lax.dtype(a) k = min(m, n) permutation = lu_pivots_to_permutation(pivots, m) batch_dims = a_shape[:-2] iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims + (1, ))) x = a_dot[iotas[:-1] + (permutation, slice(None))] # Differentiation of Matrix Functionals Using Triangular Factorization # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas # # LU = A # ==> L'U + LU' = A' # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U) # ==> L' = L . tril(inv(L) . A' . inv(U), -1) # U' = triu(inv(L) . A' . inv(U)) . U ndims = len(a_shape) l_padding = [(0, 0, 0)] * ndims l_padding[-1] = (0, m - k, 0) zero = np._constant_like(lu, 0) l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding) l = l + np.eye(m, m, dtype=dtype) u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero, ((k, 0, 0), (k, 0, 0))) u_padding = [(0, 0, 0)] * ndims u_padding[-2] = (0, n - k, 0) u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True, unit_diagonal=True) lau = triangular_solve(u, la, left_side=False, transpose_a=False, lower=False) l_dot = np.matmul(l, np.tril(lau, -1)) u_dot = np.matmul(np.triu(lau), u) lu_dot = l_dot + u_dot return (lu, pivots), (lu_dot, ad_util.zero)
def lu_jvp_rule(primals, tangents): a, = primals a_dot, = tangents lu, pivots = lu_p.bind(a) a_shape = np.shape(a) m, n = a_shape[-2:] dtype = lax._dtype(a) k = min(m, n) # TODO(phawkins): use a gather rather than a matrix multiplication here. permutation = lu_pivots_to_permutation(pivots, m) p = np.array(permutation[:, None] == np.arange(m), dtype=dtype) x = np.matmul(p, a_dot) # Differentiation of Matrix Functionals Using Triangular Factorization # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas # # LU = A # ==> L'U + LU' = A' # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U) # ==> L' = L . tril(inv(L) . A' . inv(U), -1) # U' = triu(inv(L) . A' . inv(U)) . U ndims = len(a_shape) l_padding = [(0, 0, 0)] * ndims l_padding[-1] = (0, m - k, 0) zero = np._constant_like(lu, 0) l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding) l = l + np.eye(m, m, dtype=dtype) u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero, ((k, 0, 0), (k, 0, 0))) u_padding = [(0, 0, 0)] * ndims u_padding[-2] = (0, n - k, 0) u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True) lau = triangular_solve(u, la, left_side=False, transpose_a=False, lower=False) l_dot = np.matmul(l, np.tril(lau, -1)) u_dot = np.matmul(np.triu(lau), u) lu_dot = l_dot + u_dot return core.pack((lu, pivots)), ad.TangentTuple((lu_dot, ad_util.zero))
def eigh_jvp_rule(primals, tangents, lower): # Derivative for eigh in the simplest case of distinct eigenvalues. # This is classic nondegenerate perurbation theory, but also see # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf # The general solution treating the case of degenerate eigenvalues is # considerably more complicated. Ambitious readers may refer to the general # methods below or refer to degenerate perturbation theory in physics. # https://www.win.tue.nl/analysis/reports/rana06-33.pdf and # https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf a, = primals a_dot, = tangents v, w = eigh_p.bind(symmetrize(a), lower=lower) # for complex numbers we need eigenvalues to be full dtype of v, a: w = w.astype(a.dtype) eye_n = np.eye(a.shape[-1], dtype=a.dtype) # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs. Fmat = np.reciprocal(eye_n + w[..., np.newaxis, :] - w[..., np.newaxis]) - eye_n # eigh impl doesn't support batch dims, but future-proof the grad. dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul, precision=lax.Precision.HIGHEST) vdag_adot_v = dot(dot(_H(v), a_dot), v) dv = dot(v, np.multiply(Fmat, vdag_adot_v)) dw = np.diagonal(vdag_adot_v, axis1=-2, axis2=-1) return (v, w), (dv, dw)
def cholesky_jvp_rule(primals, tangents): x, = primals sigma_dot, = tangents L = cholesky_p.bind(x) # Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf phi = lambda X: np.tril(X) / (1 + np.eye(X.shape[-1], dtype=X.dtype)) tmp = triangular_solve(L, sigma_dot, left_side=False, transpose_a=True, lower=True) L_dot = lax.batch_matmul(L, phi(triangular_solve( L, tmp, left_side=True, transpose_a=False, lower=True))) return L, L_dot
def qr_jvp_rule(primals, tangents, full_matrices): # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation. x, = primals dx, = tangents q, r = qr_p.bind(x, full_matrices=False) *_, m, n = x.shape if full_matrices or m < n: raise NotImplementedError( "Unimplemented case of QR decomposition derivative") dx_rinv = triangular_solve(r, dx) # Right side solve by default qt_dx_rinv = jnp.matmul(_H(q), dx_rinv) qt_dx_rinv_lower = jnp.tril(qt_dx_rinv, -1) do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower) # This is skew-symmetric # The following correction is necessary for complex inputs do = do + jnp.eye(n, dtype=do.dtype) * (qt_dx_rinv - jnp.real(qt_dx_rinv)) dq = jnp.matmul(q, do - qt_dx_rinv) + dx_rinv dr = jnp.matmul(qt_dx_rinv - do, r) return (q, r), (dq, dr)
def phi(X): l = np.tril(X) return l / (np._constant_like(X, 1) + np.eye(X.shape[-1], dtype=X.dtype))