def test_skew_to_vec_and_back(self):
     group = SpecialOrthogonal(n=4)
     vec = gs.random.rand(group.dim)
     mat = group.skew_matrix_from_vector(vec)
     result = group.vector_from_skew_matrix(mat)
     self.assertAllClose(result, vec)
Esempio n. 2
0
"""Predict on SO(3): losses."""

import logging

import geomstats.backend as gs
import geomstats.geometry.lie_group as lie_group
from geomstats.geometry.special_orthogonal import SpecialOrthogonal


SO3 = SpecialOrthogonal(n=3, point_type='vector')


def loss(y_pred, y_true,
         metric=SO3.bi_invariant_metric,
         representation='vector'):
    """Loss function given by a Riemannian metric on a Lie group.

    Parameters
    ----------
    y_pred : array-like
        Prediction on SO(3).
    y_true : array-like
        Ground-truth on SO(3).
    metric : RiemannianMetric
        Metric used to compute the loss and gradient.
    representation : str, {'vector', 'matrix'}
        Representation chosen for points in SO(3).

    Returns
    -------
    lie_loss : array-like
Esempio n. 3
0
"""Perform tangent PCA at the mean."""

import matplotlib.pyplot as plt
import numpy as np

import geomstats.visualization as visualization
from geomstats.geometry.special_orthogonal import SpecialOrthogonal
from geomstats.learning.pca import TangentPCA

SO3_GROUP = SpecialOrthogonal(n=3)
METRIC = SO3_GROUP.bi_invariant_metric

N_SAMPLES = 10
N_COMPONENTS = 2


def main():
    """Perform tangent PCA at the mean."""
    fig = plt.figure(figsize=(15, 5))

    data = SO3_GROUP.random_uniform(n_samples=N_SAMPLES)
    mean = METRIC.mean(data)

    tpca = TangentPCA(metric=METRIC, n_components=N_COMPONENTS)
    tpca = tpca.fit(data, base_point=mean)
    tangent_projected_data = tpca.transform(data)
    print(
        'Coordinates of the Log of the first 5 data points at the mean, '
        'projected on the principal components:')
    print(tangent_projected_data[:5])
Esempio n. 4
0
"""Plot the pole ladder scheme for parallel transport on S2.

Sample a point on S2 and two tangent vectors to transport one along the
other.
"""

import matplotlib.pyplot as plt

import geomstats.backend as gs
import geomstats.visualization as visualization
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.geometry.special_orthogonal import SpecialOrthogonal

SPACE = Hypersphere(2)
METRIC = SPACE.metric
ROTATIONS = SpecialOrthogonal(3, 'vector')

N_STEPS = 4
N_POINTS = 10

gs.random.seed(1)


def main():
    """Compute pole ladder and plot the construction."""
    base_point = SPACE.random_uniform(1)
    tangent_vec_b = SPACE.random_uniform(1)
    tangent_vec_b = SPACE.to_tangent(tangent_vec_b, base_point)
    tangent_vec_b = tangent_vec_b / gs.linalg.norm(tangent_vec_b)

    rotation_vector = gs.pi / 2 * base_point
Esempio n. 5
0
    def setUp(self):
        logger = logging.getLogger()
        logger.disabled = True
        warnings.simplefilter('ignore', category=ImportWarning)

        gs.random.seed(1234)

        n = 3
        group = SpecialEuclidean(n=n, point_type='vector')
        matrix_se3 = SpecialEuclidean(n=n)
        matrix_so3 = SpecialOrthogonal(n=n)
        vector_so3 = SpecialOrthogonal(n=n, point_type='vector')

        # Diagonal left and right invariant metrics
        diag_mat_at_identity = gs.eye(group.dim)

        left_diag_metric = InvariantMetric(group=group,
                                           metric_mat_at_identity=None,
                                           left_or_right='left')
        right_diag_metric = InvariantMetric(
            group=group,
            metric_mat_at_identity=diag_mat_at_identity,
            left_or_right='right')

        sym_mat_at_identity = gs.eye(group.dim)

        left_metric = InvariantMetric(
            group=group,
            metric_mat_at_identity=sym_mat_at_identity,
            left_or_right='left')

        right_metric = InvariantMetric(
            group=group,
            metric_mat_at_identity=sym_mat_at_identity,
            left_or_right='right')

        matrix_left_metric = InvariantMetric(group=matrix_so3)

        matrix_right_metric = InvariantMetric(group=matrix_so3,
                                              left_or_right='right')

        # General case for the point
        point_1 = gs.array([-0.2, 0.9, 0.5, 5., 5., 5.])
        point_2 = gs.array([0., 2., -0.1, 30., 400., 2.])
        point_1_matrix = vector_so3.matrix_from_rotation_vector(
            point_1[..., :3])
        point_2_matrix = vector_so3.matrix_from_rotation_vector(
            point_2[..., :3])
        # Edge case for the point, angle < epsilon,
        point_small = gs.array([-1e-7, 0., -7 * 1e-8, 6., 5., 9.])

        self.group = group
        self.matrix_so3 = matrix_so3
        self.matrix_se3 = matrix_se3

        self.left_diag_metric = left_diag_metric
        self.right_diag_metric = right_diag_metric
        self.left_metric = left_metric
        self.right_metric = right_metric
        self.matrix_left_metric = matrix_left_metric
        self.matrix_right_metric = matrix_right_metric
        self.point_1 = point_1
        self.point_2 = point_2
        self.point_1_matrix = point_1_matrix
        self.point_2_matrix = point_2_matrix
        self.point_small = point_small
class BiInvariantMetricTestData(_InvariantMetricTestData):
    dim_list = random.sample(range(2, 4), 2)
    metric_args_list = [(SpecialOrthogonal(dim), ) for dim in dim_list]
    shape_list = [(dim, dim) for dim in dim_list]
    space_list = [SpecialOrthogonal(dim) for dim in dim_list]
    n_points_list = random.sample(range(1, 4), 2)
    n_tangent_vecs_list = random.sample(range(1, 4), 2)
    n_points_a_list = random.sample(range(1, 4), 2)
    n_points_b_list = [1]
    batch_size_list = random.sample(range(2, 4), 2)
    alpha_list = [1] * 2
    n_rungs_list = [1] * 2
    scheme_list = ["pole"] * 2

    def exp_shape_test_data(self):
        return self._exp_shape_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
        )

    def log_shape_test_data(self):
        return self._log_shape_test_data(
            self.metric_args_list,
            self.space_list,
        )

    def squared_dist_is_symmetric_test_data(self):
        return self._squared_dist_is_symmetric_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
            atol=gs.atol * 1000,
        )

    def exp_belongs_test_data(self):
        return self._exp_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            belongs_atol=gs.atol * 1000,
        )

    def log_is_tangent_test_data(self):
        return self._log_is_tangent_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            is_tangent_atol=gs.atol * 1000,
        )

    def geodesic_ivp_belongs_test_data(self):
        return self._geodesic_ivp_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_points_list,
            belongs_atol=gs.atol * 1000,
        )

    def geodesic_bvp_belongs_test_data(self):
        return self._geodesic_bvp_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            belongs_atol=gs.atol * 1000,
        )

    def exp_after_log_test_data(self):
        return self._exp_after_log_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            rtol=gs.rtol * 10000,
            atol=gs.atol * 10000,
        )

    def log_after_exp_test_data(self):
        return self._log_after_exp_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            rtol=gs.rtol * 10000,
            atol=gs.atol * 10000,
        )

    def exp_ladder_parallel_transport_test_data(self):
        return self._exp_ladder_parallel_transport_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            self.n_rungs_list,
            self.alpha_list,
            self.scheme_list,
        )

    def exp_geodesic_ivp_test_data(self):
        return self._exp_geodesic_ivp_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            self.n_points_list,
            rtol=gs.rtol * 100,
            atol=gs.atol * 100,
        )

    def parallel_transport_ivp_is_isometry_test_data(self):
        return self._parallel_transport_ivp_is_isometry_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            is_tangent_atol=gs.atol * 1000,
            rtol=gs.rtol * 1000,
            atol=gs.atol * 1000,
        )

    def parallel_transport_bvp_is_isometry_test_data(self):
        return self._parallel_transport_bvp_is_isometry_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            is_tangent_atol=gs.atol * 1000,
            rtol=gs.rtol * 1000,
            atol=gs.atol * 1000,
        )

    def dist_is_symmetric_test_data(self):
        return self._dist_is_symmetric_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def dist_is_positive_test_data(self):
        return self._dist_is_positive_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def squared_dist_is_positive_test_data(self):
        return self._squared_dist_is_positive_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def dist_is_norm_of_log_test_data(self):
        return self._dist_is_norm_of_log_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def dist_point_to_itself_is_zero_test_data(self):
        return self._dist_point_to_itself_is_zero_test_data(
            self.metric_args_list, self.space_list, self.n_points_list)

    def inner_product_is_symmetric_test_data(self):
        return self._inner_product_is_symmetric_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
        )

    def triangle_inequality_of_dist_test_data(self):
        return self._triangle_inequality_of_dist_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            atol=gs.atol * 100000,
        )

    def exp_at_identity_of_lie_algebra_belongs_test_data(self):
        return self._exp_at_identity_of_lie_algebra_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_tangent_vecs_list,
            belongs_atol=gs.atol * 100,
        )

    def log_at_identity_belongs_to_lie_algebra_test_data(self):
        return self._log_at_identity_belongs_to_lie_algebra_test_data(
            self.metric_args_list, self.space_list, self.n_points_list)

    def exp_after_log_at_identity_test_data(self):
        return self._exp_after_log_at_identity_test_data(
            self.metric_args_list, self.space_list, self.n_points_list)

    def log_after_exp_at_identity_test_data(self):
        return self._log_after_exp_at_identity_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            amplitude=100.0,
            atol=1e-2,
        )

    def exp_after_log_intrinsic_ball_extrinsic_test_data(self):
        smoke_data = [
            dict(
                dim=2,
                x_intrinsic=gs.array([4.0, 0.2]),
                y_intrinsic=gs.array([3.0, 3]),
            )
        ]
        return self.generate_tests(smoke_data)

    def squared_dist_is_less_than_squared_pi_test_data(self):
        smoke_data = []
        for angle_type_1, angle_type_2 in zip(elements, elements):
            smoke_data += [
                dict(point_1=elements[angle_type_1],
                     point_2=elements[angle_type_2])
            ]
        return self.generate_tests(smoke_data)

    def exp_test_data(self):
        theta = gs.pi / 5.0
        rot_vec_base_point = theta / gs.sqrt(3.0) * gs.array([1.0, 1.0, 1.0])
        rot_vec_2 = gs.pi / 4 * gs.array([1.0, 0.0, 0.0])
        phi = (gs.pi / 10) / (gs.tan(gs.array(gs.pi / 10)))
        skew = gs.array([[0.0, -1.0, 1.0], [1.0, 0.0, -1.0], [-1.0, 1.0, 0.0]])
        jacobian = (phi * gs.eye(3) + (1 - phi) / 3 * gs.ones([3, 3]) + gs.pi /
                    (10 * gs.sqrt(3.0)) * skew)
        inv_jacobian = gs.linalg.inv(jacobian)
        expected = SpecialOrthogonal(3, "vector").compose(
            (gs.pi / 5.0) / gs.sqrt(3.0) * gs.array([1.0, 1.0, 1.0]),
            gs.dot(inv_jacobian, rot_vec_2),
        )
        smoke_data = [
            dict(
                tangent_vec=gs.array([0.0, 0.0, 0.0]),
                base_point=rot_vec_base_point,
                expected=rot_vec_base_point,
            ),
            dict(
                tangent_vec=rot_vec_2,
                base_point=rot_vec_base_point,
                expected=expected,
            ),
        ]
        return self.generate_tests(smoke_data)

    def log_test_data(self):
        theta = gs.pi / 5.0
        rot_vec_base_point = theta / gs.sqrt(3.0) * gs.array([1.0, 1.0, 1.0])
        # Note: the rotation vector for the reference point
        # needs to be regularized.

        # The Logarithm of a point at itself gives 0.
        expected = gs.array([0.0, 0.0, 0.0])

        # General case: this is the inverse test of test 1 for Riemannian exp
        expected = gs.pi / 4 * gs.array([1.0, 0.0, 0.0])
        phi = (gs.pi / 10) / (gs.tan(gs.array(gs.pi / 10)))
        skew = gs.array([[0.0, -1.0, 1.0], [1.0, 0.0, -1.0], [-1.0, 1.0, 0.0]])
        jacobian = (phi * gs.eye(3) + (1 - phi) / 3 * gs.ones([3, 3]) + gs.pi /
                    (10 * gs.sqrt(3.0)) * skew)
        inv_jacobian = gs.linalg.inv(jacobian)
        aux = gs.dot(inv_jacobian, expected)
        rot_vec_2 = SpecialOrthogonal(3, "vector").compose(
            rot_vec_base_point, aux)

        smoke_data = [
            dict(
                point=rot_vec_base_point,
                base_point=rot_vec_base_point,
                expected=gs.array([0.0, 0.0, 0.0]),
            ),
            dict(
                point=rot_vec_2,
                base_point=rot_vec_base_point,
                expected=expected,
            ),
        ]
        return self.generate_tests(smoke_data)

    def distance_broadcast_test_data(self):
        smoke_data = [dict(n=2)]
        return self.generate_tests(smoke_data)
Esempio n. 7
0
 def test_regularize(self, point, expected):
     group = SpecialOrthogonal(3, "vector")
     result = group.regularize(point)
     self.assertAllClose(result, expected)
 def shape_test_data(self):
     smoke_data = [
         dict(base=SpecialOrthogonal(3), power=2, shape=(2, 3, 3))
     ]
     return self.generate_tests(smoke_data)
Esempio n. 9
0
 def test_compose(self, n, point_type, point_a, point_b, expected):
     group = SpecialOrthogonal(n, point_type)
     result = group.compose(point_a, point_b)
     self.assertAllClose(result, expected)
Esempio n. 10
0
 def test_regularize(self, n, point_type, angle, expected):
     group = SpecialOrthogonal(n, point_type)
     result = group.regularize(angle)
     self.assertAllClose(result, expected)
Esempio n. 11
0
 def geodesic_invalid_initial_conditions_test_data(self):
     smoke_data = [dict(space=SpecialOrthogonal(n=4))]
     return self.generate_tests(smoke_data)
Esempio n. 12
0
 def setUp(self):
     self.se_mat = SpecialEuclidean(n=3, default_point_type='matrix')
     self.so_vec = SpecialOrthogonal(n=3, default_point_type='vector')
     self.so = SpecialOrthogonal(n=3, default_point_type='matrix')
     self.n_samples = 3
"""Plot the pole ladder scheme for parallel transport on S2.

Sample a point on S2 and two tangent vectors to transport one along the
other.
"""

import matplotlib.pyplot as plt

import geomstats.backend as gs
import geomstats.visualization as visualization
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.geometry.special_orthogonal import SpecialOrthogonal

SPACE = Hypersphere(2)
METRIC = SPACE.metric
ROTATIONS = SpecialOrthogonal(3, "vector")

N_STEPS = 4
N_POINTS = 10

gs.random.seed(1)


def main():
    """Compute pole ladder and plot the construction."""
    base_point = SPACE.random_uniform(1)
    tangent_vec_b = SPACE.random_uniform(1)
    tangent_vec_b = SPACE.to_tangent(tangent_vec_b, base_point)
    tangent_vec_b = tangent_vec_b / gs.linalg.norm(tangent_vec_b)

    rotation_vector = gs.pi / 2 * base_point
Esempio n. 14
0
    def setUp(self):
        warnings.simplefilter('ignore', category=ImportWarning)

        self.so3_group = SpecialOrthogonal(n=3)
        self.n_samples = 2
Esempio n. 15
0
 def test_exp(self, tangent_vec, base_point, expected):
     metric = self.metric(SpecialOrthogonal(3, "vector"))
     result = metric.exp(tangent_vec, base_point)
     self.assertAllClose(result, expected)
class NFoldManifoldTestData(_ManifoldTestData):
    n_list = random.sample(range(2, 4), 2)
    base_list = [SpecialOrthogonal(n) for n in n_list]
    power_list = random.sample(range(2, 4), 2)
    space_args_list = list(zip(base_list, power_list))
    shape_list = [(power, n, n) for n, power in zip(n_list, power_list)]
    n_points_list = random.sample(range(2, 5), 2)
    n_vecs_list = random.sample(range(2, 5), 2)

    def belongs_test_data(self):
        smoke_data = [
            dict(
                base=SpecialOrthogonal(3),
                power=2,
                point=gs.stack([gs.eye(3) + 1.0, gs.eye(3)])[None],
                expected=gs.array(False),
            ),
            dict(
                base=SpecialOrthogonal(3),
                power=2,
                point=gs.array([gs.eye(3), gs.eye(3)]),
                expected=gs.array(True),
            ),
        ]
        return self.generate_tests(smoke_data)

    def shape_test_data(self):
        smoke_data = [
            dict(base=SpecialOrthogonal(3), power=2, shape=(2, 3, 3))
        ]
        return self.generate_tests(smoke_data)

    def random_point_belongs_test_data(self):
        smoke_space_args_list = [
            (SpecialOrthogonal(2), 2),
            (SpecialOrthogonal(2), 2),
        ]
        smoke_n_points_list = [1, 2]
        return self._random_point_belongs_test_data(
            smoke_space_args_list,
            smoke_n_points_list,
            self.space_args_list,
            self.n_points_list,
        )

    def projection_belongs_test_data(self):
        return self._projection_belongs_test_data(
            self.space_args_list,
            self.shape_list,
            self.n_points_list,
            belongs_atol=1e-1,
        )

    def to_tangent_is_tangent_test_data(self):
        return self._to_tangent_is_tangent_test_data(
            NFoldManifold,
            self.space_args_list,
            self.shape_list,
            self.n_vecs_list,
            is_tangent_atol=gs.atol * 1000,
        )

    def random_tangent_vec_is_tangent_test_data(self):
        return self._random_tangent_vec_is_tangent_test_data(
            NFoldManifold, self.space_args_list, self.n_vecs_list)
Esempio n. 17
0
 def test_log(self, point, base_point, expected):
     metric = self.metric(SpecialOrthogonal(3, "vector"))
     result = metric.log(point, base_point)
     self.assertAllClose(result, expected)
class NFoldMetricTestData(_RiemannianMetricTestData):

    n_list = random.sample(range(3, 5), 2)
    power_list = random.sample(range(2, 5), 2)
    base_list = [SpecialOrthogonal(n) for n in n_list]
    metric_args_list = [(base.metric, power)
                        for base, power in zip(base_list, power_list)]
    shape_list = [(power, n, n) for n, power in zip(n_list, power_list)]
    space_list = [
        NFoldManifold(base, power)
        for base, power in zip(base_list, power_list)
    ]
    n_points_list = random.sample(range(2, 5), 2)
    n_tangent_vecs_list = random.sample(range(2, 5), 2)
    n_points_a_list = random.sample(range(2, 5), 2)
    n_points_b_list = [1]
    alpha_list = [1] * 2
    n_rungs_list = [1] * 2
    scheme_list = ["pole"] * 2

    def exp_shape_test_data(self):
        return self._exp_shape_test_data(self.metric_args_list,
                                         self.space_list, self.shape_list)

    def log_shape_test_data(self):
        return self._log_shape_test_data(self.metric_args_list,
                                         self.space_list)

    def squared_dist_is_symmetric_test_data(self):
        return self._squared_dist_is_symmetric_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
            atol=gs.atol * 1000,
        )

    def exp_belongs_test_data(self):
        return self._exp_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            belongs_atol=gs.atol * 1000,
        )

    def log_is_tangent_test_data(self):
        return self._log_is_tangent_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            is_tangent_atol=1e-1,
        )

    def geodesic_ivp_belongs_test_data(self):
        return self._geodesic_ivp_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_points_list,
            belongs_atol=gs.atol * 100000,
        )

    def geodesic_bvp_belongs_test_data(self):
        return self._geodesic_bvp_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            belongs_atol=gs.atol * 100000,
        )

    def exp_after_log_test_data(self):
        return self._exp_after_log_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            rtol=gs.rtol * 10000,
            atol=1e-1,
        )

    def log_after_exp_test_data(self):
        return self._log_after_exp_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            amplitude=10.0,
            rtol=gs.rtol * 10000,
            atol=1e-1,
        )

    def exp_ladder_parallel_transport_test_data(self):
        return self._exp_ladder_parallel_transport_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            self.n_rungs_list,
            self.alpha_list,
            self.scheme_list,
        )

    def exp_geodesic_ivp_test_data(self):
        return self._exp_geodesic_ivp_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            self.n_points_list,
            rtol=gs.rtol * 100000,
            atol=gs.atol * 100000,
        )

    def parallel_transport_ivp_is_isometry_test_data(self):
        return self._parallel_transport_ivp_is_isometry_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            is_tangent_atol=gs.atol * 1000,
            atol=gs.atol * 1000,
        )

    def parallel_transport_bvp_is_isometry_test_data(self):
        return self._parallel_transport_bvp_is_isometry_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            is_tangent_atol=gs.atol * 1000,
            atol=gs.atol * 1000,
        )

    def dist_is_symmetric_test_data(self):
        print()
        return self._dist_is_symmetric_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def dist_is_positive_test_data(self):
        return self._dist_is_positive_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def squared_dist_is_positive_test_data(self):
        return self._squared_dist_is_positive_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def dist_is_norm_of_log_test_data(self):
        return self._dist_is_norm_of_log_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def dist_point_to_itself_is_zero_test_data(self):
        return self._dist_point_to_itself_is_zero_test_data(
            self.metric_args_list, self.space_list, self.n_points_list)

    def inner_product_is_symmetric_test_data(self):
        return self._inner_product_is_symmetric_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
        )

    def inner_product_shape_test_data(self):
        space = NFoldManifold(SpecialOrthogonal(3), 2)
        n_samples = 4
        point = gs.stack([gs.eye(3)] * space.n_copies * n_samples)
        point = gs.reshape(point, (n_samples, *space.shape))
        tangent_vec = space.to_tangent(gs.zeros((n_samples, *space.shape)),
                                       point)
        smoke_data = [
            dict(space=space,
                 n_samples=4,
                 point=point,
                 tangent_vec=tangent_vec)
        ]
        return self.generate_tests(smoke_data)
Esempio n. 19
0
 def test_matrix_from_rotation_vector(self, n, rot_vec, expected):
     group = SpecialOrthogonal(n)
     result = group.matrix_from_rotation_vector(rot_vec)
     self.assertAllClose(result, expected)
Esempio n. 20
0
"""
Predict on manifolds: losses.
"""

import logging

import geomstats.backend as gs
import geomstats.geometry.lie_group as lie_group
from geomstats.geometry.special_orthogonal import SpecialOrthogonal

SO3 = SpecialOrthogonal(n=3)


def loss(y_pred,
         y_true,
         metric=SO3.bi_invariant_metric,
         representation='vector'):

    if representation == 'quaternion':
        y_pred = SO3.rotation_vector_from_quaternion(y_pred)
        y_true = SO3.rotation_vector_from_quaternion(y_true)

    loss = lie_group.loss(y_pred, y_true, SO3, metric)
    return loss


def grad(y_pred,
         y_true,
         metric=SO3.bi_invariant_metric,
         representation='vector'):
Esempio n. 21
0
 def test_compose_with_inverse_is_identity(self, space_args):
     group = SpecialOrthogonal(*space_args)
     point = gs.squeeze(group.random_point())
     inv_point = group.inverse(point)
     self.assertAllClose(group.compose(point, inv_point), group.identity)
Esempio n. 22
0
"""Perform tangent PCA at the mean on SO(3)."""

import logging

import matplotlib.pyplot as plt
import numpy as np

import geomstats.visualization as visualization
from geomstats.geometry.special_orthogonal import SpecialOrthogonal
from geomstats.learning.frechet_mean import FrechetMean
from geomstats.learning.pca import TangentPCA

SO3_GROUP = SpecialOrthogonal(n=3, point_type="vector")
METRIC = SO3_GROUP.bi_invariant_metric

N_SAMPLES = 10
N_COMPONENTS = 2


def main():
    """Perform tangent PCA at the mean on SO(3)."""
    fig = plt.figure(figsize=(15, 5))

    data = SO3_GROUP.random_uniform(n_samples=N_SAMPLES)

    mean = FrechetMean(metric=METRIC)
    mean.fit(data)

    mean_estimate = mean.estimate_

    tpca = TangentPCA(metric=METRIC, n_components=N_COMPONENTS)
Esempio n. 23
0
 def inner_product_test_data(self):
     group = SpecialOrthogonal(n=3)
     algebra = group.lie_algebra
     tangent_vec_a = algebra.matrix_representation(
         gs.array([1.0, 0, 2.0]))
     tangent_vec_b = algebra.matrix_representation(
         gs.array([1.0, 0, 0.5]))
     batch_tangent_vec = algebra.matrix_representation(
         gs.array([[1.0, 0, 2.0], [0, 3.0, 5.0]]))
     smoke_data = [
         dict(
             group=group,
             metric_mat_at_identity=None,
             left_or_right="left",
             tangent_vec_a=tangent_vec_a,
             tangent_vec_b=tangent_vec_b,
             base_point=None,
             expected=4.0,
         ),
         dict(
             group=group,
             metric_mat_at_identity=None,
             left_or_right="left",
             tangent_vec_a=batch_tangent_vec,
             tangent_vec_b=tangent_vec_b,
             base_point=None,
             expected=gs.array([4.0, 5.0]),
         ),
         dict(
             group=group,
             metric_mat_at_identity=None,
             left_or_right="left",
             tangent_vec_a=group.compose(self.point_1_matrix,
                                         tangent_vec_a),
             tangent_vec_b=group.compose(self.point_1_matrix,
                                         tangent_vec_b),
             base_point=self.point_1_matrix,
             expected=4.0,
         ),
         dict(
             group=group,
             metric_mat_at_identity=None,
             left_or_right="left",
             tangent_vec_a=group.compose(self.point_1_matrix,
                                         batch_tangent_vec),
             tangent_vec_b=group.compose(self.point_1_matrix,
                                         tangent_vec_b),
             base_point=self.point_1_matrix,
             expected=gs.array([4.0, 5.0]),
         ),
         dict(
             group=group,
             metric_mat_at_identity=None,
             left_or_right="right",
             tangent_vec_a=group.compose(tangent_vec_a,
                                         self.point_1_matrix),
             tangent_vec_b=group.compose(tangent_vec_b,
                                         self.point_1_matrix),
             base_point=self.point_1_matrix,
             expected=4.0,
         ),
         dict(
             group=group,
             metric_mat_at_identity=None,
             left_or_right="right",
             tangent_vec_a=group.compose(batch_tangent_vec,
                                         self.point_1_matrix),
             tangent_vec_b=group.compose(tangent_vec_b,
                                         self.point_1_matrix),
             base_point=self.point_1_matrix,
             expected=gs.array([4.0, 5.0]),
         ),
     ]
     return self.generate_tests(smoke_data)
Esempio n. 24
0
with our BCH-implementation, for small orders approximation by BCH is faster
than the scikit-learn version, while being close to the actual value.

"""
import timeit

import matplotlib.pyplot as plt

import geomstats.backend as gs
from geomstats.geometry.skew_symmetric_matrices import SkewSymmetricMatrices
from geomstats.geometry.special_orthogonal import SpecialOrthogonal

N = 3
MAX_ORDER = 10

GROUP = SpecialOrthogonal(n=N)
GROUP.default_point_type = 'matrix'

DIM = int(N * (N - 1) / 2)
ALGEBRA = SkewSymmetricMatrices(n=N)


def main():
    norm_rv_1 = gs.random.normal(size=DIM)
    tan_rv_1 = ALGEBRA.matrix_representation(
        norm_rv_1 / gs.linalg.norm(norm_rv_1, axis=0) / 2)
    exp_1 = gs.linalg.expm(tan_rv_1)

    norm_rv_2 = gs.random.normal(size=DIM)
    tan_rv_2 = ALGEBRA.matrix_representation(
        norm_rv_2 / gs.linalg.norm(norm_rv_2, axis=0) / 2)
Esempio n. 25
0
    class InvariantMetricTestData(_RiemannianMetricTestData):
        group = SpecialEuclidean(n=3, point_type="vector")
        matrix_se3 = SpecialEuclidean(n=3)
        matrix_so3 = SpecialOrthogonal(n=3)
        vector_so3 = SpecialOrthogonal(n=3, point_type="vector")
        point_1 = gs.array([-0.2, 0.9, 0.5, 5.0, 5.0, 5.0])
        point_2 = gs.array([0.0, 2.0, -0.1, 30.0, 400.0, 2.0])
        point_1_matrix = vector_so3.matrix_from_rotation_vector(
            point_1[..., :3])
        point_2_matrix = vector_so3.matrix_from_rotation_vector(
            point_2[..., :3])
        # Edge case for the point, angle < epsilon,
        point_small = gs.array([-1e-7, 0.0, -7 * 1e-8, 6.0, 5.0, 9.0])

        diag_mat_at_identity = gs.eye(group.dim)
        metric_args_list = [
            (group, None, "left"),
            (group, None, "right"),
            (group, gs.eye(group.dim), "left"),
            (group, gs.eye(group.dim), "right"),
            (matrix_so3, None, "right"),
            (matrix_so3, None, "left"),
        ]
        shape_list = [metric_args[0].shape for metric_args in metric_args_list]
        space_list = [metric_args[0] for metric_args in metric_args_list]
        n_points_list = [1, 2] * 3
        n_tangent_vecs_list = [1, 2] * 3
        n_points_a_list = [1, 2] * 3
        n_points_b_list = [1]
        alpha_list = [1] * 6
        n_rungs_list = [1] * 6
        scheme_list = ["pole"] * 6

        def inner_product_mat_at_identity_shape_test_data(self):
            group = SpecialEuclidean(n=3, point_type="vector")
            sym_mat_at_identity = gs.eye(group.dim)
            smoke_data = [
                dict(
                    group=group,
                    metric_mat_at_identity=sym_mat_at_identity,
                    left_or_right="left",
                )
            ]
            return self.generate_tests(smoke_data)

        def inner_product_matrix_shape_test_data(self):
            group = SpecialEuclidean(n=3, point_type="vector")
            sym_mat_at_identity = gs.eye(group.dim)
            smoke_data = [
                dict(
                    group=group,
                    metric_mat_at_identity=sym_mat_at_identity,
                    left_or_right="left",
                    base_point=None,
                ),
                dict(
                    group=group,
                    metric_mat_at_identity=sym_mat_at_identity,
                    left_or_right="left",
                    base_point=group.identity,
                ),
            ]
            return self.generate_tests(smoke_data)

        def inner_product_matrix_and_its_inverse_test_data(self):
            group = SpecialEuclidean(n=3, point_type="vector")
            smoke_data = [
                dict(group=group,
                     metric_mat_at_identity=None,
                     left_or_right="left")
            ]
            return self.generate_tests(smoke_data)

        def inner_product_test_data(self):
            group = SpecialOrthogonal(n=3)
            algebra = group.lie_algebra
            tangent_vec_a = algebra.matrix_representation(
                gs.array([1.0, 0, 2.0]))
            tangent_vec_b = algebra.matrix_representation(
                gs.array([1.0, 0, 0.5]))
            batch_tangent_vec = algebra.matrix_representation(
                gs.array([[1.0, 0, 2.0], [0, 3.0, 5.0]]))
            smoke_data = [
                dict(
                    group=group,
                    metric_mat_at_identity=None,
                    left_or_right="left",
                    tangent_vec_a=tangent_vec_a,
                    tangent_vec_b=tangent_vec_b,
                    base_point=None,
                    expected=4.0,
                ),
                dict(
                    group=group,
                    metric_mat_at_identity=None,
                    left_or_right="left",
                    tangent_vec_a=batch_tangent_vec,
                    tangent_vec_b=tangent_vec_b,
                    base_point=None,
                    expected=gs.array([4.0, 5.0]),
                ),
                dict(
                    group=group,
                    metric_mat_at_identity=None,
                    left_or_right="left",
                    tangent_vec_a=group.compose(self.point_1_matrix,
                                                tangent_vec_a),
                    tangent_vec_b=group.compose(self.point_1_matrix,
                                                tangent_vec_b),
                    base_point=self.point_1_matrix,
                    expected=4.0,
                ),
                dict(
                    group=group,
                    metric_mat_at_identity=None,
                    left_or_right="left",
                    tangent_vec_a=group.compose(self.point_1_matrix,
                                                batch_tangent_vec),
                    tangent_vec_b=group.compose(self.point_1_matrix,
                                                tangent_vec_b),
                    base_point=self.point_1_matrix,
                    expected=gs.array([4.0, 5.0]),
                ),
                dict(
                    group=group,
                    metric_mat_at_identity=None,
                    left_or_right="right",
                    tangent_vec_a=group.compose(tangent_vec_a,
                                                self.point_1_matrix),
                    tangent_vec_b=group.compose(tangent_vec_b,
                                                self.point_1_matrix),
                    base_point=self.point_1_matrix,
                    expected=4.0,
                ),
                dict(
                    group=group,
                    metric_mat_at_identity=None,
                    left_or_right="right",
                    tangent_vec_a=group.compose(batch_tangent_vec,
                                                self.point_1_matrix),
                    tangent_vec_b=group.compose(tangent_vec_b,
                                                self.point_1_matrix),
                    base_point=self.point_1_matrix,
                    expected=gs.array([4.0, 5.0]),
                ),
            ]
            return self.generate_tests(smoke_data)

        def log_antipodals_test_data(self):
            group = self.matrix_so3
            smoke_data = [
                dict(
                    group=group,
                    rotation_mat1=gs.eye(3),
                    rotation_mat2=gs.array([[1.0, 0.0, 0.0], [0.0, -1.0, 0.0],
                                            [0.0, 0.0, -1.0]]),
                    expected=pytest.raises(ValueError),
                )
            ]
            return self.generate_tests(smoke_data)

        def structure_constant_test_data(self):
            group = self.matrix_so3
            metric = InvariantMetric(group)
            x, y, z = metric.normal_basis(group.lie_algebra.basis)
            smoke_data = []
            smoke_data += [
                dict(
                    group=self.matrix_so3,
                    tangent_vec_a=x,
                    tangent_vec_b=y,
                    tangent_vec_c=z,
                    expected=2.0**0.5 / 2.0,
                )
            ]
            smoke_data += [
                dict(
                    group=self.matrix_so3,
                    tangent_vec_a=y,
                    tangent_vec_b=x,
                    tangent_vec_c=z,
                    expected=-(2.0**0.5 / 2.0),
                )
            ]
            smoke_data += [
                dict(
                    group=self.matrix_so3,
                    tangent_vec_a=y,
                    tangent_vec_b=z,
                    tangent_vec_c=x,
                    expected=2.0**0.5 / 2.0,
                )
            ]
            smoke_data += [
                dict(
                    group=self.matrix_so3,
                    tangent_vec_a=z,
                    tangent_vec_b=y,
                    tangent_vec_c=x,
                    expected=-(2.0**0.5 / 2.0),
                )
            ]
            smoke_data += [
                dict(
                    group=self.matrix_so3,
                    tangent_vec_a=z,
                    tangent_vec_b=x,
                    tangent_vec_c=y,
                    expected=2.0**0.5 / 2.0,
                )
            ]
            smoke_data += [
                dict(
                    group=self.matrix_so3,
                    tangent_vec_a=x,
                    tangent_vec_b=z,
                    tangent_vec_c=y,
                    expected=-(2.0**0.5 / 2.0),
                )
            ]

            for x, y in itertools.permutations((x, y, z), 2):
                smoke_data += [
                    dict(
                        group=self.matrix_so3,
                        tangent_vec_a=x,
                        tangent_vec_b=x,
                        tangent_vec_c=y,
                        expected=0.0,
                    )
                ]

            return self.generate_tests(smoke_data)

        def dual_adjoint_structure_constant_test_data(self):
            group = self.matrix_so3
            metric = InvariantMetric(group)
            x, y, z = metric.normal_basis(group.lie_algebra.basis)
            smoke_data = []
            for x, y, z in itertools.permutations((x, y, z)):
                smoke_data += [
                    dict(
                        group=group,
                        tangent_vec_a=x,
                        tangent_vec_b=y,
                        tangent_vec_c=z,
                    )
                ]

            return self.generate_tests(smoke_data)

        def connection_test_data(self):
            group = self.matrix_so3
            metric = InvariantMetric(group)
            x, y, z = metric.normal_basis(group.lie_algebra.basis)
            smoke_data = [
                dict(
                    group=group,
                    tangent_vec_a=x,
                    tangent_vec_b=y,
                    expected=1.0 / 2**0.5 / 2.0 * z,
                )
            ]
            return self.generate_tests(smoke_data)

        def connection_translation_map_test_data(self):
            group = self.matrix_so3
            metric = InvariantMetric(group)
            x, y, z = metric.normal_basis(group.lie_algebra.basis)
            smoke_data = [
                dict(
                    group=group,
                    tangent_vec_a=x,
                    tangent_vec_b=y,
                    point=group.random_point(),
                    expected=1.0 / 2**0.5 / 2.0 * z,
                )
            ]
            return self.generate_tests(smoke_data)

        def sectional_curvature_test_data(self):
            group = self.matrix_so3
            metric = InvariantMetric(group)

            x, y, z = metric.normal_basis(group.lie_algebra.basis)
            smoke_data = [
                dict(group=group,
                     tangent_vec_a=x,
                     tangent_vec_b=y,
                     expected=1.0 / 8),
                dict(group=group,
                     tangent_vec_a=y,
                     tangent_vec_b=y,
                     expected=0.0),
                dict(
                    group=group,
                    tangent_vec_a=gs.stack([x, y]),
                    tangent_vec_b=gs.stack([z] * 2),
                    expected=gs.array([1.0 / 8, 1.0 / 8]),
                ),
            ]
            return self.generate_tests(smoke_data)

        def sectional_curvature_translation_point_test_data(self):
            return self.connection_translation_map_test_data()

        def curvature_test_data(self):
            group = self.matrix_so3
            metric = InvariantMetric(group)
            x, y, z = metric.normal_basis(group.lie_algebra.basis)
            smoke_data = [
                dict(
                    group=group,
                    tangent_vec_a=x,
                    tangent_vec_b=y,
                    tangent_vec_c=x,
                    expected=1.0 / 8 * y,
                ),
                dict(
                    group=group,
                    tangent_vec_a=gs.stack([x, x]),
                    tangent_vec_b=gs.stack([y] * 2),
                    tangent_vec_c=gs.stack([x, x]),
                    expected=gs.array([1.0 / 8 * y] * 2),
                ),
                dict(
                    group=group,
                    tangent_vec_a=y,
                    tangent_vec_b=y,
                    tangent_vec_c=z,
                    expected=gs.zeros_like(z),
                ),
            ]
            return self.generate_tests(smoke_data)

        def curvature_translation_point_test_data(self):
            group = self.matrix_so3
            metric = InvariantMetric(group)
            x, y, _ = metric.normal_basis(group.lie_algebra.basis)

            smoke_data = [
                dict(
                    group=group,
                    tangent_vec_a=x,
                    tangent_vec_b=y,
                    tangent_vec_c=x,
                    point=group.random_point(),
                    expected=1.0 / 8 * y,
                )
            ]
            return self.generate_tests(smoke_data)

        def curvature_derivative_at_identity_test_data(self):
            group = self.matrix_so3
            metric = InvariantMetric(group)
            basis = metric.normal_basis(group.lie_algebra.basis)
            smoke_data = []
            for x in basis:
                for i, y in enumerate(basis):
                    for z in basis[i:]:
                        for t in basis:
                            smoke_data.append(
                                dict(
                                    group=group,
                                    tangent_vec_a=x,
                                    tangent_vec_b=y,
                                    tangent_vec_c=z,
                                    tangent_vec_d=t,
                                    expected=gs.zeros_like(x),
                                ))

            return self.generate_tests(smoke_data)

        def curvature_derivative_tangent_translation_map_test_data(self):
            group = self.matrix_so3
            metric = InvariantMetric(group=group)
            x, y, z = metric.normal_basis(group.lie_algebra.basis)
            smoke_data = [
                dict(
                    group=group,
                    tangent_vec_a=x,
                    tangent_vec_b=y,
                    tangent_vec_c=z,
                    tangent_vec_d=x,
                    base_point=group.random_point(),
                    expected=gs.zeros_like(x),
                )
            ]
            return self.generate_tests(smoke_data)

        def integrated_exp_at_id_test_data(self):

            smoke_data = [dict(group=self.matrix_so3)]
            return self.generate_tests(smoke_data)

        def integrated_se3_exp_at_id_test_data(self):
            smoke_data = [dict(group=self.matrix_se3)]
            return self.generate_tests(smoke_data)

        def integrated_exp_and_log_at_id_test_data(self):
            smoke_data = [dict(group=self.matrix_so3)]
            return self.generate_tests(smoke_data)

        def integrated_parallel_transport_test_data(self):
            smoke_data = [dict(group=self.matrix_se3, n=3, n_samples=20)]
            return self.generate_tests(smoke_data)

        def exp_shape_test_data(self):
            return self._exp_shape_test_data(self.metric_args_list,
                                             self.space_list, self.shape_list)

        def log_shape_test_data(self):
            return self._log_shape_test_data(self.metric_args_list,
                                             self.space_list)

        def squared_dist_is_symmetric_test_data(self):
            return self._squared_dist_is_symmetric_test_data(
                self.metric_args_list,
                self.space_list,
                self.n_points_a_list,
                self.n_points_b_list,
                atol=gs.atol * 1000,
            )

        def exp_belongs_test_data(self):
            return self._exp_belongs_test_data(
                self.metric_args_list,
                self.space_list,
                self.shape_list,
                self.n_tangent_vecs_list,
                belongs_atol=1e-2,
            )

        def log_is_tangent_test_data(self):
            return self._log_is_tangent_test_data(
                self.metric_args_list,
                self.space_list,
                self.n_points_list,
                is_tangent_atol=1e-2,
            )

        def geodesic_ivp_belongs_test_data(self):
            return self._geodesic_ivp_belongs_test_data(
                self.metric_args_list,
                self.space_list,
                self.shape_list,
                self.n_points_list,
                belongs_atol=gs.atol * 100000,
            )

        def geodesic_bvp_belongs_test_data(self):
            return self._geodesic_bvp_belongs_test_data(
                self.metric_args_list,
                self.space_list,
                self.n_points_list,
                belongs_atol=gs.atol * 100000,
            )

        def log_then_exp_test_data(self):
            return self._log_then_exp_test_data(
                self.metric_args_list,
                self.space_list,
                self.n_points_list,
                rtol=gs.rtol * 1000,
                atol=1e-1,
            )

        def exp_then_log_test_data(self):
            return self._exp_then_log_test_data(
                self.metric_args_list,
                self.space_list,
                self.shape_list,
                self.n_tangent_vecs_list,
                amplitude=1000,
                rtol=gs.rtol * 1000,
                atol=1e-1,
            )

        def exp_ladder_parallel_transport_test_data(self):
            return self._exp_ladder_parallel_transport_test_data(
                self.metric_args_list,
                self.space_list,
                self.shape_list,
                self.n_tangent_vecs_list,
                self.n_rungs_list,
                self.alpha_list,
                self.scheme_list,
            )

        def exp_geodesic_ivp_test_data(self):
            return self._exp_geodesic_ivp_test_data(
                self.metric_args_list,
                self.space_list,
                self.shape_list,
                self.n_tangent_vecs_list,
                self.n_points_list,
                rtol=gs.rtol * 100000,
                atol=gs.atol * 100000,
            )

        def parallel_transport_ivp_is_isometry_test_data(self):
            return self._parallel_transport_ivp_is_isometry_test_data(
                self.metric_args_list,
                self.space_list,
                self.shape_list,
                self.n_tangent_vecs_list,
                is_tangent_atol=gs.atol * 1000,
                atol=gs.atol * 1000,
            )

        def parallel_transport_bvp_is_isometry_test_data(self):
            return self._parallel_transport_bvp_is_isometry_test_data(
                self.metric_args_list,
                self.space_list,
                self.shape_list,
                self.n_tangent_vecs_list,
                is_tangent_atol=gs.atol * 1000,
                atol=gs.atol * 1000,
            )

        def dist_is_symmetric_test_data(self):
            return self._dist_is_symmetric_test_data(
                self.metric_args_list,
                self.space_list,
                self.n_points_a_list,
                self.n_points_b_list,
            )

        def dist_is_positive_test_data(self):
            return self._dist_is_positive_test_data(
                self.metric_args_list,
                self.space_list,
                self.n_points_a_list,
                self.n_points_b_list,
            )

        def squared_dist_is_positive_test_data(self):
            return self._squared_dist_is_positive_test_data(
                self.metric_args_list,
                self.space_list,
                self.n_points_a_list,
                self.n_points_b_list,
            )

        def dist_is_norm_of_log_test_data(self):
            return self._dist_is_norm_of_log_test_data(
                self.metric_args_list,
                self.space_list,
                self.n_points_a_list,
                self.n_points_b_list,
            )

        def dist_point_to_itself_is_zero_test_data(self):
            return self._dist_point_to_itself_is_zero_test_data(
                self.metric_args_list, self.space_list, self.n_points_list)

        def inner_product_is_symmetric_test_data(self):
            return self._inner_product_is_symmetric_test_data(
                self.metric_args_list,
                self.space_list,
                self.shape_list,
                self.n_tangent_vecs_list,
            )

        def exp_log_composition_at_identity_test_data(self):
            smoke_data = []
            for metric_args in self.metric_args_list[:4]:
                for tangent_vec in [self.point_1, self.point_small]:
                    smoke_data += [
                        dict(metric_args=metric_args, tangent_vec=tangent_vec)
                    ]
            return self.generate_tests(smoke_data)

        def log_exp_composition_at_identity_test_data(self):
            smoke_data = []
            for metric_args in self.metric_args_list[:4]:
                for point in [self.point_1, self.point_small]:
                    smoke_data += [dict(metric_args=metric_args, point=point)]
            return self.generate_tests(smoke_data)

        def left_exp_and_exp_from_identity_left_diag_metrics_test_data(self):
            smoke_data = [
                dict(metric_args=self.metric_args_list[0], point=self.point_1)
            ]
            return self.generate_tests(smoke_data)

        def left_log_and_log_from_identity_left_diag_metrics_test_data(self):
            smoke_data = [
                dict(metric_args=self.metric_args_list[0], point=self.point_1)
            ]
            return self.generate_tests(smoke_data)
 def setUp(self):
     self.n = 2
     self.group = SpecialOrthogonal(n=self.n)
     self.n_samples = 4
Esempio n. 27
0
    def setup_method(self):
        warnings.simplefilter("ignore", category=ImportWarning)

        self.so3_group = SpecialOrthogonal(n=3)
        self.n_samples = 2
 def test_dim(self):
     for n in [2, 3, 4, 5, 6]:
         group = SpecialOrthogonal(n=n)
         result = group.dim
         expected = n * (n - 1) / 2
         self.assertAllClose(result, expected)
Esempio n. 29
0
 def __init__(self, n):
     super(BuresWassersteinBundle, self).__init__(
         n=n,
         group=SpecialOrthogonal(n),
         ambient_metric=MatricesMetric(n, n),
     )