def inner(self, u, vec1, vec2): """Returns manifold wise inner product of vectors from a tangent space. Args: u: complex valued tensor of shape (..., n, p), a set of points from the complex Stiefel manifold. vec1: complex valued tensor of shape (..., n, p), a set of tangent vectors from the complex Stiefel manifold. vec2: complex valued tensor of shape (..., n, p), a set of tangent vectors from the complex Stiefel manifold. Returns: complex valued tensor of shape (..., 1, 1), manifold wise inner product""" if self._metric == 'euclidean': s_sq = jax.numpy.trace(adj(vec1) @ vec2, axis1=-1, axis2=-2)[..., None, None] elif self._metric == 'canonical': G = jax.numpy.eye(u.shape[-2], dtype=u.dtype) - u @ adj(u) / 2 s_sq = jax.numpy.trace(adj(vec1) @ G @ vec2, axis1=-1, axis2=-2)[..., None, None] return jax.numpy.real(s_sq)
def retraction(self, u, vec): """Transports a set of points from the complex Stiefel manifold via a retraction map. Args: u: complex valued tensor of shape (..., n, p), a set of points to be transported. vec: complex valued tensor of shape (..., n, p), a set of direction vectors. Returns: complex valued tensor of shape (..., n, p), a set of transported points.""" if self._retraction == 'svd': new_u = u + vec # _, v, w = tf.linalg.svd(new_u) v, _, wh = jax.numpy.linalg.svd(new_u, full_matrices=False) return v @ wh elif self._retraction == 'cayley': W = vec @ adj(u) - 0.5 * u @ (adj(u) @ vec @ adj(u)) W = W - adj(W) Id = jax.numpy.eye(W.shape[-1], dtype=W.dtype) return jax.numpy.linalg.inv(Id - W / 2) @ (Id + W / 2) @ u elif self._retraction == 'qr': new_u = u + vec q, r = jax.numpy.linalg.qr(new_u) diag = jax.numpy.diag(r) sign = jax.numpy.sign(diag)[..., None, :] return q * sign
def proj(self, u, vec): """Returns projection of vectors on a tangent space of the complex Stiefel manifold. Args: u: complex valued tensor of shape (..., n, p), a set of points from the complex Stiefel manifold. vec: complex valued tensor of shape (..., n, p), a set of vectors to be projected. Returns: complex valued tensor of shape (..., n, p), a set of projected vectors""" return vec - 0.5 * u @ (adj(u) @ vec + adj(vec) @ u)
def egrad_to_rgrad(self, u, egrad): """Returns the Riemannian gradient from an Euclidean gradient. Args: u: complex valued tensor of shape (..., n, p), a set of points from the complex Stiefel manifold. egrad: complex valued tensor of shape (..., n, p), a set of Euclidean gradients. Returns: complex valued tensor of shape (..., n, p), the set of Reimannian gradients.""" if self._metric == 'euclidean': return egrad - 0.5 * u @ (adj(u) @ egrad + adj(egrad) @ u) elif self._metric == 'canonical': return egrad - u @ adj(egrad) @ u
def is_in_manifold(self, u, tol=1e-5): """Checks if a point is in the Stiefel manifold or not. Args: u: complex valued tensor of shape (..., n, p), a point to be checked. tol: small real value showing tolerance. Returns: bolean tensor of shape (...).""" Id = jax.numpy.eye(u.shape[-1], dtype=u.dtype) udagu = adj(u) @ u diff = Id - udagu diff_norm = jax.numpy.linalg.norm(diff, axis=(-1,-2)) udagu_norm = jax.numpy.linalg.norm(udagu, axis=(-1,-2)) Id_norm = jax.numpy.linalg.norm(Id, axis=(-1,-2)) rel_diff = jax.numpy.abs(diff_norm / jax.numpy.sqrt(Id_norm * udagu_norm)) return tol > rel_diff