Ejemplo n.º 1
0
 def test_parallel_transport(self, dim, n_samples):
     sphere = Hypersphere(dim)
     base_point = sphere.random_uniform(n_samples)
     tan_vec_a = sphere.to_tangent(gs.random.rand(n_samples, 3), base_point)
     tan_vec_b = sphere.to_tangent(gs.random.rand(n_samples, 3), base_point)
     expected = sphere.metric.parallel_transport(tan_vec_a, base_point,
                                                 tan_vec_b)
     expected_point = sphere.metric.exp(tan_vec_b, base_point)
     base_point = gs.cast(base_point, gs.float64)
     base_point, tan_vec_a, tan_vec_b = gs.convert_to_wider_dtype(
         [base_point, tan_vec_a, tan_vec_b])
     for step, alpha in zip(["pole", "schild"], [1, 2]):
         min_n = 1 if step == "pole" else 50
         tol = 1e-5 if step == "pole" else 1e-2
         for n_rungs in [min_n, 11]:
             ladder = sphere.metric.ladder_parallel_transport(
                 tan_vec_a,
                 base_point,
                 tan_vec_b,
                 n_rungs=n_rungs,
                 scheme=step,
                 alpha=alpha,
             )
             result = ladder["transported_tangent_vec"]
             result_point = ladder["end_point"]
             self.assertAllClose(result, expected, rtol=tol, atol=tol)
             self.assertAllClose(result_point, expected_point)
Ejemplo n.º 2
0
    def test_parallel_transport_trajectory(self, dim, n_samples):
        sphere = Hypersphere(dim)
        for step in ["pole", "schild"]:
            n_steps = 1 if step == "pole" else 50
            tol = 1e-6 if step == "pole" else 1e-2
            base_point = sphere.random_uniform(n_samples)
            tan_vec_a = sphere.to_tangent(gs.random.rand(n_samples, 3),
                                          base_point)
            tan_vec_b = sphere.to_tangent(gs.random.rand(n_samples, 3),
                                          base_point)
            expected = sphere.metric.parallel_transport(
                tan_vec_a, base_point, tan_vec_b)
            expected_point = sphere.metric.exp(tan_vec_b, base_point)
            ladder = sphere.metric.ladder_parallel_transport(
                tan_vec_a,
                base_point,
                tan_vec_b,
                n_rungs=n_steps,
                scheme=step,
                return_geodesics=True,
            )
            result = ladder["transported_tangent_vec"]
            result_point = ladder["end_point"]

            self.assertAllClose(result, expected, rtol=tol, atol=tol)
            self.assertAllClose(result_point, expected_point)
Ejemplo n.º 3
0
def main():
    sphere = Hypersphere(dimension=2)

    data = sphere.random_von_mises_fisher(kappa=10, n_samples=1000)

    n_clusters = 4
    clustering = OnlineKMeans(metric=sphere.metric, n_clusters=n_clusters)
    clustering = clustering.fit(data)

    plt.figure(0)
    ax = plt.subplot(111, projection="3d")
    visualization.plot(points=clustering.cluster_centers_,
                       ax=ax,
                       space='S2',
                       c='r')
    plt.show()

    plt.figure(1)
    ax = plt.subplot(111, projection="3d")
    sphere_plot = visualization.Sphere()
    sphere_plot.draw(ax=ax)
    for i in range(n_clusters):
        cluster = data[clustering.labels_ == i, :]
        sphere_plot.draw_points(ax=ax, points=cluster)
    plt.show()
Ejemplo n.º 4
0
    def test_single_cluster(self):
        gs.random.seed(10)

        sphere = Hypersphere(dim=2)
        metric = sphere.metric

        cluster = sphere.random_von_mises_fisher(kappa=100, n_samples=10)

        rms = riemannian_mean_shift(
            manifold=sphere,
            metric=metric,
            bandwidth=float("inf"),
            tol=1e-4,
            n_centers=1,
            max_iter=1,
        )
        rms.fit(cluster)
        center = rms.predict(cluster)

        mean = FrechetMean(metric=metric, init_step_size=1.0)
        mean.fit(cluster)

        result = center[0]
        expected = mean.estimate_

        self.assertAllClose(expected, result)
Ejemplo n.º 5
0
    def test_hypersphere_predict(self):
        gs.random.seed(1234)

        sphere = Hypersphere(dim=2)
        metric = sphere.metric
        cluster = sphere.random_von_mises_fisher(kappa=100, n_samples=10)

        rms = riemannian_mean_shift(
            manifold=sphere,
            metric=metric,
            bandwidth=0.6,
            tol=1e-4,
            n_centers=2,
            max_iter=100,
        )
        rms.fit(cluster)
        result = rms.predict(cluster)

        closest_centers = []
        for point in cluster:
            closest_center = metric.closest_neighbor_index(point, rms.centers)
            closest_centers.append(rms.centers[closest_center, :])
        expected = gs.array(closest_centers)

        self.assertAllClose(expected, result)
Ejemplo n.º 6
0
def load_cities():
    """Load data from data/cities/cities.json.

    Returns
    -------
    data : array-like, shape=[50, 2]
        Array with each row representing one sample,
        i. e. latitude and longitude of a city.
        Angles are in radians.
    name : list
        List of city names.
    """
    with open(CITIES_PATH, encoding='utf-8') as json_file:
        data_file = json.load(json_file)

        names = [row['city'] for row in data_file]
        data = list(
            map(
                lambda row:
                [row[col_name] / 180 * gs.pi for col_name in ['lat', 'lng']],
                data_file,
            ))

    data = gs.array(data)

    colat = gs.pi / 2 - data[:, 0]
    colat = gs.expand_dims(colat, axis=1)
    lng = gs.expand_dims(data[:, 1] + gs.pi, axis=1)

    data = gs.concatenate([colat, lng], axis=1)
    sphere = Hypersphere(dim=2)
    data = sphere.spherical_to_extrinsic(data)
    return data, names
Ejemplo n.º 7
0
def main():
    """Plot an Agglomerative Hierarchical Clustering on the sphere."""
    sphere = Hypersphere(dim=2)
    sphere_distance = sphere.metric.dist

    n_clusters = 2
    n_samples_per_dataset = 50

    dataset_1 = sphere.random_von_mises_fisher(
        kappa=10,
        n_samples=n_samples_per_dataset)
    dataset_2 = - sphere.random_von_mises_fisher(
        kappa=10,
        n_samples=n_samples_per_dataset)
    dataset = gs.concatenate((dataset_1, dataset_2), axis=0)

    clustering = AgglomerativeHierarchicalClustering(
        n_clusters=n_clusters,
        distance=sphere_distance)

    clustering.fit(dataset)

    clustering_labels = clustering.labels_

    plt.figure(0)
    ax = plt.subplot(111, projection='3d')
    plt.title('Agglomerative Hierarchical Clustering')
    sphere_plot = visualization.Sphere()
    sphere_plot.draw(ax=ax)
    for i_label in range(n_clusters):
        points_label_i = dataset[clustering_labels == i_label, ...]
        sphere_plot.draw_points(ax=ax, points=points_label_i)

    plt.show()
Ejemplo n.º 8
0
    def setUp(self):
        gs.random.seed(1234)

        self.dimension = 4
        self.space = Hypersphere(dimension=self.dimension)
        self.metric = self.space.metric
        self.n_samples = 10
Ejemplo n.º 9
0
 def test_random_von_mises_kappa(self):
     # check concentration parameter for dispersed distribution
     kappa = 1.0
     n_points = 100000
     for dim in [2, 9]:
         sphere = Hypersphere(dim)
         points = sphere.random_von_mises_fisher(kappa=kappa,
                                                 n_samples=n_points)
         sum_points = gs.sum(points, axis=0)
         mean_norm = gs.linalg.norm(sum_points) / n_points
         kappa_estimate = (mean_norm * (dim + 1.0 - mean_norm**2) /
                           (1.0 - mean_norm**2))
         kappa_estimate = gs.cast(kappa_estimate, gs.float64)
         p = dim + 1
         n_steps = 100
         for _ in range(n_steps):
             bessel_func_1 = scipy.special.iv(p / 2.0, kappa_estimate)
             bessel_func_2 = scipy.special.iv(p / 2.0 - 1.0, kappa_estimate)
             ratio = bessel_func_1 / bessel_func_2
             denominator = 1.0 - ratio**2 - (p -
                                             1.0) * ratio / kappa_estimate
             mean_norm = gs.cast(mean_norm, gs.float64)
             kappa_estimate = kappa_estimate - (ratio -
                                                mean_norm) / denominator
         result = kappa_estimate
         expected = kappa
         self.assertAllClose(result, expected, atol=KAPPA_ESTIMATION_TOL)
Ejemplo n.º 10
0
    def setup_method(self):
        warnings.simplefilter("ignore", category=UserWarning)
        gs.random.seed(0)
        self.dim = 2
        self.euc = Euclidean(dim=self.dim)
        self.sphere = Hypersphere(dim=self.dim)
        self.euc_metric = EuclideanMetric(dim=self.dim)
        self.sphere_metric = HypersphereMetric(dim=self.dim)

        def _euc_metric_matrix(base_point):
            """Return matrix of Euclidean inner-product."""
            dim = base_point.shape[-1]
            return gs.eye(dim)

        def _sphere_metric_matrix(base_point):
            """Return sphere's metric in spherical coordinates."""
            theta = base_point[..., 0]
            mat = gs.array([[1.0, 0.0], [0.0, gs.sin(theta)**2]])
            return mat

        new_euc_metric = RiemannianMetric(dim=self.dim)
        new_euc_metric.metric_matrix = _euc_metric_matrix

        new_sphere_metric = RiemannianMetric(dim=self.dim)
        new_sphere_metric.metric_matrix = _sphere_metric_matrix

        self.new_euc_metric = new_euc_metric
        self.new_sphere_metric = new_sphere_metric
    def setup_method(self):
        gs.random.seed(1234)
        self.n_samples = 20

        # Set up for hypersphere
        self.dim_sphere = 2
        self.shape_sphere = (self.dim_sphere + 1, )
        self.sphere = Hypersphere(dim=self.dim_sphere)

        self.intercept_sphere_true = gs.array([0.0, -1.0, 0.0])
        self.coef_sphere_true = gs.array([1.0, 0.0, 0.5])

        # set up the prior
        self.prior = lambda x: self.sphere.metric.exp(
            x * self.coef_sphere_true,
            base_point=self.intercept_sphere_true,
        )

        self.kernel = ConstantKernel(1.0, (1e-3, 1e3)) * RBF(10.0, (1e-2, 1e2))

        # generate data
        X = gs.linspace(0.0, 1.5 * gs.pi, self.n_samples)
        self.X_sphere = gs.reshape((X - gs.mean(X)), (-1, 1))
        # generate the geodesic
        y = self.prior(self.X_sphere)
        # Then add orthogonal sinusoidal oscillations

        o = (1.0 / 20.0) * gs.array([-0.5, 0.0, 1.0])
        o = self.sphere.to_tangent(o, base_point=y)
        s = self.X_sphere * gs.sin(5.0 * gs.pi * self.X_sphere)
        self.y_sphere = self.sphere.metric.exp(s * o, base_point=y)
Ejemplo n.º 12
0
    def setUp(self):
        gs.random.seed(1234)

        self.dimension = 2
        self.space = Hypersphere(dim=self.dimension)
        self.metric = self.space.metric
        self.data = self.space.random_von_mises_fisher(kappa=100, n_samples=50)
Ejemplo n.º 13
0
class GeomstatsSphere(Manifold):
    """A simple adapter class which proxies calls by pymanopt's solvers to
    `Manifold` subclasses to the underlying geomstats `Hypersphere` class.
    """

    def __init__(self, ambient_dimension):
        self._sphere = Hypersphere(ambient_dimension - 1)

    def norm(self, base_vector, tangent_vector):
        return self._sphere.metric.norm(tangent_vector, base_point=base_vector)

    def inner(self, base_vector, tangent_vector_a, tangent_vector_b):
        return self._sphere.metric.inner_product(
            tangent_vector_a, tangent_vector_b, base_point=base_vector)

    def proj(self, base_vector, tangent_vector):
        return self._sphere.to_tangent(
            tangent_vector, base_point=base_vector)

    def retr(self, base_vector, tangent_vector):
        """The retraction operator, which maps a tangent vector in the tangent
        space at a specific point back to the manifold by approximating moving
        along a geodesic. Since geomstats's `Hypersphere` class doesn't provide
        a retraction we use the exponential map instead (see also
        https://hal.archives-ouvertes.fr/hal-00651608/document).
        """
        return self._sphere.metric.exp(tangent_vector, base_point=base_vector)

    def rand(self):
        return self._sphere.random_uniform()
Ejemplo n.º 14
0
        def tangent_extrinsic_to_spherical_raises_test_data(self):
            smoke_data = []
            dim_list = [2, 3]
            for dim in dim_list:
                space = Hypersphere(dim)
                base_point = space.random_point()
                tangent_vec = space.to_tangent(space.random_point(), base_point)
                if dim == 2:
                    expected = does_not_raise()
                    smoke_data.append(
                        dict(
                            dim=2,
                            tangent_vec=tangent_vec,
                            base_point=None,
                            base_point_spherical=None,
                            expected=pytest.raises(ValueError),
                        )
                    )
                else:
                    expected = pytest.raises(NotImplementedError)
                smoke_data.append(
                    dict(
                        dim=dim,
                        tangent_vec=tangent_vec,
                        base_point=base_point,
                        base_point_spherical=None,
                        expected=expected,
                    )
                )

            return self.generate_tests(smoke_data)
Ejemplo n.º 15
0
 def test_random_von_mises_one_sample_belongs(self):
     for dim in [2, 9]:
         sphere = Hypersphere(dim)
         point = sphere.random_von_mises_fisher()
         self.assertAllClose(point.shape, (dim + 1, ))
         result = sphere.belongs(point)
         self.assertTrue(result)
Ejemplo n.º 16
0
    def setUp(self):
        warnings.simplefilter('ignore', category=UserWarning)

        self.dim = 4
        self.euc_metric = EuclideanMetric(dim=self.dim)

        self.connection = Connection(dim=2)
        self.hypersphere = Hypersphere(dim=2)
Ejemplo n.º 17
0
 def setUp(self):
     gs.random.seed(123)
     self.sphere = Hypersphere(dim=4)
     self.hyperbolic = Hyperboloid(dim=3)
     self.euclidean = Euclidean(dim=2)
     self.minkowski = Minkowski(dim=2)
     self.so3 = SpecialOrthogonal(n=3, point_type='vector')
     self.so_matrix = SpecialOrthogonal(n=3, point_type='matrix')
Ejemplo n.º 18
0
    def setup_method(self):
        warnings.simplefilter("ignore", category=UserWarning)
        gs.random.seed(0)
        self.dim = 4
        self.euc_metric = EuclideanMetric(dim=self.dim)

        self.connection = Connection(dim=2)
        self.hypersphere = Hypersphere(dim=2)
Ejemplo n.º 19
0
    def setUp(self):
        gs.random.seed(1234)

        self.space_matrix = ProductManifold(
            manifolds=[Hypersphere(dim=2), Hyperboloid(dim=2)],
            default_point_type='matrix')
        self.space_vector = ProductManifold(
            manifolds=[Hypersphere(dim=2), Hyperboloid(dim=5)],
            default_point_type='vector')
Ejemplo n.º 20
0
    def setUp(self):
        self.n_samples = 10
        self.SO3_GROUP = SpecialOrthogonalGroup(n=3)
        self.SE3_GROUP = SpecialEuclideanGroup(n=3)
        self.S1 = Hypersphere(dimension=1)
        self.S2 = Hypersphere(dimension=2)
        self.H2 = HyperbolicSpace(dimension=2)

        plt.figure()
Ejemplo n.º 21
0
    def setUp(self):
        self.n_samples = 10
        self.SO3_GROUP = SpecialOrthogonal(n=3)
        self.SE3_GROUP = SpecialEuclidean(n=3)
        self.S1 = Hypersphere(dim=1)
        self.S2 = Hypersphere(dim=2)
        self.H2 = Hyperbolic(dim=2)

        plt.figure()
Ejemplo n.º 22
0
class TestVisualization(geomstats.tests.TestCase):
    def setUp(self):
        self.n_samples = 10
        self.SO3_GROUP = SpecialOrthogonal(n=3, point_type='vector')
        self.SE3_GROUP = SpecialEuclidean(n=3, point_type='vector')
        self.S1 = Hypersphere(dim=1)
        self.S2 = Hypersphere(dim=2)
        self.H2 = Hyperbolic(dim=2)
        self.H2_half_plane = PoincareHalfSpace(dim=2)

        plt.figure()

    @staticmethod
    def test_tutorial_matplotlib():
        visualization.tutorial_matplotlib()

    def test_plot_points_so3(self):
        points = self.SO3_GROUP.random_uniform(self.n_samples)
        visualization.plot(points, space='SO3_GROUP')

    def test_plot_points_se3(self):
        points = self.SE3_GROUP.random_uniform(self.n_samples)
        visualization.plot(points, space='SE3_GROUP')

    @geomstats.tests.np_and_pytorch_only
    def test_plot_points_s1(self):
        points = self.S1.random_uniform(self.n_samples)
        visualization.plot(points, space='S1')

    def test_plot_points_s2(self):
        points = self.S2.random_uniform(self.n_samples)
        visualization.plot(points, space='S2')

    def test_plot_points_h2_poincare_disk(self):
        points = self.H2.random_uniform(self.n_samples)
        visualization.plot(points, space='H2_poincare_disk')

    def test_plot_points_h2_poincare_half_plane_ext(self):
        points = self.H2.random_uniform(self.n_samples)
        visualization.plot(points,
                           space='H2_poincare_half_plane',
                           point_type='extrinsic')

    def test_plot_points_h2_poincare_half_plane_none(self):
        points = self.H2_half_plane.random_uniform(self.n_samples)
        visualization.plot(points, space='H2_poincare_half_plane')

    def test_plot_points_h2_poincare_half_plane_hs(self):
        points = self.H2_half_plane.random_uniform(self.n_samples)
        visualization.plot(points,
                           space='H2_poincare_half_plane',
                           point_type='half_space')

    def test_plot_points_h2_klein_disk(self):
        points = self.H2.random_uniform(self.n_samples)
        visualization.plot(points, space='H2_klein_disk')
Ejemplo n.º 23
0
def empirical_frechet_var_bubble(n_samples, theta, dim, n_expectation=1000):
    """Variance of the empirical Fréchet mean for a bubble distribution.

    Draw n_sampless from a bubble distribution, computes its empirical
    Fréchet mean and the square distance to the asymptotic mean. This
    is repeated n_expectation times to compute an approximation of its
    expectation (i.e. its variance) by sampling.

    The bubble distribution is an isotropic distributions on a Riemannian
    hyper sub-sphere of radius 0 < theta < Pi around the north pole of the
    sphere of dimension dim.

    Parameters
    ----------
    n_samples : int
        Number of samples to draw.
    theta: float
        Radius of the bubble distribution.
    dim : int
        Dimension of the sphere (embedded in R^{dim+1}).
    n_expectation: int, optional (defaults to 1000)
        Number of computations for approximating the expectation.

    Returns
    -------
    tuple (variance, std-dev on the computed variance)
    """
    if dim <= 1:
        raise ValueError(
            'Dim > 1 needed to draw a uniform sample on sub-sphere.')
    var = []
    sphere = Hypersphere(dim=dim)
    bubble = Hypersphere(dim=dim - 1)

    north_pole = gs.zeros(dim + 1)
    north_pole[dim] = 1.0
    for _ in range(n_expectation):
        # Sample n points from the uniform distribution on a sub-sphere
        # of radius theta (i.e cos(theta) in ambient space)
        # TODO (nina): Add this code as a method of hypersphere
        data = gs.zeros((n_samples, dim + 1), dtype=gs.float64)
        directions = bubble.random_uniform(n_samples)
        directions = gs.to_ndarray(directions, to_ndim=2)

        for i in range(n_samples):
            for j in range(dim):
                data[i, j] = gs.sin(theta) * directions[i, j]
            data[i, dim] = gs.cos(theta)

        # TODO (nina): Use FrechetMean here
        current_mean = _adaptive_gradient_descent(data,
                                                  metric=sphere.metric,
                                                  max_iter=32,
                                                  init_point=north_pole)
        var.append(sphere.metric.squared_dist(north_pole, current_mean))
    return gs.mean(var), 2 * gs.std(var) / gs.sqrt(n_expectation)
Ejemplo n.º 24
0
 def test_spherical_to_extrinsic_vectorization(self):
     dim = 2
     sphere = Hypersphere(dim)
     points_spherical = gs.array([[gs.pi / 2, 0], [gs.pi / 6, gs.pi / 4]])
     result = sphere.spherical_to_extrinsic(points_spherical)
     expected = gs.array([[1., 0., 0.],
                          [gs.sqrt(2) / 4,
                           gs.sqrt(2) / 4,
                           gs.sqrt(3) / 2]])
     self.assertAllClose(result, expected)
Ejemplo n.º 25
0
    def __init__(self, k_landmarks, m_ambient):
        super(PreShapeMetric,
              self).__init__(dim=m_ambient * (k_landmarks - 1) - 1,
                             default_point_type="matrix")

        self.embedding_metric = MatricesMetric(k_landmarks, m_ambient)
        self.sphere_metric = Hypersphere(m_ambient * k_landmarks - 1).metric

        self.k_landmarks = k_landmarks
        self.m_ambient = m_ambient
Ejemplo n.º 26
0
 def setup_method(self):
     gs.random.seed(123)
     self.sphere = Hypersphere(dim=4)
     self.hyperbolic = Hyperboloid(dim=3)
     self.euclidean = Euclidean(dim=2)
     self.minkowski = Minkowski(dim=2)
     self.so3 = SpecialOrthogonal(n=3, point_type="vector")
     self.so_matrix = SpecialOrthogonal(n=3)
     self.curves_2d = DiscreteCurves(R2)
     self.elastic_metric = ElasticMetric(a=1, b=1, ambient_manifold=R2)
Ejemplo n.º 27
0
    def setUp(self):
        self.n_samples = 10
        self.SO3_GROUP = SpecialOrthogonal(n=3, point_type='vector')
        self.SE3_GROUP = SpecialEuclidean(n=3, point_type='vector')
        self.S1 = Hypersphere(dim=1)
        self.S2 = Hypersphere(dim=2)
        self.H2 = Hyperbolic(dim=2)
        self.H2_half_plane = PoincareHalfSpace(dim=2)

        plt.figure()
Ejemplo n.º 28
0
def empirical_frechet_var_bubble(n_samples, theta, dim, n_expectation=1000):
    """Variance of the empirical Fréchet mean for a bubble distribution.

    Draw n_sampless from a bubble distribution, computes its empirical
    Fréchet mean and the square distance to the asymptotic mean. This
    is repeated n_expectation times to compute an approximation of its
    expectation (i.e. its variance) by sampling.

    The bubble distribution is an isotropic distributions on a Riemannian
    hyper sub-sphere of radius 0 < theta < Pi around the north pole of the
    sphere of dimension dim.

    Parameters
    ----------
    n_samples : int
        Number of samples to draw.
    theta: float
        Radius of the bubble distribution.
    dim : int
        Dimension of the sphere (embedded in R^{dim+1}).
    n_expectation: int, optional (defaults to 1000)
        Number of computations for approximating the expectation.

    Returns
    -------
    tuple (variance, std-dev on the computed variance)
    """
    if dim <= 1:
        raise ValueError(
            "Dim > 1 needed to draw a uniform sample on sub-sphere.")
    var = []
    sphere = Hypersphere(dim=dim)
    bubble = Hypersphere(dim=dim - 1)

    north_pole = gs.zeros(dim + 1)
    north_pole[dim] = 1.0
    for _ in range(n_expectation):
        # Sample n points from the uniform distribution on a sub-sphere
        # of radius theta (i.e cos(theta) in ambient space)
        # TODO (nina): Add this code as a method of hypersphere
        last_col = gs.cos(theta) * gs.ones(n_samples)
        last_col = last_col[:, None] if (n_samples > 1) else last_col

        directions = bubble.random_uniform(n_samples)
        rest_col = gs.sin(theta) * directions
        data = gs.concatenate([rest_col, last_col], axis=-1)

        estimator = FrechetMean(sphere.metric,
                                max_iter=32,
                                method="adaptive",
                                init_point=north_pole)
        estimator.fit(data)
        current_mean = estimator.estimate_
        var.append(sphere.metric.squared_dist(north_pole, current_mean))
    return gs.mean(var), 2 * gs.std(var) / gs.sqrt(n_expectation)
Ejemplo n.º 29
0
    def setup_method(self):
        gs.random.seed(1234)
        self.n_samples = 20

        # Set up for hypersphere
        self.dim_sphere = 4
        self.shape_sphere = (self.dim_sphere + 1, )
        self.sphere = Hypersphere(dim=self.dim_sphere)
        X = gs.random.rand(self.n_samples)
        self.X_sphere = X - gs.mean(X)
        self.intercept_sphere_true = self.sphere.random_point()
        self.coef_sphere_true = self.sphere.projection(
            gs.random.rand(self.dim_sphere + 1))

        self.y_sphere = self.sphere.metric.exp(
            self.X_sphere[:, None] * self.coef_sphere_true,
            base_point=self.intercept_sphere_true,
        )

        self.param_sphere_true = gs.vstack(
            [self.intercept_sphere_true, self.coef_sphere_true])
        self.param_sphere_guess = gs.vstack([
            self.y_sphere[0],
            self.sphere.to_tangent(gs.random.normal(size=self.shape_sphere),
                                   self.y_sphere[0]),
        ])

        # Set up for special euclidean
        self.se2 = SpecialEuclidean(n=2)
        self.metric_se2 = self.se2.left_canonical_metric
        self.metric_se2.default_point_type = "matrix"

        self.shape_se2 = (3, 3)
        X = gs.random.rand(self.n_samples)
        self.X_se2 = X - gs.mean(X)

        self.intercept_se2_true = self.se2.random_point()
        self.coef_se2_true = self.se2.to_tangent(
            5.0 * gs.random.rand(*self.shape_se2), self.intercept_se2_true)

        self.y_se2 = self.metric_se2.exp(
            self.X_se2[:, None, None] * self.coef_se2_true[None],
            self.intercept_se2_true,
        )

        self.param_se2_true = gs.vstack([
            gs.flatten(self.intercept_se2_true),
            gs.flatten(self.coef_se2_true),
        ])
        self.param_se2_guess = gs.vstack([
            gs.flatten(self.y_se2[0]),
            gs.flatten(
                self.se2.to_tangent(gs.random.normal(size=self.shape_se2),
                                    self.y_se2[0])),
        ])
Ejemplo n.º 30
0
    def setup_method(self):
        gs.random.seed(1234)

        self.space_matrix = ProductManifold(
            manifolds=[Hypersphere(dim=2), Hyperboloid(dim=2)],
            default_point_type="matrix",
        )
        self.space_vector = ProductManifold(
            manifolds=[Hypersphere(dim=2), Hyperboloid(dim=3)],
            default_point_type="vector",
        )