Пример #1
0
 def test_connection(self, group, tangent_vec_a, tangent_vec_b, expected):
     metric = InvariantMetric(group)
     self.assertAllClose(metric.connection(tangent_vec_a, tangent_vec_b),
                         expected)
Пример #2
0
    def setUp(self):
        logger = logging.getLogger()
        logger.disabled = True
        warnings.simplefilter('ignore', category=ImportWarning)

        gs.random.seed(1234)

        n = 3
        group = SpecialEuclidean(n=n, point_type='vector')
        matrix_so3 = SpecialOrthogonal(n=n)
        vector_so3 = SpecialOrthogonal(n=n, point_type='vector')

        # Diagonal left and right invariant metrics
        diag_mat_at_identity = gs.eye(group.dim)

        left_diag_metric = InvariantMetric(
            group=group,
            inner_product_mat_at_identity=None,
            left_or_right='left')
        right_diag_metric = InvariantMetric(
            group=group,
            inner_product_mat_at_identity=diag_mat_at_identity,
            left_or_right='right')

        # General left and right invariant metrics
        # FIXME (nina): This is valid only for bi-invariant metrics
        sym_mat_at_identity = gs.eye(group.dim)

        left_metric = InvariantMetric(
            group=group,
            inner_product_mat_at_identity=sym_mat_at_identity,
            left_or_right='left')

        right_metric = InvariantMetric(
            group=group,
            inner_product_mat_at_identity=sym_mat_at_identity,
            left_or_right='right')

        matrix_left_metric = InvariantMetric(group=matrix_so3)

        matrix_right_metric = InvariantMetric(
            group=matrix_so3,
            left_or_right='right')

        # General case for the point
        point_1 = gs.array([[-0.2, 0.9, 0.5, 5., 5., 5.]])
        point_2 = gs.array([[0., 2., -0.1, 30., 400., 2.]])
        point_1_matrix = vector_so3.matrix_from_rotation_vector(
            point_1[:, :3])
        point_2_matrix = vector_so3.matrix_from_rotation_vector(
            point_2[:, :3])
        # Edge case for the point, angle < epsilon,
        point_small = gs.array([[-1e-7, 0., -7 * 1e-8, 6., 5., 9.]])

        self.group = group
        self.matrix_so3 = matrix_so3

        self.left_diag_metric = left_diag_metric
        self.right_diag_metric = right_diag_metric
        self.left_metric = left_metric
        self.right_metric = right_metric
        self.matrix_left_metric = matrix_left_metric
        self.matrix_right_metric = matrix_right_metric
        self.point_1 = point_1
        self.point_2 = point_2
        self.point_1_matrix = point_1_matrix
        self.point_2_matrix = point_2_matrix
        self.point_small = point_small
Пример #3
0
 def test_structure_constant(self, group, tangent_vec_a, tangent_vec_b,
                             tangent_vec_c, expected):
     metric = InvariantMetric(group=group)
     result = metric.structure_constant(tangent_vec_a, tangent_vec_b,
                                        tangent_vec_c)
     self.assertAllClose(result, expected)
Пример #4
0
 def test_sectional_curvature(self, group, tangent_vec_a, tangent_vec_b,
                              expected):
     metric = InvariantMetric(group)
     result = metric.sectional_curvature(tangent_vec_a, tangent_vec_b)
     self.assertAllClose(result, expected)
Пример #5
0
class SpecialEuclidean3VectorsTestData(TestData):
    group = SpecialEuclidean(n=3, point_type="vector")
    angle_0 = gs.zeros(6)
    angle_close_0 = 1e-10 * gs.array([1.0, -1.0, 1.0, 0.0, 0.0, 0.0
                                      ]) + gs.array(
                                          [0.0, 0.0, 0.0, 1.0, 5.0, 2])
    angle_close_pi_low = (gs.pi - 1e-9) / gs.sqrt(2.0) * gs.array(
        [0.0, 1.0, -1.0, 0.0, 0.0, 0.0]) + gs.array(
            [0.0, 0.0, 0.0, -100.0, 0.0, 2.0])
    angle_pi = gs.pi / gs.sqrt(3.0) * gs.array(
        [1.0, 1.0, -1.0, 0.0, 0.0, 0.0]) + gs.array(
            [0.0, 0.0, 0.0, -10.2, 0.0, 2.6])
    angle_close_pi_high = (gs.pi + 1e-9) / gs.sqrt(3.0) * gs.array(
        [-1.0, 1.0, -1.0, 0.0, 0.0, 0.0]) + gs.array(
            [0.0, 0.0, 0.0, -100.0, 0.0, 2.0])
    angle_in_pi_2pi = (gs.pi + 0.3) / gs.sqrt(5.0) * gs.array(
        [-2.0, 1.0, 0.0, 0.0, 0.0, 0.0]) + gs.array(
            [0.0, 0.0, 0.0, -100.0, 0.0, 2.0])
    angle_close_2pi_low = (2 * gs.pi - 1e-9) / gs.sqrt(6.0) * gs.array(
        [2.0, 1.0, -1.0, 0.0, 0.0, 0.0]) + gs.array(
            [0.0, 0.0, 0.0, 8.0, 555.0, -2.0])
    angle_2pi = 2.0 * gs.pi / gs.sqrt(3.0) * gs.array(
        [1.0, 1.0, -1.0, 0.0, 0.0, 0.0]) + gs.array(
            [0.0, 0.0, 0.0, 1.0, 8.0, -10.0])
    angle_close_2pi_high = (2.0 * gs.pi + 1e-9) / gs.sqrt(2.0) * gs.array(
        [1.0, 0.0, -1.0, 0.0, 0.0, 0.0]) + gs.array(
            [0.0, 0.0, 0.0, 1.0, 8.0, -10.0])

    point_1 = gs.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
    point_2 = gs.array([0.5, 0.0, -0.3, 0.4, 5.0, 60.0])

    translation_large = gs.array([0.0, 0.0, 0.0, 0.4, 0.5, 0.6])
    translation_small = gs.array([0.0, 0.0, 0.0, 0.5, 0.6, 0.7])
    rot_with_parallel_trans = gs.array([gs.pi / 3.0, 0.0, 0.0, 1.0, 0.0, 0.0])

    elements_all = {
        "angle_0": angle_0,
        "angle_close_0": angle_close_0,
        "angle_close_pi_low": angle_close_pi_low,
        "angle_pi": angle_pi,
        "angle_close_pi_high": angle_close_pi_high,
        "angle_in_pi_2pi": angle_in_pi_2pi,
        "angle_close_2pi_low": angle_close_2pi_low,
        "angle_2pi": angle_2pi,
        "angle_close_2pi_high": angle_close_2pi_high,
        "translation_large": translation_large,
        "translation_small": translation_small,
        "point_1": point_1,
        "point_2": point_2,
        "rot_with_parallel_trans": rot_with_parallel_trans,
    }
    elements = elements_all
    if geomstats.tests.tf_backend():
        # Tf is extremely slow
        elements = {
            "point_1": point_1,
            "point_2": point_2,
            "angle_close_pi_low": angle_close_pi_low,
        }

    # Metrics - only diagonals
    diag_mat_at_identity = gs.eye(6) * gs.array([2.0, 2.0, 2.0, 3.0, 3.0, 3.0])

    left_diag_metric = InvariantMetric(
        group=group,
        metric_mat_at_identity=diag_mat_at_identity,
        left_or_right="left",
    )
    right_diag_metric = InvariantMetric(
        group=group,
        metric_mat_at_identity=diag_mat_at_identity,
        left_or_right="right",
    )

    metrics_all = {
        "left_canonical": group.left_canonical_metric,
        "right_canonical": group.right_canonical_metric,
        "left_diag": left_diag_metric,
        "right_diag": right_diag_metric,
    }
    # FIXME:
    # 'left': left_metric,
    # 'right': right_metric}
    metrics = metrics_all
    if geomstats.tests.tf_backend():
        metrics = {"left_diag": left_diag_metric}

    angles_close_to_pi_all = [
        "angle_close_pi_low",
        "angle_pi",
        "angle_close_pi_high",
    ]
    angles_close_to_pi = angles_close_to_pi_all
    if geomstats.tests.tf_backend():
        angles_close_to_pi = ["angle_close_pi_low"]

    def exp_after_log_right_with_angles_close_to_pi_test_data(self):
        smoke_data = []
        for metric in list(
                self.metrics.values()) + [SpecialEuclidean(3, "vector")]:
            for base_point in self.elements.values():
                for element_type in self.angles_close_to_pi:
                    point = self.elements[element_type]
                    smoke_data.append(
                        dict(
                            metric=metric,
                            point=point,
                            base_point=base_point,
                        ))
        return self.generate_tests(smoke_data)

    def exp_after_log_test_data(self):
        smoke_data = []
        for metric in list(
                self.metrics.values()) + [SpecialEuclidean(3, "vector")]:
            for base_point in self.elements.values():
                for element_type in self.elements:
                    if element_type in self.angles_close_to_pi:
                        continue
                    point = self.elements[element_type]
                    smoke_data.append(
                        dict(
                            metric=metric,
                            point=point,
                            base_point=base_point,
                        ))
        return self.generate_tests(smoke_data)

    def log_after_exp_with_angles_close_to_pi_test_data(self):
        smoke_data = []
        for metric in self.metrics_all.values():
            for base_point in self.elements.values():
                for element_type in self.angles_close_to_pi:
                    tangent_vec = self.elements_all[element_type]
                    smoke_data.append(
                        dict(
                            metric=metric,
                            tangent_vec=tangent_vec,
                            base_point=base_point,
                        ))
        return self.generate_tests(smoke_data)

    def log_after_exp_test_data(self):
        smoke_data = []
        for metric in [
                self.metrics_all["left_canonical"],
                self.metrics_all["left_diag"],
        ]:
            for base_point in self.elements.values():
                for element_type in self.elements:
                    if element_type in self.angles_close_to_pi:
                        continue
                    tangent_vec = self.elements[element_type]
                    smoke_data.append(
                        dict(
                            metric=metric,
                            tangent_vec=tangent_vec,
                            base_point=base_point,
                        ))
        return self.generate_tests(smoke_data)

    def exp_test_data(self):
        rot_vec_base_point = gs.array([0.0, 0.0, 0.0])
        translation_base_point = gs.array([4.0, -1.0, 10000.0])
        transfo_base_point = gs.concatenate(
            [rot_vec_base_point, translation_base_point], axis=0)

        # Tangent vector is a translation (no infinitesimal rotational part)
        # Expect the sum of the translation
        # with the translation of the reference point
        rot_vec = gs.array([0.0, 0.0, 0.0])
        translation = gs.array([1.0, 0.0, -3.0])
        tangent_vec = gs.concatenate([rot_vec, translation], axis=0)
        expected = gs.concatenate(
            [gs.array([0.0, 0.0, 0.0]),
             gs.array([5.0, -1.0, 9997.0])], axis=0)
        smoke_data = [
            dict(
                metric=self.metrics_all["left_canonical"],
                tangent_vec=tangent_vec,
                base_point=transfo_base_point,
                expected=expected,
            ),
            dict(
                metric=self.group,
                tangent_vec=self.elements_all["translation_small"],
                base_point=self.elements_all["translation_large"],
                expected=self.elements_all["translation_large"] +
                self.elements_all["translation_small"],
            ),
        ]
        return self.generate_tests(smoke_data)

    def log_test_data(self):
        rot_vec_base_point = gs.array([0.0, 0.0, 0.0])
        translation_base_point = gs.array([4.0, 0.0, 0.0])
        transfo_base_point = gs.concatenate(
            [rot_vec_base_point, translation_base_point], axis=0)

        # Point is a translation (no rotational part)
        # Expect the difference of the translation
        # by the translation of the reference point
        rot_vec = gs.array([0.0, 0.0, 0.0])
        translation = gs.array([-1.0, -1.0, -1.2])
        point = gs.concatenate([rot_vec, translation], axis=0)

        expected = gs.concatenate(
            [gs.array([0.0, 0.0, 0.0]),
             gs.array([-5.0, -1.0, -1.2])], axis=0)
        smoke_data = [
            dict(
                metric=self.metrics_all["left_canonical"],
                point=point,
                base_point=transfo_base_point,
                expected=expected,
            ),
            dict(
                metric=self.group,
                point=self.elements_all["translation_large"],
                base_point=self.elements_all["translation_small"],
                expected=self.elements_all["translation_large"] -
                self.elements_all["translation_small"],
            ),
        ]
        return self.generate_tests(smoke_data)

    def regularize_extreme_cases_test_data(self):
        smoke_data = []
        for angle_type in ["angle_close_0", "angle_close_pi_low", "angle_0"]:
            point = self.elements_all[angle_type]
            smoke_data += [dict(point=point, expected=point)]

        if not geomstats.tests.tf_backend():
            angle_type = "angle_pi"
            point = self.elements_all[angle_type]
            smoke_data += [dict(point=point, expected=point)]

            angle_type = "angle_close_pi_high"
            point = self.elements_all[angle_type]

            norm = gs.linalg.norm(point[:3])
            expected_rot = gs.concatenate(
                [point[:3] / norm * (norm - 2 * gs.pi),
                 gs.zeros(3)], axis=0)
            expected_trans = gs.concatenate([gs.zeros(3), point[3:6]], axis=0)
            expected = expected_rot + expected_trans
            smoke_data += [dict(point=point, expected=expected)]

            in_pi_2pi = ["angle_in_pi_2pi", "angle_close_2pi_low"]

            for angle_type in in_pi_2pi:
                point = self.elements_all[angle_type]
                angle = gs.linalg.norm(point[:3])
                new_angle = gs.pi - (angle - gs.pi)

                expected_rot = gs.concatenate(
                    [-new_angle * (point[:3] / angle),
                     gs.zeros(3)], axis=0)
                expected_trans = gs.concatenate([gs.zeros(3), point[3:6]],
                                                axis=0)
                expected = expected_rot + expected_trans
                smoke_data += [dict(point=point, expected=expected)]

            angle_type = "angle_2pi"
            point = self.elements_all[angle_type]

            expected = gs.concatenate([gs.zeros(3), point[3:6]], axis=0)
            smoke_data += [dict(point=point, expected=expected)]

            angle_type = "angle_close_2pi_high"
            point = self.elements_all[angle_type]
            angle = gs.linalg.norm(point[:3])
            new_angle = angle - 2 * gs.pi

            expected_rot = gs.concatenate(
                [new_angle * point[:3] / angle,
                 gs.zeros(3)], axis=0)
            expected_trans = gs.concatenate([gs.zeros(3), point[3:6]], axis=0)
            expected = expected_rot + expected_trans
            smoke_data += [dict(point=point, expected=expected)]
        return self.generate_tests(smoke_data)
Пример #6
0
    def setup_method(self):
        logger = logging.getLogger()
        logger.disabled = True
        warnings.simplefilter("ignore", category=ImportWarning)

        gs.random.seed(1234)

        n = 3
        group = SpecialEuclidean(n=n, point_type="vector")
        matrix_se3 = SpecialEuclidean(n=n)
        matrix_so3 = SpecialOrthogonal(n=n)
        vector_so3 = SpecialOrthogonal(n=n, point_type="vector")

        # Diagonal left and right invariant metrics
        diag_mat_at_identity = gs.eye(group.dim)

        left_diag_metric = InvariantMetric(
            group=group, metric_mat_at_identity=None, left_or_right="left"
        )
        right_diag_metric = InvariantMetric(
            group=group,
            metric_mat_at_identity=diag_mat_at_identity,
            left_or_right="right",
        )

        sym_mat_at_identity = gs.eye(group.dim)

        left_metric = InvariantMetric(
            group=group,
            metric_mat_at_identity=sym_mat_at_identity,
            left_or_right="left",
        )

        right_metric = InvariantMetric(
            group=group,
            metric_mat_at_identity=sym_mat_at_identity,
            left_or_right="right",
        )

        matrix_left_metric = InvariantMetric(group=matrix_so3)

        matrix_right_metric = InvariantMetric(group=matrix_so3, left_or_right="right")

        # General case for the point
        point_1 = gs.array([-0.2, 0.9, 0.5, 5.0, 5.0, 5.0])
        point_2 = gs.array([0.0, 2.0, -0.1, 30.0, 400.0, 2.0])
        point_1_matrix = vector_so3.matrix_from_rotation_vector(point_1[..., :3])
        point_2_matrix = vector_so3.matrix_from_rotation_vector(point_2[..., :3])
        # Edge case for the point, angle < epsilon,
        point_small = gs.array([-1e-7, 0.0, -7 * 1e-8, 6.0, 5.0, 9.0])

        self.group = group
        self.matrix_so3 = matrix_so3
        self.matrix_se3 = matrix_se3

        self.left_diag_metric = left_diag_metric
        self.right_diag_metric = right_diag_metric
        self.left_metric = left_metric
        self.right_metric = right_metric
        self.matrix_left_metric = matrix_left_metric
        self.matrix_right_metric = matrix_right_metric
        self.point_1 = point_1
        self.point_2 = point_2
        self.point_1_matrix = point_1_matrix
        self.point_2_matrix = point_2_matrix
        self.point_small = point_small
Пример #7
0
    def structure_constant_test_data(self):
        group = self.matrix_so3
        metric = InvariantMetric(group)
        x, y, z = metric.normal_basis(group.lie_algebra.basis)
        smoke_data = []
        smoke_data += [
            dict(
                group=self.matrix_so3,
                tangent_vec_a=x,
                tangent_vec_b=y,
                tangent_vec_c=z,
                expected=2.0**0.5 / 2.0,
            )
        ]
        smoke_data += [
            dict(
                group=self.matrix_so3,
                tangent_vec_a=y,
                tangent_vec_b=x,
                tangent_vec_c=z,
                expected=-(2.0**0.5 / 2.0),
            )
        ]
        smoke_data += [
            dict(
                group=self.matrix_so3,
                tangent_vec_a=y,
                tangent_vec_b=z,
                tangent_vec_c=x,
                expected=2.0**0.5 / 2.0,
            )
        ]
        smoke_data += [
            dict(
                group=self.matrix_so3,
                tangent_vec_a=z,
                tangent_vec_b=y,
                tangent_vec_c=x,
                expected=-(2.0**0.5 / 2.0),
            )
        ]
        smoke_data += [
            dict(
                group=self.matrix_so3,
                tangent_vec_a=z,
                tangent_vec_b=x,
                tangent_vec_c=y,
                expected=2.0**0.5 / 2.0,
            )
        ]
        smoke_data += [
            dict(
                group=self.matrix_so3,
                tangent_vec_a=x,
                tangent_vec_b=z,
                tangent_vec_c=y,
                expected=-(2.0**0.5 / 2.0),
            )
        ]

        for x, y in itertools.permutations((x, y, z), 2):
            smoke_data += [
                dict(
                    group=self.matrix_so3,
                    tangent_vec_a=x,
                    tangent_vec_b=x,
                    tangent_vec_c=y,
                    expected=0.0,
                )
            ]

        return self.generate_tests(smoke_data)