Esempio n. 1
0
def test_mean_shift_negative_bandwidth():
    bandwidth = -1
    ms = MeanShift(bandwidth=bandwidth)
    msg = (r"bandwidth needs to be greater than zero or None,"
           r" got -1\.000000")
    with pytest.raises(ValueError, match=msg):
        ms.fit(X)
Esempio n. 2
0
def test_cluster_intensity_tie():
    X = np.array([[1, 1], [2, 1], [1, 0],
                  [4, 7], [3, 5], [3, 6]])
    c1 = MeanShift(bandwidth=2).fit(X)

    X = np.array([[4, 7], [3, 5], [3, 6],
                  [1, 1], [2, 1], [1, 0]])
    c2 = MeanShift(bandwidth=2).fit(X)
    assert_array_equal(c1.labels_, [1, 1, 1, 0, 0, 0])
    assert_array_equal(c2.labels_, [0, 0, 0, 1, 1, 1])
Esempio n. 3
0
def test_mean_shift(bandwidth, cluster_all, expected, first_cluster_label):
    # Test MeanShift algorithm
    ms = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all)
    labels = ms.fit(X).labels_
    labels_unique = np.unique(labels)
    n_clusters_ = len(labels_unique)
    assert n_clusters_ == expected
    assert labels_unique[0] == first_cluster_label

    cluster_centers, labels_mean_shift = mean_shift(X, cluster_all=cluster_all)
    labels_mean_shift_unique = np.unique(labels_mean_shift)
    n_clusters_mean_shift = len(labels_mean_shift_unique)
    assert n_clusters_mean_shift == expected
    assert labels_mean_shift_unique[0] == first_cluster_label
Esempio n. 4
0
def test_parallel():
    centers = np.array([[1, 1], [-1, -1], [1, -1]]) + 10
    X, _ = make_blobs(n_samples=50, n_features=2, centers=centers,
                      cluster_std=0.4, shuffle=True, random_state=11)

    ms1 = MeanShift(n_jobs=2)
    ms1.fit(X)

    ms2 = MeanShift()
    ms2.fit(X)

    assert_array_almost_equal(ms1.cluster_centers_, ms2.cluster_centers_)
    assert_array_equal(ms1.labels_, ms2.labels_)
Esempio n. 5
0
import numpy as np
from mrex.cluster import MeanShift, estimate_bandwidth
from mrex.datasets.samples_generator import make_blobs

# #############################################################################
# Generate sample data
centers = [[1, 1], [-1, -1], [1, -1]]
X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6)

# #############################################################################
# Compute clustering with MeanShift

# The following bandwidth can be automatically detected using
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)

ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_

labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)

print("number of estimated clusters : %d" % n_clusters_)

# #############################################################################
# Plot result
import matplotlib.pyplot as plt
from itertools import cycle

plt.figure(1)
Esempio n. 6
0
def test_meanshift_all_orphans():
    # init away from the data, crash with a sensible warning
    ms = MeanShift(bandwidth=0.1, seeds=[[-9, -9], [-10, -10]])
    msg = "No point was within bandwidth=0.1"
    assert_raise_message(ValueError, msg, ms.fit, X,)
Esempio n. 7
0
def test_meanshift_predict():
    # Test MeanShift.predict
    ms = MeanShift(bandwidth=1.2)
    labels = ms.fit_predict(X)
    labels2 = ms.predict(X)
    assert_array_equal(labels, labels2)
Esempio n. 8
0
def test_unfitted():
    # Non-regression: before fit, there should be not fitted attributes.
    ms = MeanShift()
    assert not hasattr(ms, "cluster_centers_")
    assert not hasattr(ms, "labels_")