Ejemplo n.º 1
0
    def parallel_transport(self, tangent_vec_a, tangent_vec_b, base_point):
        r"""Parallel transport of a tangent vector.

        Closed-form solution for the parallel transport of a tangent vector a
        along the geodesic defined by exp_(base_point)(tangent_vec_b).
        Denoting `tangent_vec_a` by `S`, `base_point` by `A`, let
        `B = Exp_A(tangent_vec_b)` and :math: `E = (BA^{- 1})^({ 1 / 2})`.
        Then the
        parallel transport to `B`is:

        ..math::
                        S' = ESE^T

        Parameters
        ----------
        tangent_vec_a : array-like, shape=[..., dim + 1]
            Tangent vector at base point to be transported.
        tangent_vec_b : array-like, shape=[..., dim + 1]
            Tangent vector at base point, initial speed of the geodesic along
            which the parallel transport is computed.
        base_point : array-like, shape=[..., dim + 1]
            Point on the manifold of SPD matrices.

        Returns
        -------
        transported_tangent_vec: array-like, shape=[..., dim + 1]
            Transported tangent vector at exp_(base_point)(tangent_vec_b).
        """
        end_point = self.exp(tangent_vec_b, base_point)
        inverse_base_point = GeneralLinear.inverse(base_point)
        congruence_mat = GeneralLinear.mul(end_point, inverse_base_point)
        congruence_mat = gs.linalg.sqrtm(congruence_mat)
        return GeneralLinear.congruent(tangent_vec_a, congruence_mat)
    def transp(self, base_point, end_point, tangent):
        """

        transports a tangent vector at a base_point
        to the tangent space at end_point.
        """

        if self.exact_transport:
            # https://github.com/geomstats/geomstats/blob/master/geomstats/geometry/spd_matrices.py#L613
            inverse_base_point = GeneralLinear.inverse(base_point)
            congruence_mat = GeneralLinear.mul(end_point, inverse_base_point)
            congruence_mat = gs.linalg.sqrtm(congruence_mat.cpu()).to(tangent)
            return GeneralLinear.congruent(tangent, congruence_mat)

        # https://github.com/NicolasBoumal/manopt/blob/master/manopt/manifolds/symfixedrank/sympositivedefinitefactory.m#L181
        return tangent