コード例 #1
0
    def test_hypersphere_riemannian_mean_shift_predict(self):
        gs.random.seed(1234)
        dim = 2

        manifold = Hypersphere(dim)
        metric = HypersphereMetric(dim)
        cluster = manifold.random_von_mises_fisher(kappa=100, n_samples=10)

        rms = riemannian_mean_shift(
            manifold=manifold,
            metric=metric,
            bandwidth=0.6,
            tol=1e-4,
            n_centers=2,
            max_iter=100,
        )
        rms.fit(cluster)
        result = rms.predict(cluster)

        closest_centers = []
        for point in cluster:
            closest_center = metric.closest_neighbor_index(point, rms.centers)
            closest_centers.append(rms.centers[closest_center, :])
        expected = gs.array(closest_centers)

        self.assertAllClose(expected, result)
コード例 #2
0
    def setUp(self):
        warnings.simplefilter("ignore", category=UserWarning)
        gs.random.seed(0)
        self.dim = 2
        self.euc = Euclidean(dim=self.dim)
        self.sphere = Hypersphere(dim=self.dim)
        self.euc_metric = EuclideanMetric(dim=self.dim)
        self.sphere_metric = HypersphereMetric(dim=self.dim)

        def _euc_metric_matrix(base_point):
            """Return matrix of Euclidean inner-product."""
            dim = base_point.shape[-1]
            return gs.eye(dim)

        def _sphere_metric_matrix(base_point):
            """Return sphere's metric in spherical coordinates."""
            theta = base_point[..., 0]
            mat = gs.array([[1.0, 0.0], [0.0, gs.sin(theta) ** 2]])
            return mat

        new_euc_metric = RiemannianMetric(dim=self.dim)
        new_euc_metric.metric_matrix = _euc_metric_matrix

        new_sphere_metric = RiemannianMetric(dim=self.dim)
        new_sphere_metric.metric_matrix = _sphere_metric_matrix

        self.new_euc_metric = new_euc_metric
        self.new_sphere_metric = new_sphere_metric
コード例 #3
0
    def test_single_cluster_riemannian_mean_shift(self):
        gs.random.seed(10)

        sphere = Hypersphere(dim=2)
        metric = HypersphereMetric(2)

        cluster = sphere.random_von_mises_fisher(kappa=100, n_samples=10)

        rms = riemannian_mean_shift(
            manifold=sphere,
            metric=metric,
            bandwidth=float("inf"),
            tol=1e-4,
            n_centers=1,
            max_iter=1,
        )
        rms.fit(cluster)
        center = rms.predict(cluster)

        mean = FrechetMean(metric=metric, init_step_size=1.0)
        mean.fit(cluster)

        result = center[0]
        expected = mean.estimate_

        self.assertAllClose(expected, result)
コード例 #4
0
    def test_double_cluster_riemannian_mean_shift(self):
        gs.random.seed(10)
        number_of_samples = 20
        sphere = Hypersphere(dim=2)
        metric = HypersphereMetric(2)

        cluster = sphere.random_von_mises_fisher(kappa=20,
                                                 n_samples=number_of_samples)

        special_orthogonal = SpecialOrthogonal(3)
        rotation1 = special_orthogonal.random_uniform()
        rotation2 = special_orthogonal.random_uniform()

        cluster_1 = cluster @ rotation1
        cluster_2 = cluster @ rotation2

        combined_cluster = gs.concatenate((cluster_1, cluster_2))
        rms = riemannian_mean_shift(manifold=sphere,
                                    metric=metric,
                                    bandwidth=0.3,
                                    tol=1e-4,
                                    n_centers=2)

        rms.fit(combined_cluster)
        closest_centers = rms.predict(combined_cluster)

        count_in_first_cluster = 0
        for point in closest_centers:
            if gs.allclose(point, rms.centers[0]):
                count_in_first_cluster += 1

        count_in_second_cluster = 0
        for point in closest_centers:
            if gs.allclose(point, rms.centers[1]):
                count_in_second_cluster += 1

        self.assertEqual(combined_cluster.shape[0],
                         count_in_first_cluster + count_in_second_cluster)
コード例 #5
0
class RiemannianMetricTestData(TestData):

    dim = 2
    euc = Euclidean(dim=dim)
    sphere = Hypersphere(dim=dim)
    euc_metric = EuclideanMetric(dim=dim)
    sphere_metric = HypersphereMetric(dim=dim)

    new_euc_metric = RiemannianMetric(dim=dim)
    new_euc_metric.metric_matrix = _euc_metric_matrix

    new_sphere_metric = RiemannianMetric(dim=dim)
    new_sphere_metric.metric_matrix = _sphere_metric_matrix

    new_euc_metric = new_euc_metric
    new_sphere_metric = new_sphere_metric

    def cometric_matrix_test_data(self):
        random_data = [
            dict(
                metric=self.euc_metric,
                base_point=self.euc.random_point(),
                expected=gs.eye(self.dim),
            )
        ]
        return self.generate_tests(random_data)

    def inner_coproduct_test_data(self):
        base_point = gs.array([0.0, 0.0, 1.0])
        cotangent_vec_a = self.sphere.to_tangent(gs.array([1.0, 2.0, 0.0]),
                                                 base_point)
        cotangent_vec_b = self.sphere.to_tangent(gs.array([1.0, 3.0, 0.0]),
                                                 base_point)

        smoke_data = [
            dict(
                metric=self.euc_metric,
                cotangent_vec_a=gs.array([1.0, 2.0]),
                cotangent_vec_b=gs.array([1.0, 2.0]),
                base_point=self.euc.random_point(),
                expected=5.0,
            ),
            dict(
                metric=self.sphere_metric,
                cotangent_vec_a=cotangent_vec_a,
                cotangent_vec_b=cotangent_vec_b,
                base_point=base_point,
                expected=7.0,
            ),
        ]
        return self.generate_tests(smoke_data)

    def hamiltonian_test_data(self):

        smoke_data = [
            dict(
                metric=self.euc_metric,
                state=(gs.array([1.0, 2.0]), gs.array([1.0, 2.0])),
                expected=2.5,
            )
        ]
        smoke_data += [
            dict(
                metric=self.sphere_metric,
                state=(gs.array([0.0, 0.0, 1.0]), gs.array([1.0, 2.0, 1.0])),
                expected=3.0,
            )
        ]
        return self.generate_tests(smoke_data)

    def inner_product_derivative_matrix_test_data(self):
        base_point = self.euc.random_point()
        random_data = [
            dict(
                metric=self.new_euc_metric,
                base_point=base_point,
                expected=gs.zeros((self.dim, ) * 3),
            )
        ]
        random_data += [
            dict(
                metric=self.euc_metric,
                base_point=base_point,
                expected=gs.zeros((self.dim, ) * 3),
            )
        ]
        return self.generate_tests([], random_data)

    def inner_product_test_data(self):
        base_point = self.euc.random_point()
        tangent_vec_a = self.euc.random_point()
        tangent_vec_b = self.euc.random_point()
        random_data = [
            dict(
                metric=self.euc_metric,
                tangent_vec_a=tangent_vec_a,
                tangent_vec_b=tangent_vec_b,
                base_point=base_point,
                expected=gs.dot(tangent_vec_a, tangent_vec_b),
            )
        ]

        smoke_data = [
            dict(
                metric=self.new_sphere_metric,
                tangent_vec_a=gs.array([0.3, 0.4]),
                tangent_vec_b=gs.array([0.1, -0.5]),
                base_point=gs.array([gs.pi / 3.0, gs.pi / 5.0]),
                expected=-0.12,
            )
        ]
        return self.generate_tests(smoke_data, random_data)

    def christoffels_test_data(self):
        base_point = gs.array([gs.pi / 10.0, gs.pi / 9.0])
        gs.array([gs.pi / 10.0, gs.pi / 9.0])
        smoke_data = []
        random_data = []
        smoke_data = [
            dict(
                metric=self.new_sphere_metric,
                base_point=gs.array([gs.pi / 10.0, gs.pi / 9.0]),
                expected=self.sphere_metric.christoffels(base_point),
            )
        ]
        random_data += [
            dict(
                metric=self.new_euc_metric,
                base_point=self.euc.random_point(),
                expected=gs.zeros((self.dim, ) * 3),
            )
        ]
        random_data += [
            dict(
                metric=self.euc_metric,
                base_point=self.euc.random_point(),
                expected=gs.zeros((self.dim, ) * 3),
            )
        ]

        return self.generate_tests(smoke_data, random_data)

    def exp_test_data(self):
        base_point = gs.array([gs.pi / 10.0, gs.pi / 9.0])
        tangent_vec = gs.array([gs.pi / 2.0, 0.0])
        expected = gs.array([gs.pi / 10.0 + gs.pi / 2.0, gs.pi / 9.0])

        euc_base_point = self.euc.random_point()
        euc_tangent_vec = self.euc.random_point()
        euc_expected = euc_base_point + euc_tangent_vec

        smoke_data = [
            dict(
                metric=self.new_sphere_metric,
                tangent_vec=tangent_vec,
                base_point=base_point,
                expected=expected,
            )
        ]
        random_data = [
            dict(
                metric=self.new_euc_metric,
                tangent_vec=euc_tangent_vec,
                base_point=euc_base_point,
                expected=euc_expected,
            )
        ]
        return self.generate_tests(smoke_data, random_data)

    def log_test_data(self):
        base_point = self.euc.random_point()
        point = self.euc.random_point()
        expected = point - base_point
        random_data = [
            dict(
                metric=self.new_euc_metric,
                point=point,
                base_point=base_point,
                expected=expected,
            )
        ]
        return self.generate_tests([], random_data)
コード例 #6
0
ファイル: categorical.py プロジェクト: emaignant/geomstats
 def __init__(self, dim):
     super(CategoricalMetric, self).__init__(dim=dim)
     self.sphere_metric = HypersphereMetric(dim)
コード例 #7
0
ファイル: categorical.py プロジェクト: emaignant/geomstats
class CategoricalMetric(RiemannianMetric):
    """Class for the Fisher information metric on categorical distributions.

    The Fisher information metric on the $n$-simplex of categorical
    distributions parameters can be obtained as the pullback metric of the
    $n$-sphere using the componentwise square root.

    References
    ----------
    .. [K2003] R. E. Kass. The Geometry of Asymptotic Inference. Statistical
      Science, 4(3): 188 - 234, 1989.
    """
    def __init__(self, dim):
        super(CategoricalMetric, self).__init__(dim=dim)
        self.sphere_metric = HypersphereMetric(dim)

    def metric_matrix(self, base_point=None):
        """Compute the inner-product matrix.

        Compute the inner-product matrix of the Fisher information metric
        at the tangent space at base point.

        Parameters
        ----------
        base_point : array-like, shape=[..., dim + 1]
            Base point.

        Returns
        -------
        mat : array-like, shape=[..., dim, dim]
            Inner-product matrix.
        """
        if base_point is None:
            raise ValueError("A base point must be given to compute the "
                             "metric matrix")
        base_point = gs.to_ndarray(base_point, to_ndim=2)
        mat = from_vector_to_diagonal_matrix(1 / base_point)
        return gs.squeeze(mat)

    @staticmethod
    def simplex_to_sphere(point):
        """Send point of the simplex to the sphere.

        The map takes the square root of each component.

        Parameters
        ----------
        point : array-like, shape=[..., dim + 1]
            Point on the simplex.

        Returns
        -------
        point_sphere : array-like, shape=[..., dim + 1]
            Point on the sphere.
        """
        return point**(1 / 2)

    @staticmethod
    def sphere_to_simplex(point):
        """Send point of the sphere to the simplex.

        The map squares each component.

        Parameters
        ----------
        point : array-like, shape=[..., dim + 1]
            Point on the sphere.

        Returns
        -------
        point_simplex : array-like, shape=[..., dim + 1]
            Point on the simplex.
        """
        return point**2

    def tangent_simplex_to_sphere(self, tangent_vec, base_point):
        """Send tangent vector of the simplex to tangent space of sphere.

        This is the differential of the simplex_to_sphere map.

        Parameters
        ----------
        tangent_vec : array-like, shape=[..., dim + 1]
            Tangent vec to the simplex at base point.
        base_point : array-like, shape=[..., dim + 1]
            Point of the simplex.

        Returns
        -------
        tangent_vec_sphere : array-like, shape=[..., dim + 1]
            Tangent vec to the sphere at the image of
            base point by simplex_to_sphere.
        """
        base_point = gs.to_ndarray(base_point, to_ndim=2)
        tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=2)
        tangent_vec_sphere = gs.einsum(
            "...i,...i->...i", tangent_vec,
            1 / (2 * self.simplex_to_sphere(base_point)))
        return gs.squeeze(tangent_vec_sphere)

    @staticmethod
    def tangent_sphere_to_simplex(tangent_vec, base_point):
        """Send tangent vector of the sphere to tangent space of simplex.

        This is the differential of the sphere_to_simplex map.

        Parameters
        ----------
        tangent_vec : array-like, shape=[..., dim + 1]
            Tangent vec to the sphere at base point.
        base_point : array-like, shape=[..., dim + 1]
            Point of the sphere.

        Returns
        -------
        tangent_vec_simplex : array-like, shape=[..., dim + 1]
            Tangent vec to the simplex at the image of
            base point by sphere_to_simplex.
        """
        base_point = gs.to_ndarray(base_point, to_ndim=2)
        tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=2)
        tangent_vec_simplex = gs.einsum("...i,...i->...i", tangent_vec,
                                        2 * base_point)
        return gs.squeeze(tangent_vec_simplex)

    def exp(self, tangent_vec, base_point):
        """Compute the exponential map.

        Comute the exponential map associated to the Fisher information
        metric by pulling back the exponential map on the sphere by the
        simplex_to_sphere map.

        Parameters
        ----------
        tangent_vec : array-like, shape=[..., dim + 1]
            Tangent vector at base point.
        base_point : array-like, shape=[..., dim + 1]
            Base point.

        Returns
        -------
        exp : array-like, shape=[..., dim + 1]
            End point of the geodesic starting at base_point with
            initial velocity tangent_vec and stopping at time 1.
        """
        base_point_sphere = self.simplex_to_sphere(base_point)
        tangent_vec_sphere = self.tangent_simplex_to_sphere(
            tangent_vec, base_point)
        exp_sphere = self.sphere_metric.exp(tangent_vec_sphere,
                                            base_point_sphere)

        return self.sphere_to_simplex(exp_sphere)

    def log(self, point, base_point):
        """Compute the logarithm map.

        Compute logarithm map associated to the Fisher information
        metric by pulling back the exponential map on the sphere by
        the simplex_to_sphere map.

        Parameters
        ----------
        point : array-like, shape=[..., dim + 1]
            Point.
        base_point : array-like, shape=[..., dim + 1]
            Base po int.

        Returns
        -------
        tangent_vec : array-like, shape=[..., dim + 1]
            Initial velocity of the geodesic starting at base_point and
            reaching point at time 1.
        """
        point_sphere = self.simplex_to_sphere(point)
        base_point_sphere = self.simplex_to_sphere(base_point)
        log_sphere = self.sphere_metric.log(point_sphere, base_point_sphere)

        return self.tangent_sphere_to_simplex(log_sphere, base_point_sphere)

    def geodesic(self,
                 initial_point,
                 end_point=None,
                 initial_tangent_vec=None):
        """Generate parameterized function for the geodesic curve.

        Geodesic curve defined by either:
        - an initial point and an initial tangent vector,
        - an initial point and an end point.

        Parameters
        ----------
        initial_point : array-like, shape=[..., dim + 1]
            Point on the manifold, initial point of the geodesic.
        end_point : array-like, shape=[..., dim + 1]
            Point on the manifold, end point of the geodesic.
            Optional, default: None.
            If None, an initial tangent vector must be given.
        initial_tangent_vec : array-like, shape=[..., dim + 1],
            Tangent vector at base point, the initial speed of the geodesics.
            Optional, default: None.
            If None, an end point must be given and a logarithm is computed.

        Returns
        -------
        path : callable
            Time parameterized geodesic curve. If a batch of initial
            conditions is passed, the output array's first dimension
            represents time, and the second corresponds to the different
            initial conditions.
        """
        initial_point_sphere = self.simplex_to_sphere(initial_point)
        end_point_sphere = None
        vec_sphere = None
        if end_point is not None:
            end_point_sphere = self.simplex_to_sphere(end_point)
        if initial_tangent_vec is not None:
            vec_sphere = self.tangent_simplex_to_sphere(
                initial_tangent_vec, initial_point)
        geodesic_sphere = self.sphere_metric.geodesic(initial_point_sphere,
                                                      end_point_sphere,
                                                      vec_sphere)

        def path(t):
            """Generate parameterized function for geodesic curve.

            Parameters
            ----------
            t : array-like, shape=[n_times,]
                Times at which to compute points of the geodesics.

            Returns
            -------
            geodesic : array-like, shape=[..., n_times, dim + 1]
                Values of the geodesic at times t.
            """
            geod_sphere_at_t = geodesic_sphere(t)
            geod_at_t = self.sphere_to_simplex(geod_sphere_at_t)
            return gs.squeeze(geod_at_t)

        return path
コード例 #8
0
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    plot_sphere = Sphere()

    n_points = 10000
    f_points = plot_sphere.fibonnaci_points(n_points).swapaxes(0, 1)
    plot_sphere.plot_heatmap(ax=ax, n_points=n_points, scalar_function=loss_f)
    correct_points = s_points[labels == 0][:30, :]
    correct_labels = np.ones_like(correct_points)

    ax = visualization.plot(correct_points, ax=ax, space='S2', color='red', s=80)

    f_labels = np.array(f_labels)[:, 0]
    f_points = f_points[f_labels != 0]

    metric = HypersphereMetric(dim=2)
    for k in range(len(correct_points)):
        point_matrix = correct_points[k:k+1, :].repeat(len(f_points), axis=0)
        dist_array = metric.dist(point_matrix, f_points)
        idx_min = np.argmin(dist_array)

        geodesic = sphere.metric.geodesic(
            initial_point=correct_points[k],
            end_point=f_points[idx_min])

        points_on_geodesic = geodesic(gs.linspace(0., 1., 10))
        plot_sphere.add_points(points_on_geodesic)

    plot_sphere.draw_points(ax=ax, color='black', alpha=0.1)

    plt.show()
コード例 #9
0
class TestRiemannianMetric(geomstats.tests.TestCase):
    def setUp(self):
        warnings.simplefilter("ignore", category=UserWarning)
        gs.random.seed(0)
        self.dim = 2
        self.euc = Euclidean(dim=self.dim)
        self.sphere = Hypersphere(dim=self.dim)
        self.euc_metric = EuclideanMetric(dim=self.dim)
        self.sphere_metric = HypersphereMetric(dim=self.dim)

        def _euc_metric_matrix(base_point):
            """Return matrix of Euclidean inner-product."""
            dim = base_point.shape[-1]
            return gs.eye(dim)

        def _sphere_metric_matrix(base_point):
            """Return sphere's metric in spherical coordinates."""
            theta = base_point[..., 0]
            mat = gs.array([[1.0, 0.0], [0.0, gs.sin(theta) ** 2]])
            return mat

        new_euc_metric = RiemannianMetric(dim=self.dim)
        new_euc_metric.metric_matrix = _euc_metric_matrix

        new_sphere_metric = RiemannianMetric(dim=self.dim)
        new_sphere_metric.metric_matrix = _sphere_metric_matrix

        self.new_euc_metric = new_euc_metric
        self.new_sphere_metric = new_sphere_metric

    def test_cometric_matrix(self):
        base_point = self.euc.random_point()

        result = self.euc_metric.metric_inverse_matrix(base_point)
        expected = gs.eye(self.dim)

        self.assertAllClose(result, expected)

    @geomstats.tests.autograd_and_torch_only
    def test_metric_derivative_euc_metric(self):
        base_point = self.euc.random_point()

        result = self.euc_metric.inner_product_derivative_matrix(base_point)
        expected = gs.zeros((self.dim,) * 3)

        self.assertAllClose(result, expected)

    @geomstats.tests.autograd_and_torch_only
    def test_metric_derivative_new_euc_metric(self):
        base_point = self.euc.random_point()

        result = self.new_euc_metric.inner_product_derivative_matrix(base_point)
        expected = gs.zeros((self.dim,) * 3)

        self.assertAllClose(result, expected)

    def test_inner_product_new_euc_metric(self):
        base_point = self.euc.random_point()
        tan_a = self.euc.random_point()
        tan_b = self.euc.random_point()
        expected = gs.dot(tan_a, tan_b)

        result = self.new_euc_metric.inner_product(tan_a, tan_b, base_point=base_point)

        self.assertAllClose(result, expected)

    def test_inner_product_new_sphere_metric(self):
        base_point = gs.array([gs.pi / 3.0, gs.pi / 5.0])
        tan_a = gs.array([0.3, 0.4])
        tan_b = gs.array([0.1, -0.5])
        expected = -0.12

        result = self.new_sphere_metric.inner_product(
            tan_a, tan_b, base_point=base_point
        )

        self.assertAllClose(result, expected)

    @geomstats.tests.autograd_and_torch_only
    def test_christoffels_eucl_metric(self):
        base_point = self.euc.random_point()

        result = self.euc_metric.christoffels(base_point)
        expected = gs.zeros((self.dim,) * 3)

        self.assertAllClose(result, expected)

    @geomstats.tests.autograd_and_torch_only
    def test_christoffels_new_eucl_metric(self):
        base_point = self.euc.random_point()

        result = self.new_euc_metric.christoffels(base_point)
        expected = gs.zeros((self.dim,) * 3)

        self.assertAllClose(result, expected)

    @geomstats.tests.autograd_tf_and_torch_only
    def test_christoffels_sphere_metrics(self):
        base_point = gs.array([gs.pi / 10.0, gs.pi / 9.0])

        expected = self.sphere_metric.christoffels(base_point)
        result = self.new_sphere_metric.christoffels(base_point)

        self.assertAllClose(result, expected)

    @geomstats.tests.autograd_and_torch_only
    def test_exp_new_eucl_metric(self):
        base_point = self.euc.random_point()
        tan = self.euc.random_point()

        expected = base_point + tan
        result = self.new_euc_metric.exp(tan, base_point)
        self.assertAllClose(result, expected)

    @geomstats.tests.autograd_and_torch_only
    def test_log_new_eucl_metric(self):
        base_point = self.euc.random_point()
        point = self.euc.random_point()

        expected = point - base_point
        result = self.new_euc_metric.log(point, base_point)
        self.assertAllClose(result, expected)

    @geomstats.tests.autograd_tf_and_torch_only
    def test_exp_new_sphere_metric(self):
        base_point = gs.array([gs.pi / 10.0, gs.pi / 9.0])
        tan = gs.array([gs.pi / 2.0, 0.0])

        expected = gs.array([gs.pi / 10.0 + gs.pi / 2.0, gs.pi / 9.0])
        result = self.new_sphere_metric.exp(tan, base_point)
        self.assertAllClose(result, expected)