Ejemplo n.º 1
0
    def _log_translation_transform(self, rot_vec):
        exp_transform = self._exp_translation_transform(rot_vec)

        inv_determinant = .5 / utils.taylor_exp_even_func(
            rot_vec**2, utils.cosc_close_0, order=4)
        transform = gs.einsum('...l, ...jk -> ...jk', inv_determinant,
                              GeneralLinear.transpose(exp_transform))

        return transform
Ejemplo n.º 2
0
 def _to_lie_algebra(self, tangent_vec):
     """Project vector rotation part onto skew-symmetric matrices."""
     translation_mask = gs.hstack(
         [gs.ones((self.n, ) * 2), 2 * gs.ones((self.n, 1))])
     translation_mask = gs.concatenate(
         [translation_mask, gs.zeros((1, self.n + 1))], axis=0)
     tangent_vec = tangent_vec * gs.where(translation_mask != 0.,
                                          gs.array(1.), gs.array(0.))
     tangent_vec = (tangent_vec - GeneralLinear.transpose(tangent_vec)) / 2.
     return tangent_vec * translation_mask
Ejemplo n.º 3
0
    def regularize_tangent_vec_at_identity(self,
                                           tangent_vec,
                                           metric=None,
                                           point_type=None):
        """Regularize a tangent vector at the identity.

        Parameters
        ----------
        tangent_vec: array-like, shape=[n_samples, {dim, [n + 1, n + 1]}]
        metric : RiemannianMetric, optional
        point_type : str, {'vector', 'matrix'}, optional
            default: self.default_point_type

        Returns
        -------
        regularized_vec : the regularized tangent vector
        """
        if point_type == 'vector':
            return self.regularize_tangent_vec(tangent_vec,
                                               self.identity,
                                               metric,
                                               point_type=point_type)

        if point_type == 'matrix':
            translation_mask = gs.hstack(
                [gs.ones((self.n, ) * 2), 2 * gs.ones((self.n, 1))])
            translation_mask = gs.concatenate(
                [translation_mask, gs.zeros((1, self.n + 1))], axis=0)
            tangent_vec = tangent_vec * gs.where(translation_mask != 0.,
                                                 gs.array(1.), gs.array(0.))
            tangent_vec = (tangent_vec -
                           GeneralLinear.transpose(tangent_vec)) / 2.
            return tangent_vec * translation_mask

        raise ValueError('Invalid point_type, expected \'vector\' or '
                         '\'matrix\'.')
Ejemplo n.º 4
0
 def tangent_submersion(tangent_vec, base_point):
     product = GeneralLinear.mul(base_point,
                                 GeneralLinear.transpose(tangent_vec))
     return 2 * GeneralLinear.to_symmetric(product)
Ejemplo n.º 5
0
 def submersion(point):
     return GeneralLinear.mul(point, GeneralLinear.transpose(point))
Ejemplo n.º 6
0
 def _to_lie_algebra(self, tangent_vec):
     """Project vector rotation part onto skew-symmetric matrices."""
     tangent_vec = tangent_vec * gs.where(self.translation_mask != 0.,
                                          gs.array(1.), gs.array(0.))
     tangent_vec = (tangent_vec - GeneralLinear.transpose(tangent_vec)) / 2.
     return tangent_vec * self.translation_mask