Exemplo n.º 1
0
    def test_fit(self):
        X = self.data
        clustering = OnlineKMeans(metric=self.metric, n_clusters=1,
                                  n_repetitions=10)
        clustering.fit(X)

        center = clustering.cluster_centers_
        mean = self.metric.mean(X)
        result = self.metric.dist(center, mean)
        expected = 0.
        self.assertAllClose(expected, result, atol=TOLERANCE)
Exemplo n.º 2
0
    def test_predict(self):
        X = self.data
        clustering = OnlineKMeans(
            metric=self.metric, n_clusters=3, n_repetitions=1)
        clustering.fit(X)

        point = self.data[0, :]
        prediction = clustering.predict(point)

        result = prediction
        expected = clustering.labels_[0]
        self.assertAllClose(expected, result)
Exemplo n.º 3
0
    def test_fit(self):
        X = self.data
        clustering = OnlineKMeans(metric=self.metric,
                                  n_clusters=1,
                                  n_repetitions=10)
        clustering.fit(X)

        center = clustering.cluster_centers_
        mean = FrechetMean(metric=self.metric, lr=1.)
        mean.fit(X)

        result = self.metric.dist(center, mean.estimate_)
        expected = 0.
        self.assertAllClose(expected, result, atol=1e-3)
Exemplo n.º 4
0
def main():
    sphere = Hypersphere(dimension=2)

    data = sphere.random_von_mises_fisher(kappa=10, n_samples=1000)

    n_clusters = 4
    clustering = OnlineKMeans(metric=sphere.metric, n_clusters=n_clusters)
    clustering = clustering.fit(data)

    plt.figure(0)
    ax = plt.subplot(111, projection="3d")
    visualization.plot(points=clustering.cluster_centers_,
                       ax=ax,
                       space='S2',
                       c='r')
    plt.show()

    plt.figure(1)
    ax = plt.subplot(111, projection="3d")
    sphere_plot = visualization.Sphere()
    sphere_plot.draw(ax=ax)
    for i in range(n_clusters):
        cluster = data[clustering.labels_ == i, :]
        sphere_plot.draw_points(ax=ax, points=cluster)
    plt.show()
Exemplo n.º 5
0
def main():
    circle = Hypersphere(dimension=1)

    data = circle.random_uniform(n_samples=1000)

    n_clusters = 5
    clustering = OnlineKMeans(metric=circle.metric, n_clusters=n_clusters)
    clustering = clustering.fit(data)

    plt.figure(0)
    visualization.plot(points=clustering.cluster_centers_, space='S1',
                       color='red')
    plt.show()

    plt.figure(1)
    ax = plt.axes()
    circle_plot = visualization.Circle()
    circle_plot.draw(ax=ax)
    for i in range(n_clusters):
        cluster = data[clustering.labels_ == i, :]
        circle_plot.draw_points(ax=ax, points=cluster)
    plt.show()