Exemple #1
0
class TestL2CurvesMetric(RiemannianMetricTestCase, metaclass=Parametrizer):
    metric = connection = L2CurvesMetric
    skip_test_exp_belongs = True
    skip_test_exp_shape = True
    skip_test_log_shape = True
    skip_test_exp_geodesic_ivp = True
    skip_test_parallel_transport_ivp_is_isometry = True
    skip_test_parallel_transport_bvp_is_isometry = True
    skip_test_dist_is_norm_of_log = tf_backend()
    skip_test_dist_is_symmetric = tf_backend()
    skip_test_squared_dist_is_symmetric = tf_backend()
    skip_test_inner_product_is_symmetric = tf_backend()

    testing_data = L2CurvesMetricTestData()

    def test_l2_metric_geodesic(
        self, ambient_manifold, curve_a, curve_b, times, n_sampling_points
    ):
        """Test the geodesic method of L2LandmarksMetric."""
        l2_metric_s2 = L2CurvesMetric(ambient_manifold=s2)
        curves_ab = l2_metric_s2.geodesic(curve_a, curve_b)
        curves_ab = curves_ab(times)

        result = curves_ab
        expected = []
        for k in range(n_sampling_points):
            geod = l2_metric_s2.ambient_metric.geodesic(
                initial_point=curve_a[k, :], end_point=curve_b[k, :]
            )
            expected.append(geod(times))
        expected = gs.stack(expected, axis=1)
        self.assertAllClose(result, expected)
class TestHermitianMetric(RiemannianMetricTestCase, metaclass=Parametrizer):
    metric = connection = HermitianMetric
    skip_test_exp = tf_backend()
    skip_test_log = tf_backend()
    skip_test_inner_product = tf_backend()
    skip_test_dist = geomstats.tests.tf_backend()
    skip_test_parallel_transport_ivp_is_isometry = True
    skip_test_parallel_transport_bvp_is_isometry = True
    skip_test_exp_geodesic_ivp = True

    testing_data = HermitianMetricTestData()

    def test_exp(self, dim, tangent_vec, base_point, expected):
        metric = HermitianMetric(dim)
        self.assertAllClose(
            metric.exp(gs.array(tangent_vec), gs.array(base_point)), gs.array(expected)
        )

    def test_log(self, dim, point, base_point, expected):
        metric = HermitianMetric(dim)
        self.assertAllClose(
            metric.log(gs.array(point), gs.array(base_point)), gs.array(expected)
        )

    def test_inner_product(self, dim, tangent_vec_a, tangent_vec_b, expected):
        metric = HermitianMetric(dim)
        self.assertAllClose(
            metric.inner_product(gs.array(tangent_vec_a), gs.array(tangent_vec_b)),
            gs.array(expected),
        )

    def test_squared_norm(self, dim, vec, expected):
        metric = HermitianMetric(dim)
        self.assertAllClose(metric.squared_norm(gs.array(vec)), gs.array(expected))

    def test_norm(self, dim, vec, expected):
        metric = HermitianMetric(dim)
        self.assertAllClose(metric.norm(gs.array(vec)), gs.array(expected))

    def test_metric_matrix(self, dim, expected):
        self.assertAllClose(HermitianMetric(dim).metric_matrix(), gs.array(expected))

    def test_squared_dist(self, dim, point_a, point_b, expected):
        metric = HermitianMetric(dim)
        result = metric.squared_dist(point_a, point_b)
        self.assertAllClose(result, gs.array(expected))

    def test_dist(self, dim, point_a, point_b, expected):
        metric = HermitianMetric(dim)
        result = metric.dist(point_a, point_b)
        self.assertAllClose(result, gs.array(expected))
class TestHermitian(VectorSpaceTestCase, metaclass=Parametrizer):
    space = Hermitian
    skip_test_basis_belongs = True
    skip_test_basis_cardinality = True
    skip_test_belongs = tf_backend()

    testing_data = HermitianTestData()

    def test_belongs(self, dim, vec, expected):
        self.assertAllClose(self.space(dim).belongs(gs.array(vec)), gs.array(expected))
Exemple #4
0
class TestClosedDiscreteCurves(ManifoldTestCase, metaclass=Parametrizer):
    space = ClosedDiscreteCurves
    skip_test_projection_belongs = tf_backend()
    skip_test_random_tangent_vec_is_tangent = True
    skip_test_to_tangent_is_tangent = True

    testing_data = ClosedDiscreteCurvesTestData()

    @geomstats.tests.np_and_autograd_only
    def test_projection_closed_curves(self, ambient_manifold, curve):
        planar_closed_curve = ClosedDiscreteCurves(ambient_manifold)
        proj = planar_closed_curve.projection(curve)
        expected = proj
        result = planar_closed_curve.projection(proj)
        self.assertAllClose(result, expected)

        result = proj[-1, :]
        expected = proj[0, :]
        self.assertAllClose(result, expected, rtol=10 * gs.rtol)
Exemple #5
0
class TestSRVMetric(RiemannianMetricTestCase, metaclass=Parametrizer):
    metric = connection = SRVMetric
    skip_test_exp_shape = True
    skip_test_log_shape = True
    skip_test_exp_geodesic_ivp = True
    skip_test_parallel_transport_ivp_is_isometry = True
    skip_test_parallel_transport_bvp_is_isometry = True
    skip_test_geodesic_bvp_belongs = True
    skip_test_geodesic_ivp_belongs = True
    skip_test_exp_after_log = tf_backend()
    skip_test_exp_belongs = tf_backend()
    skip_test_exp_ladder_parallel_transport = tf_backend()
    skip_test_inner_product_is_symmetric = tf_backend()
    skip_test_log_after_exp = tf_backend()
    skip_test_log_is_tangent = tf_backend()

    testing_data = SRVMetricTestData()

    def test_srv_inner_product(self, curve_a, curve_b, curve_c, times):
        l2_metric_s2 = L2CurvesMetric(ambient_manifold=s2)
        srv_metric_r3 = SRVMetric(ambient_manifold=r3)
        curves_ab = l2_metric_s2.geodesic(curve_a, curve_b)
        curves_bc = l2_metric_s2.geodesic(curve_b, curve_c)
        curves_ab = curves_ab(times)
        curves_bc = curves_bc(times)
        srvs_ab = srv_metric_r3.srv_transform(curves_ab)
        srvs_bc = srv_metric_r3.srv_transform(curves_bc)

        result = srv_metric_r3.l2_curves_metric.inner_product(srvs_ab, srvs_bc)
        products = srvs_ab * srvs_bc
        expected = [gs.sum(product) for product in products]
        expected = gs.array(expected) / (srvs_ab.shape[-2] + 1)
        self.assertAllClose(result, expected)

        result = result.shape
        expected = [srvs_ab.shape[0]]
        self.assertAllClose(result, expected)

    def test_srv_norm(self, curve_a, curve_b, times):
        l2_metric_s2 = L2CurvesMetric(ambient_manifold=s2)
        srv_metric_r3 = SRVMetric(ambient_manifold=r3)
        curves_ab = l2_metric_s2.geodesic(curve_a, curve_b)
        curves_ab = curves_ab(times)
        srvs_ab = srv_metric_r3.srv_transform(curves_ab)

        result = srv_metric_r3.l2_curves_metric.norm(srvs_ab)
        products = srvs_ab * srvs_ab
        sums = [gs.sum(product) for product in products]
        squared_norm = gs.array(sums) / (srvs_ab.shape[-2] + 1)
        expected = gs.sqrt(squared_norm)
        self.assertAllClose(result, expected)

        result = result.shape
        expected = [srvs_ab.shape[0]]
        self.assertAllClose(result, expected)

    def test_srv_transform_and_srv_transform_inverse(self, rtol, atol):
        """Test that srv and its inverse are inverse."""
        metric = SRVMetric(ambient_manifold=r3)
        curve = DiscreteCurves(r3).random_point(n_samples=2)

        srv = metric.srv_transform(curve)
        srv_inverse = metric.srv_transform_inverse(srv, curve[:, 0])

        result = srv.shape
        expected = (curve.shape[0], curve.shape[1] - 1, 3)
        self.assertAllClose(result, expected)

        result = srv_inverse
        expected = curve
        self.assertAllClose(result, expected, rtol, atol)

    @geomstats.tests.np_and_autograd_only
    def test_aux_differential_srv_transform(
        self, dim, n_sampling_points, n_curves, curve_fun_a
    ):
        """Test differential of square root velocity transform.
        Check that its value at (curve, tangent_vec) coincides
        with the derivative at zero of the square root velocity
        transform of a path of curves starting at curve with
        initial derivative tangent_vec.
        """
        srv_metric_r3 = SRVMetric(r3)
        sampling_times = gs.linspace(0.0, 1.0, n_sampling_points)
        curve_a = curve_fun_a(sampling_times)
        tangent_vec = gs.transpose(
            gs.tile(gs.linspace(1.0, 2.0, n_sampling_points), (dim, 1))
        )
        result = srv_metric_r3.aux_differential_srv_transform(tangent_vec, curve_a)

        times = gs.linspace(0.0, 1.0, n_curves)
        path_of_curves = curve_a + gs.einsum("i,jk->ijk", times, tangent_vec)
        srv_path = srv_metric_r3.srv_transform(path_of_curves)
        expected = n_curves * (srv_path[1] - srv_path[0])
        self.assertAllClose(result, expected, atol=1e-3, rtol=1e-3)

    @geomstats.tests.np_and_autograd_only
    def test_aux_differential_srv_transform_inverse(
        self, dim, n_sampling_points, curve_a
    ):
        """Test inverse of differential of square root velocity transform.
        Check that it is the inverse of aux_differential_srv_transform.
        """
        tangent_vec = gs.transpose(
            gs.tile(gs.linspace(0.0, 1.0, n_sampling_points), (dim, 1))
        )
        srv_metric_r3 = SRVMetric(r3)
        d_srv = srv_metric_r3.aux_differential_srv_transform(tangent_vec, curve_a)
        result = srv_metric_r3.aux_differential_srv_transform_inverse(d_srv, curve_a)
        expected = tangent_vec
        self.assertAllClose(result, expected, atol=1e-3, rtol=1e-3)

    def test_aux_differential_srv_transform_vectorization(
        self, dim, n_sampling_points, curve_a, curve_b
    ):
        """Test differential of square root velocity transform.
        Check vectorization.
        """
        dim = 3
        curves = gs.stack((curve_a, curve_b))
        tangent_vecs = gs.random.rand(2, n_sampling_points, dim)
        srv_metric_r3 = SRVMetric(r3)
        result = srv_metric_r3.aux_differential_srv_transform(tangent_vecs, curves)

        res_a = srv_metric_r3.aux_differential_srv_transform(tangent_vecs[0], curve_a)
        res_b = srv_metric_r3.aux_differential_srv_transform(tangent_vecs[1], curve_b)
        expected = gs.stack([res_a, res_b])
        self.assertAllClose(result, expected)

    def test_srv_inner_product_elastic(self, dim, n_sampling_points, curve_a):
        """Test inner product of SRVMetric.
        Check that the pullback metric gives an elastic metric
        with parameters a=1, b=1/2.
        """
        tangent_vec_a = gs.random.rand(n_sampling_points, dim)
        tangent_vec_b = gs.random.rand(n_sampling_points, dim)
        r3 = Euclidean(dim)
        srv_metric_r3 = SRVMetric(r3)
        result = srv_metric_r3.inner_product(tangent_vec_a, tangent_vec_b, curve_a)

        d_vec_a = (n_sampling_points - 1) * (
            tangent_vec_a[1:, :] - tangent_vec_a[:-1, :]
        )
        d_vec_b = (n_sampling_points - 1) * (
            tangent_vec_b[1:, :] - tangent_vec_b[:-1, :]
        )
        velocity_vec = (n_sampling_points - 1) * (curve_a[1:, :] - curve_a[:-1, :])
        velocity_norm = r3.metric.norm(velocity_vec)
        unit_velocity_vec = gs.einsum("ij,i->ij", velocity_vec, 1 / velocity_norm)
        a_param = 1
        b_param = 1 / 2
        integrand = (
            a_param**2 * gs.sum(d_vec_a * d_vec_b, axis=1)
            - (a_param**2 - b_param**2)
            * gs.sum(d_vec_a * unit_velocity_vec, axis=1)
            * gs.sum(d_vec_b * unit_velocity_vec, axis=1)
        ) / velocity_norm
        expected = gs.sum(integrand) / n_sampling_points
        self.assertAllClose(result, expected)

    def test_srv_inner_product_and_dist(self, dim, curve_a, curve_b):
        """Test that norm of log and dist coincide
        for curves with same / different starting points, and for
        the translation invariant / non invariant SRV metric.
        """
        r3 = Euclidean(dim=dim)
        curve_b_transl = curve_b + gs.array([1.0, 0.0, 0.0])
        curve_b = [curve_b, curve_b_transl]
        translation_invariant = [True, False]
        for curve in curve_b:
            for param in translation_invariant:
                srv_metric = SRVMetric(ambient_manifold=r3, translation_invariant=param)
                log = srv_metric.log(point=curve, base_point=curve_a)
                result = srv_metric.norm(vector=log, base_point=curve_a)
                expected = srv_metric.dist(curve_a, curve)
                self.assertAllClose(result, expected)

    def test_srv_inner_product_vectorization(
        self, dim, n_sampling_points, curve_a, curve_b
    ):
        """Test inner product of SRVMetric.
        Check vectorization.
        """
        curves = gs.stack((curve_a, curve_b))
        tangent_vecs_1 = gs.random.rand(2, n_sampling_points, dim)
        tangent_vecs_2 = gs.random.rand(2, n_sampling_points, dim)
        srv_metric_r3 = SRVMetric(r3)
        result = srv_metric_r3.inner_product(tangent_vecs_1, tangent_vecs_2, curves)

        res_a = srv_metric_r3.inner_product(
            tangent_vecs_1[0], tangent_vecs_2[0], curve_a
        )
        res_b = srv_metric_r3.inner_product(
            tangent_vecs_1[1], tangent_vecs_2[1], curve_b
        )
        expected = gs.stack((res_a, res_b))
        self.assertAllClose(result, expected)

    @geomstats.tests.np_autograd_and_torch_only
    def test_split_horizontal_vertical(
        self, times, n_discretized_curves, curve_a, curve_b
    ):
        """Test split horizontal vertical.
        Check that horizontal and vertical parts of any tangent
        vector are othogonal with respect to the SRVMetric inner
        product, and check vectorization.
        """
        srv_metric_r3 = SRVMetric(r3)
        quotient_srv_metric_r3 = DiscreteCurves(ambient_manifold=r3).quotient_srv_metric
        geod = srv_metric_r3.geodesic(initial_curve=curve_a, end_curve=curve_b)
        geod = geod(times)
        tangent_vec = n_discretized_curves * (geod[1, :, :] - geod[0, :, :])
        (
            tangent_vec_hor,
            tangent_vec_ver,
            _,
        ) = quotient_srv_metric_r3.split_horizontal_vertical(tangent_vec, curve_a)
        result = srv_metric_r3.inner_product(tangent_vec_hor, tangent_vec_ver, curve_a)
        expected = 0.0
        self.assertAllClose(result, expected, atol=1e-4)

        tangent_vecs = n_discretized_curves * (geod[1:] - geod[:-1])
        _, _, result = quotient_srv_metric_r3.split_horizontal_vertical(
            tangent_vecs, geod[:-1]
        )
        expected = []
        for i in range(n_discretized_curves - 1):
            _, _, res = quotient_srv_metric_r3.split_horizontal_vertical(
                tangent_vecs[i], geod[i]
            )
            expected.append(res)
        expected = gs.stack(expected)
        self.assertAllClose(result, expected)

    def test_space_derivative(
        self, dim, n_points, n_discretized_curves, n_sampling_points
    ):
        """Test space derivative.
        Check result on an example and vectorization.
        """
        n_points = 3
        dim = 3
        srv_metric_r3 = SRVMetric(Euclidean(dim))
        curve = gs.random.rand(n_points, dim)
        result = srv_metric_r3.space_derivative(curve)
        delta = 1 / n_points
        d_curve_1 = (curve[1] - curve[0]) / delta
        d_curve_2 = (curve[2] - curve[0]) / (2 * delta)
        d_curve_3 = (curve[2] - curve[1]) / delta
        expected = gs.squeeze(
            gs.vstack(
                (
                    gs.to_ndarray(d_curve_1, 2),
                    gs.to_ndarray(d_curve_2, 2),
                    gs.to_ndarray(d_curve_3, 2),
                )
            )
        )
        self.assertAllClose(result, expected)

        path_of_curves = gs.random.rand(n_discretized_curves, n_sampling_points, dim)
        result = srv_metric_r3.space_derivative(path_of_curves)
        expected = []
        for i in range(n_discretized_curves):
            expected.append(srv_metric_r3.space_derivative(path_of_curves[i]))
        expected = gs.stack(expected)
        self.assertAllClose(result, expected)

    def test_srv_metric_pointwise_inner_products(
        self, times, curve_a, curve_b, curve_c, n_discretized_curves, n_sampling_points
    ):
        l2_metric_s2 = L2CurvesMetric(ambient_manifold=s2)
        srv_metric_r3 = SRVMetric(ambient_manifold=r3)
        curves_ab = l2_metric_s2.geodesic(curve_a, curve_b)
        curves_bc = l2_metric_s2.geodesic(curve_b, curve_c)
        curves_ab = curves_ab(times)
        curves_bc = curves_bc(times)

        tangent_vecs = l2_metric_s2.log(point=curves_bc, base_point=curves_ab)
        result = srv_metric_r3.l2_curves_metric.pointwise_inner_products(
            tangent_vec_a=tangent_vecs, tangent_vec_b=tangent_vecs, base_curve=curves_ab
        )
        expected_shape = (n_discretized_curves, n_sampling_points)
        self.assertAllClose(gs.shape(result), expected_shape)

        result = srv_metric_r3.l2_curves_metric.pointwise_inner_products(
            tangent_vec_a=tangent_vecs[0],
            tangent_vec_b=tangent_vecs[0],
            base_curve=curves_ab[0],
        )
        expected_shape = (n_sampling_points,)
        self.assertAllClose(gs.shape(result), expected_shape)

    def test_srv_transform_and_inverse(self, times, curve_a, curve_b):
        """Test of SRVT and its inverse.
        N.B: Here curves_ab are seen as curves in R3 and not S2.
        """
        l2_metric_s2 = L2CurvesMetric(ambient_manifold=s2)
        srv_metric_r3 = SRVMetric(ambient_manifold=r3)
        curves_ab = l2_metric_s2.geodesic(curve_a, curve_b)
        curves_ab = curves_ab(times)

        curves = curves_ab
        srv_curves = srv_metric_r3.srv_transform(curves)
        starting_points = curves[:, 0, :]
        result = srv_metric_r3.srv_transform_inverse(srv_curves, starting_points)
        expected = curves

        self.assertAllClose(result, expected)

point_1 = gs.array([0.1, 0.2, 0.3])
point_2 = gs.array([0.5, 5.0, 60.0])

translation_large = gs.array([0.0, 5.0, 6.0])
translation_small = gs.array([0.0, 0.6, 0.7])

elements_all = {
    "translation_large": translation_large,
    "translation_small": translation_small,
    "point_1": point_1,
    "point_2": point_2,
}
elements = elements_all
if tf_backend():
    # Tf is extremely slow
    elements = {"point_1": point_1, "point_2": point_2}

elements_matrices_all = {
    key:
    SpecialEuclidean(2,
                     point_type="vector").matrix_from_vector(elements_all[key])
    for key in elements_all
}
elements_matrices = elements_matrices_all


class SpecialEuclideanTestData(_LieGroupTestData):
    n_list = random.sample(range(2, 4), 2)
    space_args_list = [(n, ) for n in n_list] + [(2, "vector"), (3, "vector")]
class TestSpecialEuclidean(LieGroupTestCase, metaclass=Parametrizer):

    space = group = SpecialEuclidean
    skip_test_log_after_exp = tf_backend()
    skip_test_exp_after_log = tf_backend()

    testing_data = SpecialEuclideanTestData()

    def test_belongs(self, n, mat, expected):
        self.assertAllClose(
            SpecialEuclidean(n).belongs(gs.array(mat)), gs.array(expected)
        )

    def test_random_point_belongs(self, n, n_samples):
        group = self.cls(n)
        self.assertTrue(gs.all(group(n).random_point(n_samples)))

    def test_identity(self, n, expected):
        self.assertAllClose(SpecialEuclidean(n).identity, gs.array(expected))

    def test_is_tangent(self, n, tangent_vec, base_point, expected):
        result = SpecialEuclidean(n).is_tangent(
            gs.array(tangent_vec), gs.array(base_point)
        )
        self.assertAllClose(result, gs.array(expected))

    def test_metrics_default_point_type(self, n, metric_str):
        group = self.space(n)
        self.assertTrue(getattr(group, metric_str).default_point_type == "matrix")

    def test_inverse_shape(self, n, points, expected):
        group = self.space(n)
        self.assertAllClose(gs.shape(group.inverse(points)), expected)

    def test_compose_shape(self, n, point_a, point_b, expected):
        group = self.space(n)
        result = gs.shape(group.compose(gs.array(point_a), gs.array(point_b)))
        self.assertAllClose(result, expected)

    def test_regularize_shape(self, n, point_type, n_samples):
        group = self.space(n, point_type)
        points = group.random_point(n_samples=n_samples)
        regularized_points = group.regularize(points)

        self.assertAllClose(
            gs.shape(regularized_points),
            (n_samples, *group.get_point_type_shape()),
        )

    def test_compose(self, n, point_type, point_1, point_2, expected):
        group = self.space(n, point_type)
        result = group.compose(point_1, point_2)
        self.assertAllClose(result, expected)

    def test_group_exp_from_identity(self, n, point_type, tangent_vec, expected):
        group = self.space(n, point_type)
        result = group.exp(base_point=group.identity, tangent_vec=tangent_vec)
        self.assertAllClose(result, expected)

    def test_group_log_from_identity(self, n, point_type, point, expected):
        group = self.space(n, point_type)
        result = group.log(base_point=group.identity, point=point)
        self.assertAllClose(result, expected)
class TestSpecialEuclidean(LieGroupTestCase, metaclass=Parametrizer):

    space = group = SpecialEuclidean
    skip_test_exp_then_log = tf_backend()
    skip_test_log_then_exp = tf_backend()

    class SpecialEuclideanTestData(_LieGroupTestData):
        n_list = random.sample(range(2, 4), 2)
        space_args_list = [(n,) for n in n_list] + [(2, "vector"), (3, "vector")]
        shape_list = [(n + 1, n + 1) for n in n_list] + [(3,)] + [(6,)]
        n_tangent_vecs_list = [2, 3] * 2
        n_points_list = [2, 3] * 2
        n_vecs_list = [2, 3] * 2

        def belongs_test_data(self):
            smoke_data = [
                dict(
                    n=2, mat=group_useful_matrix(gs.pi / 3, elem_33=1.0), expected=True
                ),
                dict(
                    n=2, mat=group_useful_matrix(gs.pi / 3, elem_33=0.0), expected=False
                ),
                dict(
                    n=2,
                    mat=[
                        group_useful_matrix(gs.pi / 3, elem_33=1.0),
                        group_useful_matrix(gs.pi / 3, elem_33=0.0),
                    ],
                    expected=[True, False],
                ),
            ]
            return self.generate_tests(smoke_data)

        def identity_test_data(self):
            smoke_data = [
                dict(n=2, expected=gs.eye(3)),
                dict(n=3, expected=gs.eye(4)),
                dict(n=10, expected=gs.eye(11)),
            ]
            return self.generate_tests(smoke_data)

        def is_tangent_test_data(self):
            theta = gs.pi / 3
            vec_1 = gs.array([[0.0, -theta, 2.0], [theta, 0.0, 3.0], [0.0, 0.0, 0.0]])
            vec_2 = gs.array([[0.0, -theta, 2.0], [theta, 0.0, 3.0], [0.0, 0.0, 1.0]])
            point = group_useful_matrix(theta)
            smoke_data = [
                dict(n=2, tangent_vec=point @ vec_1, base_point=point, expected=True),
                dict(n=2, tangent_vec=point @ vec_2, base_point=point, expected=False),
                dict(
                    n=2,
                    tangent_vec=[point @ vec_1, point @ vec_2],
                    base_point=point,
                    expected=[True, False],
                ),
            ]
            return self.generate_tests(smoke_data)

        def basis_representation_test_data(self):
            n_list = random.sample(range(2, 50), 10)
            n_samples = 100
            random_data = [
                dict(n=n, vec=gs.random.rand(n_samples, self.group.dim)) for n in n_list
            ]
            return self.generate_tests([], random_data)

        def metrics_default_point_type_test_data(self):
            n_list = random.sample(range(2, 5), 2)
            metric_str_list = [
                "left_canonical_metric",
                "right_canonical_metric",
                "metric",
            ]
            random_data = itertools.product(n_list, metric_str_list)
            return self.generate_tests([], random_data)

        def inverse_shape_test_data(self):
            n_list = random.sample(range(2, 50), 10)
            n_samples = 10
            random_data = [
                dict(
                    n=n,
                    points=SpecialEuclidean(n).random_point(n_samples),
                    expected=(n_samples, n + 1, n + 1),
                )
                for n in n_list
            ]
            return self.generate_tests([], random_data)

        def compose_shape_test_data(self):
            n_list = random.sample(range(2, 50), 10)
            n_samples = 10
            random_data = [
                dict(
                    n=n,
                    point_a=SpecialEuclidean(n).random_point(n_samples),
                    point_b=SpecialEuclidean(n).random_point(n_samples),
                    expected=(n_samples, n + 1, n + 1),
                )
                for n in n_list
            ]
            random_data += [
                dict(
                    n=n,
                    point_a=SpecialEuclidean(n).random_point(),
                    point_b=SpecialEuclidean(n).random_point(n_samples),
                    expected=(n_samples, n + 1, n + 1),
                )
                for n in n_list
            ]
            random_data += [
                dict(
                    n=n,
                    point_a=SpecialEuclidean(n).random_point(n_samples),
                    point_b=SpecialEuclidean(n).random_point(),
                    expected=(n_samples, n + 1, n + 1),
                )
                for n in n_list
            ]
            return self.generate_tests([], random_data)

        def random_point_belongs_test_data(self):
            smoke_space_args_list = [(2, True), (3, True), (2, False)]
            smoke_n_points_list = [1, 2, 1]
            return self._random_point_belongs_test_data(
                smoke_space_args_list,
                smoke_n_points_list,
                self.space_args_list,
                self.n_points_list,
            )

        def projection_belongs_test_data(self):
            return self._projection_belongs_test_data(
                self.space_args_list,
                self.shape_list,
                self.n_points_list,
                belongs_atol=1e-2,
            )

        def to_tangent_is_tangent_test_data(self):
            return self._to_tangent_is_tangent_test_data(
                SpecialEuclidean,
                self.space_args_list,
                self.shape_list,
                self.n_vecs_list,
            )

        def random_tangent_vec_is_tangent_test_data(self):
            return self._random_tangent_vec_is_tangent_test_data(
                SpecialEuclidean,
                self.space_args_list,
                self.n_vecs_list,
                is_tangent_atol=gs.atol * 100,
            )

        def exp_then_log_test_data(self):
            return self._exp_then_log_test_data(
                SpecialEuclidean,
                self.space_args_list,
                self.shape_list,
                self.n_tangent_vecs_list,
                amplitude=100,
                atol=gs.atol * 10000,
            )

        def log_then_exp_test_data(self):
            return self._log_then_exp_test_data(
                SpecialEuclidean,
                self.space_args_list,
                self.n_points_list,
                atol=gs.atol * 10000,
            )

        def regularize_test_data(self):
            smoke_data = [
                dict(
                    n=2,
                    point_type="vector",
                    point=elements_all["point_1"],
                    expected=elements_all["point_1"],
                )
            ]
            return self.generate_tests(smoke_data)

        def regularize_shape_test_data(self):
            smoke_data = [dict(n=2, point_type="vector", n_samples=3)]
            return self.generate_tests(smoke_data)

        def compose_inverse_point_with_point_is_identity_test_data(self):
            return self._compose_inverse_point_with_point_is_identity_test_data(
                SpecialEuclidean, self.space_args_list, self.n_points_list
            )

        def compose_point_with_inverse_point_is_identity_test_data(self):
            return self._compose_point_with_inverse_point_is_identity_test_data(
                SpecialEuclidean, self.space_args_list, self.n_points_list
            )

        def compose_point_with_identity_is_point_test_data(self):
            return self._compose_point_with_identity_is_point_test_data(
                SpecialEuclidean, self.space_args_list, self.n_points_list
            )

        def compose_identity_with_point_is_point_test_data(self):
            return self._compose_identity_with_point_is_point_test_data(
                SpecialEuclidean, self.space_args_list, self.n_points_list
            )

        def compose_test_data(self):
            smoke_data = [
                dict(
                    n=2,
                    point_typ="vector",
                    point_1=elements_all["translation_small"],
                    point_2=elements_all["translation_large"],
                    expected=elements_all["translation_small"]
                    + elements_all["translation_large"],
                )
            ]
            return self.generate_tests(smoke_data)

        def group_exp_from_identity_test_data(self):
            smoke_data = [
                dict(
                    n=2,
                    point_type="vector",
                    tangent_vec=elements_all["translation_small"],
                    expected=elements_all["translation_small"],
                ),
                dict(
                    n=2,
                    point_type="vector",
                    tangent_vec=gs.stack([elements_all["translation_small"]] * 2),
                    expected=gs.stack([elements_all["translation_small"]] * 2),
                ),
            ]
            return self.generate_tests(smoke_data)

        def group_log_from_identity_test_data(self):
            smoke_data = [
                dict(
                    n=2,
                    point_type="vector",
                    point=elements_all["translation_small"],
                    expected=elements_all["translation_small"],
                ),
                dict(
                    n=2,
                    point_type="vector",
                    point=gs.stack([elements_all["translation_small"]] * 2),
                    expected=gs.stack([elements_all["translation_small"]] * 2),
                ),
            ]
            return self.generate_tests(smoke_data)

    testing_data = SpecialEuclideanTestData()

    def test_belongs(self, n, mat, expected):
        self.assertAllClose(
            SpecialEuclidean(n).belongs(gs.array(mat)), gs.array(expected)
        )

    def test_random_point_belongs(self, n, n_samples):
        group = self.cls(n)
        self.assertAllClose(gs.all(group(n).random_point(n_samples)), gs.array(True))

    def test_identity(self, n, expected):
        self.assertAllClose(SpecialEuclidean(n).identity, gs.array(expected))

    def test_is_tangent(self, n, tangent_vec, base_point, expected):
        result = SpecialEuclidean(n).is_tangent(
            gs.array(tangent_vec), gs.array(base_point)
        )
        self.assertAllClose(result, gs.array(expected))

    def test_metrics_default_point_type(self, n, metric_str):
        group = self.space(n)
        self.assertTrue(getattr(group, metric_str).default_point_type == "matrix")

    def test_inverse_shape(self, n, points, expected):
        group = self.space(n)
        self.assertAllClose(gs.shape(group.inverse(points)), expected)

    def test_compose_shape(self, n, point_a, point_b, expected):
        group = self.space(n)
        result = gs.shape(group.compose(gs.array(point_a), gs.array(point_b)))
        self.assertAllClose(result, expected)

    def test_regularize_shape(self, n, point_type, n_samples):
        group = self.space(n, point_type)
        points = group.random_point(n_samples=n_samples)
        regularized_points = group.regularize(points)

        self.assertAllClose(
            gs.shape(regularized_points),
            (n_samples, *group.get_point_type_shape()),
        )

    def test_compose(self, n, point_type, point_1, point_2, expected):
        group = self.space(n, point_type)
        result = group.compose(point_1, point_2)
        self.assertAllClose(result, expected)

    def test_group_exp_from_identity(self, n, point_type, tangent_vec, expected):
        group = self.space(n, point_type)
        result = group.exp(base_point=group.identity, tangent_vec=tangent_vec)
        self.assertAllClose(result, expected)

    def test_group_log_from_identity(self, n, point_type, point, expected):
        group = self.space(n, point_type)
        result = group.log(base_point=group.identity, point=point)
        self.assertAllClose(result, expected)