Exemplo n.º 1
0
    def test_cholesky_factor(self, n, spd_mat, cf):
        result = SPDMatrices.cholesky_factor(gs.array(spd_mat))

        self.assertAllClose(result, gs.array(cf))
        self.assertAllClose(
            gs.all(PositiveLowerTriangularMatrices(n).belongs(result)),
            gs.array(True),
        )
Exemplo n.º 2
0
    def differential_cholesky_factor(cls, tangent_vec, base_point):
        """Compute the differential of the cholesky factor map.

        Parameters
        ----------
        tangent_vec : array_like, shape=[..., n, n]
            Tangent vector at base point.
            symmetric matrix.

        base_point : array_like, shape=[..., n, n]
            Base point.
            spd matrix.

        Returns
        -------
        differential_cf : array-like, shape=[..., n, n]
            Differential of cholesky factor map
            lower triangular matrix.
        """
        cf = cls.cholesky_factor(base_point)
        differential_cf = PositiveLowerTriangularMatrices.inverse_differential_gram(
            tangent_vec, cf)
        return differential_cf
Exemplo n.º 3
0
 def test_cholesky_factor_belongs(self, n, mat):
     result = SPDMatrices(n).cholesky_factor(gs.array(mat))
     self.assertAllClose(
         gs.all(PositiveLowerTriangularMatrices(n).belongs(result)), True
     )
class CholeskyMetricTestData(_RiemannianMetricTestData):
    n_list = random.sample(range(2, 5), 2)
    metric_args_list = [(n, ) for n in n_list]
    shape_list = [(n, n) for n in n_list]
    space_list = [PositiveLowerTriangularMatrices(n) for n in n_list]
    n_points_list = random.sample(range(1, 5), 2)
    n_tangent_vecs_list = random.sample(range(1, 5), 2)
    n_points_a_list = random.sample(range(1, 5), 2)
    n_points_b_list = [1]
    batch_size_list = random.sample(range(2, 5), 2)
    alpha_list = [1] * 2
    n_rungs_list = [1] * 2
    scheme_list = ["pole"] * 2

    def diag_inner_product_test_data(self):
        smoke_data = [
            dict(
                n=2,
                tangent_vec_a=[[1.0, 0.0], [-2.0, -1.0]],
                tangent_vec_b=[[2.0, 0.0], [-3.0, -1.0]],
                base_point=[[SQRT_2, 0.0], [-3.0, 1.0]],
                expected=2.0,
            )
        ]
        return self.generate_tests(smoke_data)

    def strictly_lower_inner_product_test_data(self):
        smoke_data = [
            dict(
                n=2,
                tangent_vec_a=[[1.0, 0.0], [-2.0, -1.0]],
                tangent_vec_b=[[2.0, 0.0], [-3.0, -1.0]],
                expected=6.0,
            )
        ]
        return self.generate_tests(smoke_data)

    def inner_product_test_data(self):
        smoke_data = [
            dict(
                n=2,
                tangent_vec_a=[[1.0, 0.0], [-2.0, -1.0]],
                tangent_vec_b=[[2.0, 0.0], [-3.0, -1.0]],
                base_point=[[SQRT_2, 0.0], [-3.0, 1.0]],
                expected=8.0,
            ),
            dict(
                n=2,
                tangent_vec_a=[
                    [[3.0, 0.0], [4.0, 2.0]],
                    [[-1.0, 0.0], [2.0, -4.0]],
                ],
                tangent_vec_b=[[[4.0, 0.0], [3.0, 3.0]],
                               [[3.0, 0.0], [-6.0, 2.0]]],
                base_point=[[[3, 0.0], [-2.0, 6.0]], [[1, 0.0], [-1.0, 1.0]]],
                expected=[13.5, -23.0],
            ),
        ]
        return self.generate_tests(smoke_data)

    def exp_test_data(self):
        smoke_data = [
            dict(
                n=2,
                tangent_vec=[[-1.0, 0.0], [2.0, 3.0]],
                base_point=[[1.0, 0.0], [2.0, 2.0]],
                expected=[[1 / EULER, 0.0], [4.0, 2 * gs.exp(1.5)]],
            ),
            dict(
                n=2,
                tangent_vec=[[[0.0, 0.0], [2.0, 0.0]], [[1.0, 0.0], [0.0,
                                                                     0.0]]],
                base_point=[[[1.0, 0.0], [2.0, 2.0]], [[1.0, 0.0], [0.0,
                                                                    2.0]]],
                expected=[
                    [[1.0, 0.0], [4.0, 2.0]],
                    [[gs.exp(1.0), 0.0], [0.0, 2.0]],
                ],
            ),
        ]
        return self.generate_tests(smoke_data)

    def log_test_data(self):
        smoke_data = [
            dict(
                n=2,
                point=[[EULER, 0.0], [2.0, EULER**3]],
                base_point=[[EULER**3, 0.0], [4.0, EULER**4]],
                expected=[[-2.0 * EULER**3, 0.0], [-2.0, -1 * EULER**4]],
            ),
            dict(
                n=2,
                point=[
                    [[gs.exp(-2.0), 0.0], [0.0, gs.exp(2.0)]],
                    [[gs.exp(-3.0), 0.0], [2.0, gs.exp(3.0)]],
                ],
                base_point=[[[1.0, 0.0], [-1.0, 1.0]], [[1.0, 0.0], [0.0,
                                                                     1.0]]],
                expected=[[[-2.0, 0.0], [1.0, 2.0]], [[-3.0, 0.0], [2.0,
                                                                    3.0]]],
            ),
        ]
        return self.generate_tests(smoke_data)

    def squared_dist_test_data(self):
        smoke_data = [
            dict(
                n=2,
                point_a=[[EULER, 0.0], [2.0, EULER**3]],
                point_b=[[EULER**3, 0.0], [4.0, EULER**4]],
                expected=9,
            ),
            dict(
                n=2,
                point_a=[
                    [[EULER, 0.0], [2.0, EULER**3]],
                    [[EULER, 0.0], [4.0, EULER**3]],
                ],
                point_b=[
                    [[EULER**3, 0.0], [4.0, EULER**4]],
                    [[EULER**3, 0.0], [7.0, EULER**4]],
                ],
                expected=[9, 14],
            ),
        ]
        return self.generate_tests(smoke_data)

    def exp_shape_test_data(self):
        return self._exp_shape_test_data(self.metric_args_list,
                                         self.space_list, self.shape_list)

    def log_shape_test_data(self):
        return self._log_shape_test_data(self.metric_args_list,
                                         self.space_list)

    def squared_dist_is_symmetric_test_data(self):
        return self._squared_dist_is_symmetric_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
            atol=gs.atol * 1000,
        )

    def exp_belongs_test_data(self):
        return self._exp_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            belongs_atol=gs.atol * 1000,
        )

    def log_is_tangent_test_data(self):
        return self._log_is_tangent_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            is_tangent_atol=gs.atol * 1000,
        )

    def geodesic_ivp_belongs_test_data(self):
        return self._geodesic_ivp_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_points_list,
            belongs_atol=gs.atol * 1000,
        )

    def geodesic_bvp_belongs_test_data(self):
        return self._geodesic_bvp_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            belongs_atol=gs.atol * 1000,
        )

    def exp_after_log_test_data(self):
        return self._exp_after_log_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            rtol=gs.rtol * 100,
            atol=gs.atol * 10000,
        )

    def log_after_exp_test_data(self):
        return self._log_after_exp_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            rtol=gs.rtol * 100,
            atol=gs.atol * 10000,
        )

    def exp_ladder_parallel_transport_test_data(self):
        return self._exp_ladder_parallel_transport_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            self.n_rungs_list,
            self.alpha_list,
            self.scheme_list,
        )

    def exp_geodesic_ivp_test_data(self):
        return self._exp_geodesic_ivp_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            self.n_points_list,
            rtol=gs.rtol * 100000,
            atol=gs.atol * 100000,
        )

    def parallel_transport_ivp_is_isometry_test_data(self):
        return self._parallel_transport_ivp_is_isometry_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            is_tangent_atol=gs.atol * 1000,
            atol=gs.atol * 1000,
        )

    def parallel_transport_bvp_is_isometry_test_data(self):
        return self._parallel_transport_bvp_is_isometry_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            is_tangent_atol=gs.atol * 1000,
            atol=gs.atol * 1000,
        )

    def dist_is_symmetric_test_data(self):
        return self._dist_is_symmetric_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def dist_is_positive_test_data(self):
        return self._dist_is_positive_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def squared_dist_is_positive_test_data(self):
        return self._squared_dist_is_positive_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def dist_is_norm_of_log_test_data(self):
        return self._dist_is_norm_of_log_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def dist_point_to_itself_is_zero_test_data(self):
        return self._dist_point_to_itself_is_zero_test_data(
            self.metric_args_list, self.space_list, self.n_points_list)

    def inner_product_is_symmetric_test_data(self):
        return self._inner_product_is_symmetric_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
        )

    def retraction_lifting_test_data(self):
        return self._log_after_exp_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            rtol=gs.rtol * 100,
            atol=gs.atol * 10000,
        )