Example #1
0
    def inner(self, u, vec1, vec2):
        """Returns manifold wise inner product of vectors from
        a tangent space.

        Args:
            u: complex valued tensor of shape (..., n, p),
                a set of points from the complex Stiefel
                manifold.
            vec1: complex valued tensor of shape (..., n, p),
                a set of tangent vectors from the complex
                Stiefel manifold.
            vec2: complex valued tensor of shape (..., n, p),
                a set of tangent vectors from the complex
                Stiefel manifold.

        Returns:
            complex valued tensor of shape (..., 1, 1),
            manifold wise inner product"""

        if self._metric == 'euclidean':
            s_sq = jax.numpy.trace(adj(vec1) @ vec2, axis1=-1, axis2=-2)[...,
                                                                         None,
                                                                         None]
        elif self._metric == 'canonical':
            G = jax.numpy.eye(u.shape[-2], dtype=u.dtype) - u @ adj(u) / 2
            s_sq = jax.numpy.trace(adj(vec1) @ G @ vec2, axis1=-1, axis2=-2)[...,
                                                                             None,
                                                                             None]
        return jax.numpy.real(s_sq)
Example #2
0
    def retraction(self, u, vec):
        """Transports a set of points from the complex Stiefel
        manifold via a retraction map.

        Args:
            u: complex valued tensor of shape (..., n, p), a set
                of points to be transported.
            vec: complex valued tensor of shape (..., n, p),
                a set of direction vectors.

        Returns:
            complex valued tensor of shape (..., n, p),
            a set of transported points."""

        if self._retraction == 'svd':
            new_u = u + vec
            # _, v, w = tf.linalg.svd(new_u)
            v, _, wh = jax.numpy.linalg.svd(new_u, full_matrices=False)
            return v @ wh

        elif self._retraction == 'cayley':
            W = vec @ adj(u) - 0.5 * u @ (adj(u) @ vec @ adj(u))
            W = W - adj(W)
            Id = jax.numpy.eye(W.shape[-1], dtype=W.dtype)
            return jax.numpy.linalg.inv(Id - W / 2) @ (Id + W / 2) @ u

        elif self._retraction == 'qr':
            new_u = u + vec
            q, r = jax.numpy.linalg.qr(new_u)
            diag = jax.numpy.diag(r)
            sign = jax.numpy.sign(diag)[..., None, :]
            return q * sign
Example #3
0
    def proj(self, u, vec):
        """Returns projection of vectors on a tangent space
        of the complex Stiefel manifold.

        Args:
            u: complex valued tensor of shape (..., n, p),
                a set of points from the complex Stiefel
                manifold.
            vec: complex valued tensor of shape (..., n, p),
                a set of vectors to be projected.

        Returns:
            complex valued tensor of shape (..., n, p),
            a set of projected vectors"""

        return vec - 0.5 * u @ (adj(u) @ vec + adj(vec) @ u)
Example #4
0
    def egrad_to_rgrad(self, u, egrad):
        """Returns the Riemannian gradient from an Euclidean gradient.

        Args:
            u: complex valued tensor of shape (..., n, p),
                a set of points from the complex Stiefel
                manifold.
            egrad: complex valued tensor of shape (..., n, p),
                a set of Euclidean gradients.

        Returns:
            complex valued tensor of shape (..., n, p),
            the set of Reimannian gradients."""

        if self._metric == 'euclidean':
            return egrad - 0.5 * u @ (adj(u) @ egrad + adj(egrad) @ u)

        elif self._metric == 'canonical':
            return egrad - u @ adj(egrad) @ u
Example #5
0
    def is_in_manifold(self, u, tol=1e-5):
        """Checks if a point is in the Stiefel manifold or not.

        Args:
            u: complex valued tensor of shape (..., n, p),
                a point to be checked.
            tol: small real value showing tolerance.

        Returns:
            bolean tensor of shape (...)."""

        Id = jax.numpy.eye(u.shape[-1], dtype=u.dtype)
        udagu = adj(u) @ u
        diff = Id - udagu
        diff_norm = jax.numpy.linalg.norm(diff, axis=(-1,-2))
        udagu_norm = jax.numpy.linalg.norm(udagu, axis=(-1,-2))
        Id_norm = jax.numpy.linalg.norm(Id, axis=(-1,-2))
        rel_diff = jax.numpy.abs(diff_norm / jax.numpy.sqrt(Id_norm * udagu_norm))
        return tol > rel_diff