예제 #1
0
    def setUp(self):
        """
        Tangent vectors constructed following:
        http://noodle.med.yale.edu/hdtag/notes/steifel_notes.pdf
        """
        warnings.filterwarnings('ignore')

        gs.random.seed(1234)

        self.p = 3
        self.n = 4
        self.space = Stiefel(self.n, self.p)
        self.n_samples = 10
        self.dimension = int(
            self.p * self.n - (self.p * (self.p + 1) / 2))

        self.point_a = gs.array([
            [1., 0., 0.],
            [0., 1., 0.],
            [0., 0., 1.],
            [0., 0., 0.]])

        self.point_b = gs.array([
            [1. / gs.sqrt(2.), 0., 0.],
            [0., 1., 0.],
            [0., 0., 1.],
            [1. / gs.sqrt(2.), 0., 0.]])

        point_perp = gs.array([
            [0.],
            [0.],
            [0.],
            [1.]])

        matrix_a_1 = gs.array([
            [0., 2., -5.],
            [-2., 0., -1.],
            [5., 1., 0.]])

        matrix_b_1 = gs.array([
            [-2., 1., 4.]])

        matrix_a_2 = gs.array([
            [0., 2., -5.],
            [-2., 0., -1.],
            [5., 1., 0.]])

        matrix_b_2 = gs.array([
            [-2., 1., 4.]])

        self.tangent_vector_1 = (
            gs.matmul(self.point_a, matrix_a_1)
            + gs.matmul(point_perp, matrix_b_1))

        self.tangent_vector_2 = (
            gs.matmul(self.point_a, matrix_a_2)
            + gs.matmul(point_perp, matrix_b_2))

        self.metric = self.space.canonical_metric
예제 #2
0
 def test_stiefel_n_samples(self):
     space = Stiefel(3, 2)
     metric = space.metric
     point = space.random_point(2)
     mean = FrechetMean(metric, lr=0.5, verbose=True, method="default")
     mean.fit(point)
     result = space.belongs(mean.estimate_)
     self.assertTrue(result)
예제 #3
0
 def test_stiefel_two_samples(self):
     space = Stiefel(3, 2)
     metric = space.metric
     point = space.random_point(2)
     mean = FrechetMean(metric)
     mean.fit(point)
     result = mean.estimate_
     expected = metric.exp(metric.log(point[0], point[1]) / 2, point[1])
     self.assertAllClose(expected, result)
예제 #4
0
    def test_to_grassmanniann_vectorized(self):
        inf_rots = gs.array([gs.pi * r_z / n for n in [2, 3, 4]])
        rots = GeneralLinear.exp(inf_rots)
        points = Matrices.mul(rots, point1)

        result = Stiefel.to_grassmannian(points)
        expected = gs.array([p_xy, p_xy, p_xy])
        self.assertAllClose(result, expected)
예제 #5
0
    def log_two_sheets_error_test_data(self):
        stiefel = Stiefel(n=3, p=3)
        base_point = stiefel.random_point()
        det_base = gs.linalg.det(base_point)
        point = stiefel.random_point()
        det_point = gs.linalg.det(point)
        if gs.all(det_base * det_point > 0.0):
            point *= -1.0

        random_data = [
            dict(
                n=3,
                p=3,
                point=point,
                base_point=base_point,
                expected=pytest.raises(ValueError),
            )
        ]
        return self.generate_tests([], random_data)
예제 #6
0
class TestStiefelMethods(geomstats.tests.TestCase):
    def setUp(self):
        """
        Tangent vectors constructed following:
        http://noodle.med.yale.edu/hdtag/notes/steifel_notes.pdf
        """
        warnings.filterwarnings('ignore')

        gs.random.seed(1234)

        self.p = 3
        self.n = 4
        self.space = Stiefel(self.n, self.p)
        self.n_samples = 10
        self.dimension = int(
            self.p * self.n - (self.p * (self.p + 1) / 2))

        self.point_a = gs.array([
            [1., 0., 0.],
            [0., 1., 0.],
            [0., 0., 1.],
            [0., 0., 0.]])

        self.point_b = gs.array([
            [1. / gs.sqrt(2.), 0., 0.],
            [0., 1., 0.],
            [0., 0., 1.],
            [1. / gs.sqrt(2.), 0., 0.]])

        point_perp = gs.array([
            [0.],
            [0.],
            [0.],
            [1.]])

        matrix_a_1 = gs.array([
            [0., 2., -5.],
            [-2., 0., -1.],
            [5., 1., 0.]])

        matrix_b_1 = gs.array([
            [-2., 1., 4.]])

        matrix_a_2 = gs.array([
            [0., 2., -5.],
            [-2., 0., -1.],
            [5., 1., 0.]])

        matrix_b_2 = gs.array([
            [-2., 1., 4.]])

        self.tangent_vector_1 = (
            gs.matmul(self.point_a, matrix_a_1)
            + gs.matmul(point_perp, matrix_b_1))

        self.tangent_vector_2 = (
            gs.matmul(self.point_a, matrix_a_2)
            + gs.matmul(point_perp, matrix_b_2))

        self.metric = self.space.canonical_metric

    @geomstats.tests.np_and_tf_only
    def test_belongs_shape(self):
        point = self.space.random_uniform()
        belongs = self.space.belongs(point)

        self.assertAllClose(gs.shape(belongs), ())

    @geomstats.tests.np_and_tf_only
    def test_random_uniform_and_belongs(self):
        point = self.space.random_uniform()
        result = self.space.belongs(point, tolerance=1e-4)
        expected = True

        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_tf_only
    def test_random_uniform_shape(self):
        result = self.space.random_uniform()

        self.assertAllClose(gs.shape(result), (self.n, self.p))

    @geomstats.tests.np_only
    def test_log_and_exp(self):
        """
        Test that the Riemannian exponential
        and the Riemannian logarithm are inverse.

        Expect their composition to give the identity function.
        """
        # Riemannian Log then Riemannian Exp
        # General case
        base_point = self.point_a
        point = self.point_b

        log = self.metric.log(point=point, base_point=base_point)
        result = self.metric.exp(tangent_vec=log, base_point=base_point)
        expected = point

        self.assertAllClose(result, expected, atol=ATOL)

    @geomstats.tests.np_and_tf_only
    def test_exp_and_belongs(self):
        base_point = self.point_a
        tangent_vec = self.tangent_vector_1

        exp = self.metric.exp(
            tangent_vec=tangent_vec,
            base_point=base_point)
        result = self.space.belongs(exp)
        expected = True
        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_tf_only
    def test_exp_vectorization_shape(self):
        n_samples = self.n_samples
        n = self.n
        p = self.p

        one_base_point = self.point_a
        one_tangent_vec = self.tangent_vector_1

        n_base_points = gs.tile(
            gs.to_ndarray(self.point_a, to_ndim=3),
            (n_samples, 1, 1))
        n_tangent_vecs = gs.tile(
            gs.to_ndarray(self.tangent_vector_2, to_ndim=3),
            (n_samples, 1, 1))

        # With single tangent vec and base point
        result = self.metric.exp(one_tangent_vec, one_base_point)
        self.assertAllClose(gs.shape(result), (n, p))

        # With n_samples tangent vecs and base points
        result = self.metric.exp(n_tangent_vecs, one_base_point)
        self.assertAllClose(gs.shape(result), (n_samples, n, p))

        result = self.metric.exp(one_tangent_vec, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, n, p))

    @geomstats.tests.np_and_tf_only
    def test_log_vectorization_shape(self):
        n_samples = self.n_samples
        n = self.n
        p = self.p

        one_point = self.space.random_uniform()
        one_base_point = self.space.random_uniform()

        n_points = self.space.random_uniform(n_samples=n_samples)
        n_base_points = self.space.random_uniform(n_samples=n_samples)

        # With single point and base point
        result = self.metric.log(one_point, one_base_point)
        self.assertAllClose(gs.shape(result), (n, p))

        # With multiple points and base points
        result = self.metric.log(n_points, one_base_point)
        self.assertAllClose(gs.shape(result), (n_samples, n, p))

        result = self.metric.log(one_point, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, n, p))

        result = self.metric.log(n_points, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, n, p))

    @geomstats.tests.np_only
    def test_retractation_and_lifting(self):
        """
        Test that the Riemannian exponential
        and the Riemannian logarithm are inverse.

        Expect their composition to give the identity function.
        """
        # Riemannian Log then Riemannian Exp
        # General case
        base_point = self.point_a
        point = self.point_b
        tangent_vec = self.tangent_vector_1

        lifted = self.metric.lifting(point=point, base_point=base_point)
        result = self.metric.retraction(
            tangent_vec=lifted, base_point=base_point)
        expected = point

        self.assertAllClose(result, expected, atol=ATOL)

        retract = self.metric.retraction(
            tangent_vec=tangent_vec, base_point=base_point)
        result = self.metric.lifting(point=retract, base_point=base_point)
        expected = tangent_vec

        self.assertAllClose(result, expected, atol=ATOL)

    @geomstats.tests.np_only
    def test_lifting_vectorization_shape(self):
        n_samples = self.n_samples
        n = self.n
        p = self.p

        one_point = self.point_a
        one_base_point = self.point_b
        n_points = gs.tile(
            gs.to_ndarray(self.point_a, to_ndim=3),
            (n_samples, 1, 1))
        n_base_points = gs.tile(
            gs.to_ndarray(self.point_b, to_ndim=3),
            (n_samples, 1, 1))

        result = self.metric.lifting(one_point, one_base_point)
        self.assertAllClose(gs.shape(result), (n, p))

        result = self.metric.lifting(n_points, one_base_point)
        self.assertAllClose(gs.shape(result), (n_samples, n, p))

        result = self.metric.lifting(one_point, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, n, p))

        result = self.metric.lifting(n_points, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, n, p))

    @geomstats.tests.np_and_tf_only
    def test_retraction_vectorization_shape(self):
        n_samples = self.n_samples
        n = self.n
        p = self.p

        one_point = self.point_a
        n_points = gs.tile(
            gs.to_ndarray(one_point, to_ndim=3),
            (n_samples, 1, 1))
        one_tangent_vec = self.tangent_vector_1
        n_tangent_vecs = gs.tile(
            gs.to_ndarray(self.tangent_vector_2, to_ndim=3),
            (n_samples, 1, 1))

        result = self.metric.retraction(one_tangent_vec, one_point)
        self.assertAllClose(gs.shape(result), (n, p))

        result = self.metric.retraction(n_tangent_vecs, one_point)
        self.assertAllClose(gs.shape(result), (n_samples, n, p))

        result = self.metric.retraction(one_tangent_vec, n_points)
        self.assertAllClose(gs.shape(result), (n_samples, n, p))

        result = self.metric.retraction(n_tangent_vecs, n_points)
        self.assertAllClose(gs.shape(result), (n_samples, n, p))

    def test_inner_product(self):
        base_point = self.point_a
        tangent_vector_1 = self.tangent_vector_1
        tangent_vector_2 = self.tangent_vector_2

        result = self.metric.inner_product(
            tangent_vector_1,
            tangent_vector_2,
            base_point=base_point)
        self.assertAllClose(gs.shape(result), ())

    @geomstats.tests.np_and_pytorch_only
    def test_to_grassmannian(self):
        point2 = gs.array([[1., -1.], [1., 1.], [0., 0.]]) / gs.sqrt(2)
        result = self.space.to_grassmannian(point2)
        expected = p_xy
        self.assertAllClose(result, expected)

    @geomstats.tests.np_only
    def test_to_grassmanniann_vectorized(self):
        inf_rots = gs.array([gs.pi * r_z / n for n in [2, 3, 4]])
        rots = GeneralLinear.exp(inf_rots)
        points = Matrices.mul(rots, point1)

        result = Stiefel.to_grassmannian(points)
        expected = gs.array([p_xy, p_xy, p_xy])
        self.assertAllClose(result, expected)
예제 #7
0
class StiefelCanonicalMetricTestData(_RiemannianMetricTestData):

    n_list = random.sample(range(3, 5), 2)
    p_list = [random.sample(range(2, n), 1)[0] for n in n_list]
    metric_args_list = list(zip(n_list, p_list))
    shape_list = metric_args_list
    space_list = [Stiefel(n, p) for n, p in metric_args_list]
    n_points_list = random.sample(range(1, 5), 2)
    n_points_a_list = random.sample(range(1, 5), 2)
    n_points_b_list = [1]
    n_tangent_vecs_list = random.sample(range(1, 5), 2)
    alpha_list = [1] * 2
    n_rungs_list = [1] * 2
    scheme_list = ["pole"] * 2

    def log_two_sheets_error_test_data(self):
        stiefel = Stiefel(n=3, p=3)
        base_point = stiefel.random_point()
        det_base = gs.linalg.det(base_point)
        point = stiefel.random_point()
        det_point = gs.linalg.det(point)
        if gs.all(det_base * det_point > 0.0):
            point *= -1.0

        random_data = [
            dict(
                n=3,
                p=3,
                point=point,
                base_point=base_point,
                expected=pytest.raises(ValueError),
            )
        ]
        return self.generate_tests([], random_data)

    def exp_shape_test_data(self):
        return self._exp_shape_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
        )

    def log_shape_test_data(self):
        return self._log_shape_test_data(
            self.metric_args_list,
            self.space_list,
        )

    def squared_dist_is_symmetric_test_data(self):
        return self._squared_dist_is_symmetric_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
            atol=gs.atol * 1000,
        )

    def exp_belongs_test_data(self):
        return self._exp_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            belongs_atol=gs.atol * 10000,
        )

    def log_is_tangent_test_data(self):
        return self._log_is_tangent_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            is_tangent_atol=gs.atol * 1000,
        )

    def geodesic_ivp_belongs_test_data(self):
        return self._geodesic_ivp_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_points_list,
            belongs_atol=gs.atol * 1000,
        )

    def geodesic_bvp_belongs_test_data(self):
        return self._geodesic_bvp_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            belongs_atol=gs.atol * 1000,
        )

    def exp_after_log_test_data(self):
        return self._exp_after_log_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            rtol=gs.rtol * 100,
            atol=gs.atol * 10000,
        )

    def log_after_exp_test_data(self):
        return self._log_after_exp_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            rtol=gs.rtol * 100,
            atol=gs.atol * 10000,
        )

    def exp_ladder_parallel_transport_test_data(self):
        return self._exp_ladder_parallel_transport_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            self.n_rungs_list,
            self.alpha_list,
            self.scheme_list,
            atol=1e-1,
        )

    def exp_geodesic_ivp_test_data(self):
        return self._exp_geodesic_ivp_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            self.n_points_list,
            rtol=gs.rtol * 1000,
            atol=gs.atol * 1000,
        )

    def dist_is_symmetric_test_data(self):
        return self._dist_is_symmetric_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
            atol=gs.atol * 1000,
        )

    def dist_is_positive_test_data(self):
        return self._dist_is_positive_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def squared_dist_is_positive_test_data(self):
        return self._squared_dist_is_positive_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def dist_is_norm_of_log_test_data(self):
        return self._dist_is_norm_of_log_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def dist_point_to_itself_is_zero_test_data(self):
        return self._dist_point_to_itself_is_zero_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            atol=gs.atol * 1000,
        )

    def inner_product_is_symmetric_test_data(self):
        return self._inner_product_is_symmetric_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
        )

    def triangle_inequality_of_dist_test_data(self):
        return self._triangle_inequality_of_dist_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            atol=1e-3,
        )

    def retraction_lifting_test_data(self):
        return self._log_after_exp_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            rtol=gs.rtol * 100,
            atol=gs.atol * 10000,
        )

    def lifting_retraction_test_data(self):
        return self._exp_after_log_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            rtol=gs.rtol * 100,
            atol=gs.atol * 10000,
        )

    def retraction_shape_test_data(self):
        return self.exp_shape_test_data()

    def lifting_shape_test_data(self):
        return self.log_shape_test_data()