def svd_jvp_rule(primals, tangents, full_matrices, compute_uv): if full_matrices: #TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf raise NotImplementedError( "Singular value decomposition JVP not implemented for full matrices" ) A, = primals dA, = tangents s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True) k = s.shape[-1] Ut, V = np.conj(U).T, np.conj(Vt).T s_dim = s[..., None, :] dS = np.dot(np.dot(Ut, dA), V) ds = np.real(np.diag(dS)) F = 1 / (np.square(s_dim) - np.square(s_dim.T) + np.eye(k)) - np.eye(k) dSS = s_dim * dS SdS = s_dim.T * dS dU = np.dot(U, F * (dSS + dSS.T)) dV = np.dot(V, F * (SdS + SdS.T)) m, n = A.shape[-2], A.shape[-1] if m > n: dU = dU + np.dot(np.eye(m) - np.dot(U, Ut), np.dot(dA, V)) / s_dim if n > m: dV = dV + np.dot(np.eye(n) - np.dot(V, Vt), np.dot(np.conj(dA).T, U)) / s_dim return core.pack((s, U, Vt)), core.pack((ds, dU, dV.T))
def multi_dot(arrays, *, precision=None): n = len(arrays) # optimization only makes sense for len(arrays) > 2 if n < 2: raise ValueError("Expecting at least two arrays.") elif n == 2: return jnp.dot(arrays[0], arrays[1], precision=precision) arrays = [jnp.asarray(a) for a in arrays] # save original ndim to reshape the result array into the proper form later ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim # Explicitly convert vectors to 2D arrays to keep the logic of the internal # _multi_dot_* functions as simple as possible. if arrays[0].ndim == 1: arrays[0] = jnp.atleast_2d(arrays[0]) if arrays[-1].ndim == 1: arrays[-1] = jnp.atleast_2d(arrays[-1]).T _assert2d(*arrays) # _multi_dot_three is much faster than _multi_dot_matrix_chain_order if n == 3: result = _multi_dot_three(*arrays, precision) else: order = _multi_dot_matrix_chain_order(arrays) result = _multi_dot(arrays, order, 0, n - 1, precision) # return proper shape if ndim_first == 1 and ndim_last == 1: return result[0, 0] # scalar elif ndim_first == 1 or ndim_last == 1: return result.ravel() # 1-D else: return result
def _multi_dot(arrays, order, i, j, precision): """Actually do the multiplication with the given order.""" if i == j: return arrays[i] else: return jnp.dot(_multi_dot(arrays, order, i, order[i, j], precision), _multi_dot(arrays, order, order[i, j] + 1, j, precision), precision=precision)
def _multi_dot_three(A, B, C, precision): """ Find the best order for three arrays and do the multiplication. For three arguments `_multi_dot_three` is approximately 15 times faster than `_multi_dot_matrix_chain_order` """ a0, a1b0 = A.shape b1c0, c1 = C.shape # cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1 cost1 = a0 * b1c0 * (a1b0 + c1) # cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1 cost2 = a1b0 * c1 * (a0 + b1c0) if cost1 < cost2: return jnp.dot(jnp.dot(A, B, precision=precision), C, precision=precision) else: return jnp.dot(A, jnp.dot(B, C, precision=precision), precision=precision)