def test_hypersphere_riemannian_mean_shift_predict(self): gs.random.seed(1234) dim = 2 manifold = Hypersphere(dim) metric = HypersphereMetric(dim) cluster = manifold.random_von_mises_fisher(kappa=100, n_samples=10) rms = riemannian_mean_shift( manifold=manifold, metric=metric, bandwidth=0.6, tol=1e-4, n_centers=2, max_iter=100, ) rms.fit(cluster) result = rms.predict(cluster) closest_centers = [] for point in cluster: closest_center = metric.closest_neighbor_index(point, rms.centers) closest_centers.append(rms.centers[closest_center, :]) expected = gs.array(closest_centers) self.assertAllClose(expected, result)
def setUp(self): warnings.simplefilter("ignore", category=UserWarning) gs.random.seed(0) self.dim = 2 self.euc = Euclidean(dim=self.dim) self.sphere = Hypersphere(dim=self.dim) self.euc_metric = EuclideanMetric(dim=self.dim) self.sphere_metric = HypersphereMetric(dim=self.dim) def _euc_metric_matrix(base_point): """Return matrix of Euclidean inner-product.""" dim = base_point.shape[-1] return gs.eye(dim) def _sphere_metric_matrix(base_point): """Return sphere's metric in spherical coordinates.""" theta = base_point[..., 0] mat = gs.array([[1.0, 0.0], [0.0, gs.sin(theta) ** 2]]) return mat new_euc_metric = RiemannianMetric(dim=self.dim) new_euc_metric.metric_matrix = _euc_metric_matrix new_sphere_metric = RiemannianMetric(dim=self.dim) new_sphere_metric.metric_matrix = _sphere_metric_matrix self.new_euc_metric = new_euc_metric self.new_sphere_metric = new_sphere_metric
def test_single_cluster_riemannian_mean_shift(self): gs.random.seed(10) sphere = Hypersphere(dim=2) metric = HypersphereMetric(2) cluster = sphere.random_von_mises_fisher(kappa=100, n_samples=10) rms = riemannian_mean_shift( manifold=sphere, metric=metric, bandwidth=float("inf"), tol=1e-4, n_centers=1, max_iter=1, ) rms.fit(cluster) center = rms.predict(cluster) mean = FrechetMean(metric=metric, init_step_size=1.0) mean.fit(cluster) result = center[0] expected = mean.estimate_ self.assertAllClose(expected, result)
def test_double_cluster_riemannian_mean_shift(self): gs.random.seed(10) number_of_samples = 20 sphere = Hypersphere(dim=2) metric = HypersphereMetric(2) cluster = sphere.random_von_mises_fisher(kappa=20, n_samples=number_of_samples) special_orthogonal = SpecialOrthogonal(3) rotation1 = special_orthogonal.random_uniform() rotation2 = special_orthogonal.random_uniform() cluster_1 = cluster @ rotation1 cluster_2 = cluster @ rotation2 combined_cluster = gs.concatenate((cluster_1, cluster_2)) rms = riemannian_mean_shift(manifold=sphere, metric=metric, bandwidth=0.3, tol=1e-4, n_centers=2) rms.fit(combined_cluster) closest_centers = rms.predict(combined_cluster) count_in_first_cluster = 0 for point in closest_centers: if gs.allclose(point, rms.centers[0]): count_in_first_cluster += 1 count_in_second_cluster = 0 for point in closest_centers: if gs.allclose(point, rms.centers[1]): count_in_second_cluster += 1 self.assertEqual(combined_cluster.shape[0], count_in_first_cluster + count_in_second_cluster)
class RiemannianMetricTestData(TestData): dim = 2 euc = Euclidean(dim=dim) sphere = Hypersphere(dim=dim) euc_metric = EuclideanMetric(dim=dim) sphere_metric = HypersphereMetric(dim=dim) new_euc_metric = RiemannianMetric(dim=dim) new_euc_metric.metric_matrix = _euc_metric_matrix new_sphere_metric = RiemannianMetric(dim=dim) new_sphere_metric.metric_matrix = _sphere_metric_matrix new_euc_metric = new_euc_metric new_sphere_metric = new_sphere_metric def cometric_matrix_test_data(self): random_data = [ dict( metric=self.euc_metric, base_point=self.euc.random_point(), expected=gs.eye(self.dim), ) ] return self.generate_tests(random_data) def inner_coproduct_test_data(self): base_point = gs.array([0.0, 0.0, 1.0]) cotangent_vec_a = self.sphere.to_tangent(gs.array([1.0, 2.0, 0.0]), base_point) cotangent_vec_b = self.sphere.to_tangent(gs.array([1.0, 3.0, 0.0]), base_point) smoke_data = [ dict( metric=self.euc_metric, cotangent_vec_a=gs.array([1.0, 2.0]), cotangent_vec_b=gs.array([1.0, 2.0]), base_point=self.euc.random_point(), expected=5.0, ), dict( metric=self.sphere_metric, cotangent_vec_a=cotangent_vec_a, cotangent_vec_b=cotangent_vec_b, base_point=base_point, expected=7.0, ), ] return self.generate_tests(smoke_data) def hamiltonian_test_data(self): smoke_data = [ dict( metric=self.euc_metric, state=(gs.array([1.0, 2.0]), gs.array([1.0, 2.0])), expected=2.5, ) ] smoke_data += [ dict( metric=self.sphere_metric, state=(gs.array([0.0, 0.0, 1.0]), gs.array([1.0, 2.0, 1.0])), expected=3.0, ) ] return self.generate_tests(smoke_data) def inner_product_derivative_matrix_test_data(self): base_point = self.euc.random_point() random_data = [ dict( metric=self.new_euc_metric, base_point=base_point, expected=gs.zeros((self.dim, ) * 3), ) ] random_data += [ dict( metric=self.euc_metric, base_point=base_point, expected=gs.zeros((self.dim, ) * 3), ) ] return self.generate_tests([], random_data) def inner_product_test_data(self): base_point = self.euc.random_point() tangent_vec_a = self.euc.random_point() tangent_vec_b = self.euc.random_point() random_data = [ dict( metric=self.euc_metric, tangent_vec_a=tangent_vec_a, tangent_vec_b=tangent_vec_b, base_point=base_point, expected=gs.dot(tangent_vec_a, tangent_vec_b), ) ] smoke_data = [ dict( metric=self.new_sphere_metric, tangent_vec_a=gs.array([0.3, 0.4]), tangent_vec_b=gs.array([0.1, -0.5]), base_point=gs.array([gs.pi / 3.0, gs.pi / 5.0]), expected=-0.12, ) ] return self.generate_tests(smoke_data, random_data) def christoffels_test_data(self): base_point = gs.array([gs.pi / 10.0, gs.pi / 9.0]) gs.array([gs.pi / 10.0, gs.pi / 9.0]) smoke_data = [] random_data = [] smoke_data = [ dict( metric=self.new_sphere_metric, base_point=gs.array([gs.pi / 10.0, gs.pi / 9.0]), expected=self.sphere_metric.christoffels(base_point), ) ] random_data += [ dict( metric=self.new_euc_metric, base_point=self.euc.random_point(), expected=gs.zeros((self.dim, ) * 3), ) ] random_data += [ dict( metric=self.euc_metric, base_point=self.euc.random_point(), expected=gs.zeros((self.dim, ) * 3), ) ] return self.generate_tests(smoke_data, random_data) def exp_test_data(self): base_point = gs.array([gs.pi / 10.0, gs.pi / 9.0]) tangent_vec = gs.array([gs.pi / 2.0, 0.0]) expected = gs.array([gs.pi / 10.0 + gs.pi / 2.0, gs.pi / 9.0]) euc_base_point = self.euc.random_point() euc_tangent_vec = self.euc.random_point() euc_expected = euc_base_point + euc_tangent_vec smoke_data = [ dict( metric=self.new_sphere_metric, tangent_vec=tangent_vec, base_point=base_point, expected=expected, ) ] random_data = [ dict( metric=self.new_euc_metric, tangent_vec=euc_tangent_vec, base_point=euc_base_point, expected=euc_expected, ) ] return self.generate_tests(smoke_data, random_data) def log_test_data(self): base_point = self.euc.random_point() point = self.euc.random_point() expected = point - base_point random_data = [ dict( metric=self.new_euc_metric, point=point, base_point=base_point, expected=expected, ) ] return self.generate_tests([], random_data)
def __init__(self, dim): super(CategoricalMetric, self).__init__(dim=dim) self.sphere_metric = HypersphereMetric(dim)
class CategoricalMetric(RiemannianMetric): """Class for the Fisher information metric on categorical distributions. The Fisher information metric on the $n$-simplex of categorical distributions parameters can be obtained as the pullback metric of the $n$-sphere using the componentwise square root. References ---------- .. [K2003] R. E. Kass. The Geometry of Asymptotic Inference. Statistical Science, 4(3): 188 - 234, 1989. """ def __init__(self, dim): super(CategoricalMetric, self).__init__(dim=dim) self.sphere_metric = HypersphereMetric(dim) def metric_matrix(self, base_point=None): """Compute the inner-product matrix. Compute the inner-product matrix of the Fisher information metric at the tangent space at base point. Parameters ---------- base_point : array-like, shape=[..., dim + 1] Base point. Returns ------- mat : array-like, shape=[..., dim, dim] Inner-product matrix. """ if base_point is None: raise ValueError("A base point must be given to compute the " "metric matrix") base_point = gs.to_ndarray(base_point, to_ndim=2) mat = from_vector_to_diagonal_matrix(1 / base_point) return gs.squeeze(mat) @staticmethod def simplex_to_sphere(point): """Send point of the simplex to the sphere. The map takes the square root of each component. Parameters ---------- point : array-like, shape=[..., dim + 1] Point on the simplex. Returns ------- point_sphere : array-like, shape=[..., dim + 1] Point on the sphere. """ return point**(1 / 2) @staticmethod def sphere_to_simplex(point): """Send point of the sphere to the simplex. The map squares each component. Parameters ---------- point : array-like, shape=[..., dim + 1] Point on the sphere. Returns ------- point_simplex : array-like, shape=[..., dim + 1] Point on the simplex. """ return point**2 def tangent_simplex_to_sphere(self, tangent_vec, base_point): """Send tangent vector of the simplex to tangent space of sphere. This is the differential of the simplex_to_sphere map. Parameters ---------- tangent_vec : array-like, shape=[..., dim + 1] Tangent vec to the simplex at base point. base_point : array-like, shape=[..., dim + 1] Point of the simplex. Returns ------- tangent_vec_sphere : array-like, shape=[..., dim + 1] Tangent vec to the sphere at the image of base point by simplex_to_sphere. """ base_point = gs.to_ndarray(base_point, to_ndim=2) tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=2) tangent_vec_sphere = gs.einsum( "...i,...i->...i", tangent_vec, 1 / (2 * self.simplex_to_sphere(base_point))) return gs.squeeze(tangent_vec_sphere) @staticmethod def tangent_sphere_to_simplex(tangent_vec, base_point): """Send tangent vector of the sphere to tangent space of simplex. This is the differential of the sphere_to_simplex map. Parameters ---------- tangent_vec : array-like, shape=[..., dim + 1] Tangent vec to the sphere at base point. base_point : array-like, shape=[..., dim + 1] Point of the sphere. Returns ------- tangent_vec_simplex : array-like, shape=[..., dim + 1] Tangent vec to the simplex at the image of base point by sphere_to_simplex. """ base_point = gs.to_ndarray(base_point, to_ndim=2) tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=2) tangent_vec_simplex = gs.einsum("...i,...i->...i", tangent_vec, 2 * base_point) return gs.squeeze(tangent_vec_simplex) def exp(self, tangent_vec, base_point): """Compute the exponential map. Comute the exponential map associated to the Fisher information metric by pulling back the exponential map on the sphere by the simplex_to_sphere map. Parameters ---------- tangent_vec : array-like, shape=[..., dim + 1] Tangent vector at base point. base_point : array-like, shape=[..., dim + 1] Base point. Returns ------- exp : array-like, shape=[..., dim + 1] End point of the geodesic starting at base_point with initial velocity tangent_vec and stopping at time 1. """ base_point_sphere = self.simplex_to_sphere(base_point) tangent_vec_sphere = self.tangent_simplex_to_sphere( tangent_vec, base_point) exp_sphere = self.sphere_metric.exp(tangent_vec_sphere, base_point_sphere) return self.sphere_to_simplex(exp_sphere) def log(self, point, base_point): """Compute the logarithm map. Compute logarithm map associated to the Fisher information metric by pulling back the exponential map on the sphere by the simplex_to_sphere map. Parameters ---------- point : array-like, shape=[..., dim + 1] Point. base_point : array-like, shape=[..., dim + 1] Base po int. Returns ------- tangent_vec : array-like, shape=[..., dim + 1] Initial velocity of the geodesic starting at base_point and reaching point at time 1. """ point_sphere = self.simplex_to_sphere(point) base_point_sphere = self.simplex_to_sphere(base_point) log_sphere = self.sphere_metric.log(point_sphere, base_point_sphere) return self.tangent_sphere_to_simplex(log_sphere, base_point_sphere) def geodesic(self, initial_point, end_point=None, initial_tangent_vec=None): """Generate parameterized function for the geodesic curve. Geodesic curve defined by either: - an initial point and an initial tangent vector, - an initial point and an end point. Parameters ---------- initial_point : array-like, shape=[..., dim + 1] Point on the manifold, initial point of the geodesic. end_point : array-like, shape=[..., dim + 1] Point on the manifold, end point of the geodesic. Optional, default: None. If None, an initial tangent vector must be given. initial_tangent_vec : array-like, shape=[..., dim + 1], Tangent vector at base point, the initial speed of the geodesics. Optional, default: None. If None, an end point must be given and a logarithm is computed. Returns ------- path : callable Time parameterized geodesic curve. If a batch of initial conditions is passed, the output array's first dimension represents time, and the second corresponds to the different initial conditions. """ initial_point_sphere = self.simplex_to_sphere(initial_point) end_point_sphere = None vec_sphere = None if end_point is not None: end_point_sphere = self.simplex_to_sphere(end_point) if initial_tangent_vec is not None: vec_sphere = self.tangent_simplex_to_sphere( initial_tangent_vec, initial_point) geodesic_sphere = self.sphere_metric.geodesic(initial_point_sphere, end_point_sphere, vec_sphere) def path(t): """Generate parameterized function for geodesic curve. Parameters ---------- t : array-like, shape=[n_times,] Times at which to compute points of the geodesics. Returns ------- geodesic : array-like, shape=[..., n_times, dim + 1] Values of the geodesic at times t. """ geod_sphere_at_t = geodesic_sphere(t) geod_at_t = self.sphere_to_simplex(geod_sphere_at_t) return gs.squeeze(geod_at_t) return path
fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111, projection='3d') plot_sphere = Sphere() n_points = 10000 f_points = plot_sphere.fibonnaci_points(n_points).swapaxes(0, 1) plot_sphere.plot_heatmap(ax=ax, n_points=n_points, scalar_function=loss_f) correct_points = s_points[labels == 0][:30, :] correct_labels = np.ones_like(correct_points) ax = visualization.plot(correct_points, ax=ax, space='S2', color='red', s=80) f_labels = np.array(f_labels)[:, 0] f_points = f_points[f_labels != 0] metric = HypersphereMetric(dim=2) for k in range(len(correct_points)): point_matrix = correct_points[k:k+1, :].repeat(len(f_points), axis=0) dist_array = metric.dist(point_matrix, f_points) idx_min = np.argmin(dist_array) geodesic = sphere.metric.geodesic( initial_point=correct_points[k], end_point=f_points[idx_min]) points_on_geodesic = geodesic(gs.linspace(0., 1., 10)) plot_sphere.add_points(points_on_geodesic) plot_sphere.draw_points(ax=ax, color='black', alpha=0.1) plt.show()
class TestRiemannianMetric(geomstats.tests.TestCase): def setUp(self): warnings.simplefilter("ignore", category=UserWarning) gs.random.seed(0) self.dim = 2 self.euc = Euclidean(dim=self.dim) self.sphere = Hypersphere(dim=self.dim) self.euc_metric = EuclideanMetric(dim=self.dim) self.sphere_metric = HypersphereMetric(dim=self.dim) def _euc_metric_matrix(base_point): """Return matrix of Euclidean inner-product.""" dim = base_point.shape[-1] return gs.eye(dim) def _sphere_metric_matrix(base_point): """Return sphere's metric in spherical coordinates.""" theta = base_point[..., 0] mat = gs.array([[1.0, 0.0], [0.0, gs.sin(theta) ** 2]]) return mat new_euc_metric = RiemannianMetric(dim=self.dim) new_euc_metric.metric_matrix = _euc_metric_matrix new_sphere_metric = RiemannianMetric(dim=self.dim) new_sphere_metric.metric_matrix = _sphere_metric_matrix self.new_euc_metric = new_euc_metric self.new_sphere_metric = new_sphere_metric def test_cometric_matrix(self): base_point = self.euc.random_point() result = self.euc_metric.metric_inverse_matrix(base_point) expected = gs.eye(self.dim) self.assertAllClose(result, expected) @geomstats.tests.autograd_and_torch_only def test_metric_derivative_euc_metric(self): base_point = self.euc.random_point() result = self.euc_metric.inner_product_derivative_matrix(base_point) expected = gs.zeros((self.dim,) * 3) self.assertAllClose(result, expected) @geomstats.tests.autograd_and_torch_only def test_metric_derivative_new_euc_metric(self): base_point = self.euc.random_point() result = self.new_euc_metric.inner_product_derivative_matrix(base_point) expected = gs.zeros((self.dim,) * 3) self.assertAllClose(result, expected) def test_inner_product_new_euc_metric(self): base_point = self.euc.random_point() tan_a = self.euc.random_point() tan_b = self.euc.random_point() expected = gs.dot(tan_a, tan_b) result = self.new_euc_metric.inner_product(tan_a, tan_b, base_point=base_point) self.assertAllClose(result, expected) def test_inner_product_new_sphere_metric(self): base_point = gs.array([gs.pi / 3.0, gs.pi / 5.0]) tan_a = gs.array([0.3, 0.4]) tan_b = gs.array([0.1, -0.5]) expected = -0.12 result = self.new_sphere_metric.inner_product( tan_a, tan_b, base_point=base_point ) self.assertAllClose(result, expected) @geomstats.tests.autograd_and_torch_only def test_christoffels_eucl_metric(self): base_point = self.euc.random_point() result = self.euc_metric.christoffels(base_point) expected = gs.zeros((self.dim,) * 3) self.assertAllClose(result, expected) @geomstats.tests.autograd_and_torch_only def test_christoffels_new_eucl_metric(self): base_point = self.euc.random_point() result = self.new_euc_metric.christoffels(base_point) expected = gs.zeros((self.dim,) * 3) self.assertAllClose(result, expected) @geomstats.tests.autograd_tf_and_torch_only def test_christoffels_sphere_metrics(self): base_point = gs.array([gs.pi / 10.0, gs.pi / 9.0]) expected = self.sphere_metric.christoffels(base_point) result = self.new_sphere_metric.christoffels(base_point) self.assertAllClose(result, expected) @geomstats.tests.autograd_and_torch_only def test_exp_new_eucl_metric(self): base_point = self.euc.random_point() tan = self.euc.random_point() expected = base_point + tan result = self.new_euc_metric.exp(tan, base_point) self.assertAllClose(result, expected) @geomstats.tests.autograd_and_torch_only def test_log_new_eucl_metric(self): base_point = self.euc.random_point() point = self.euc.random_point() expected = point - base_point result = self.new_euc_metric.log(point, base_point) self.assertAllClose(result, expected) @geomstats.tests.autograd_tf_and_torch_only def test_exp_new_sphere_metric(self): base_point = gs.array([gs.pi / 10.0, gs.pi / 9.0]) tan = gs.array([gs.pi / 2.0, 0.0]) expected = gs.array([gs.pi / 10.0 + gs.pi / 2.0, gs.pi / 9.0]) result = self.new_sphere_metric.exp(tan, base_point) self.assertAllClose(result, expected)