def main(): sphere = Hypersphere(dimension=2) data = sphere.random_von_mises_fisher(kappa=10, n_samples=1000) n_clusters = 4 clustering = OnlineKMeans(metric=sphere.metric, n_clusters=n_clusters) clustering = clustering.fit(data) plt.figure(0) ax = plt.subplot(111, projection="3d") visualization.plot(points=clustering.cluster_centers_, ax=ax, space='S2', c='r') plt.show() plt.figure(1) ax = plt.subplot(111, projection="3d") sphere_plot = visualization.Sphere() sphere_plot.draw(ax=ax) for i in range(n_clusters): cluster = data[clustering.labels_ == i, :] sphere_plot.draw_points(ax=ax, points=cluster) plt.show()
def test_fit(self): X = self.data clustering = OnlineKMeans(metric=self.metric, n_clusters=1, n_repetitions=10) clustering.fit(X) center = clustering.cluster_centers_ mean = self.metric.mean(X) result = self.metric.dist(center, mean) expected = 0. self.assertAllClose(expected, result, atol=TOLERANCE)
def test_predict(self): X = self.data clustering = OnlineKMeans( metric=self.metric, n_clusters=3, n_repetitions=1) clustering.fit(X) point = self.data[0, :] prediction = clustering.predict(point) result = prediction expected = clustering.labels_[0] self.assertAllClose(expected, result)
def test_fit(self): X = self.data clustering = OnlineKMeans(metric=self.metric, n_clusters=1, n_repetitions=10) clustering.fit(X) center = clustering.cluster_centers_ mean = FrechetMean(metric=self.metric, lr=1.) mean.fit(X) result = self.metric.dist(center, mean.estimate_) expected = 0. self.assertAllClose(expected, result, atol=1e-3)
def main(): circle = Hypersphere(dimension=1) data = circle.random_uniform(n_samples=1000) n_clusters = 5 clustering = OnlineKMeans(metric=circle.metric, n_clusters=n_clusters) clustering = clustering.fit(data) plt.figure(0) visualization.plot(points=clustering.cluster_centers_, space='S1', color='red') plt.show() plt.figure(1) ax = plt.axes() circle_plot = visualization.Circle() circle_plot.draw(ax=ax) for i in range(n_clusters): cluster = data[clustering.labels_ == i, :] circle_plot.draw_points(ax=ax, points=cluster) plt.show()