def qr_jvp_rule(primals, tangents, full_matrices): # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation. x, = primals dx, = tangents q, r = qr_p.bind(x, full_matrices=False) if full_matrices or np.shape(x)[-2] < np.shape(x)[-1]: raise NotImplementedError dx_rinv = triangular_solve(r, dx) # Right side solve by default qt_dx_rinv = np.matmul(_H(q), dx_rinv) qt_dx_rinv_lower = np.tril(qt_dx_rinv, -1) domega = qt_dx_rinv_lower - _H(qt_dx_rinv_lower) # This is skew-symmetric dq = np.matmul(q, domega - qt_dx_rinv) + dx_rinv dr = np.matmul(qt_dx_rinv - domega, r) return (q, r), (dq, dr)
def _lu_jvp_rule(primals, tangents): a, = primals a_dot, = tangents lu, pivots = lu_p.bind(a) if a_dot is ad_util.zero: return (core.pack( (lu, pivots)), ad.TangentTuple((ad_util.zero, ad_util.zero))) a_shape = np.shape(a) m, n = a_shape[-2:] dtype = lax.dtype(a) k = min(m, n) permutation = lu_pivots_to_permutation(pivots, m) batch_dims = a_shape[:-2] iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims + (1, ))) x = a_dot[iotas[:-1] + (permutation, slice(None))] # Differentiation of Matrix Functionals Using Triangular Factorization # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas # # LU = A # ==> L'U + LU' = A' # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U) # ==> L' = L . tril(inv(L) . A' . inv(U), -1) # U' = triu(inv(L) . A' . inv(U)) . U ndims = len(a_shape) l_padding = [(0, 0, 0)] * ndims l_padding[-1] = (0, m - k, 0) zero = np._constant_like(lu, 0) l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding) l = l + np.eye(m, m, dtype=dtype) u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero, ((k, 0, 0), (k, 0, 0))) u_padding = [(0, 0, 0)] * ndims u_padding[-2] = (0, n - k, 0) u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True, unit_diagonal=True) lau = triangular_solve(u, la, left_side=False, transpose_a=False, lower=False) l_dot = np.matmul(l, np.tril(lau, -1)) u_dot = np.matmul(np.triu(lau), u) lu_dot = l_dot + u_dot return (lu, pivots), (lu_dot, ad_util.zero)
def lu_jvp_rule(primals, tangents): a, = primals a_dot, = tangents lu, pivots = lu_p.bind(a) a_shape = np.shape(a) m, n = a_shape[-2:] dtype = lax._dtype(a) k = min(m, n) # TODO(phawkins): use a gather rather than a matrix multiplication here. permutation = lu_pivots_to_permutation(pivots, m) p = np.array(permutation[:, None] == np.arange(m), dtype=dtype) x = np.matmul(p, a_dot) # Differentiation of Matrix Functionals Using Triangular Factorization # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas # # LU = A # ==> L'U + LU' = A' # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U) # ==> L' = L . tril(inv(L) . A' . inv(U), -1) # U' = triu(inv(L) . A' . inv(U)) . U ndims = len(a_shape) l_padding = [(0, 0, 0)] * ndims l_padding[-1] = (0, m - k, 0) zero = np._constant_like(lu, 0) l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding) l = l + np.eye(m, m, dtype=dtype) u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero, ((k, 0, 0), (k, 0, 0))) u_padding = [(0, 0, 0)] * ndims u_padding[-2] = (0, n - k, 0) u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True) lau = triangular_solve(u, la, left_side=False, transpose_a=False, lower=False) l_dot = np.matmul(l, np.tril(lau, -1)) u_dot = np.matmul(np.triu(lau), u) lu_dot = l_dot + u_dot return core.pack((lu, pivots)), ad.TangentTuple((lu_dot, ad_util.zero))
def lu_pivots_to_permutation(swaps, k): """Converts the pivots (row swaps) returned by LU to a permutation.""" def body_fn(i, loop_carry): swaps, permutation = loop_carry j = swaps[i] x, y = np.ravel(permutation[i]), np.ravel(permutation[j]) permutation = lax.dynamic_update_index_in_dim(permutation, y, i, axis=0) permutation = lax.dynamic_update_index_in_dim(permutation, x, j, axis=0) return swaps, permutation n, = np.shape(swaps) permutation = np.arange(k) _, permutation = lax.fori_loop(onp.array(0, onp.int32), onp.array(n, onp.int32), body_fn, (swaps, permutation)) return permutation