def test_mean_shift_negative_bandwidth(): bandwidth = -1 ms = MeanShift(bandwidth=bandwidth) msg = (r"bandwidth needs to be greater than zero or None," r" got -1\.000000") with pytest.raises(ValueError, match=msg): ms.fit(X)
def test_cluster_intensity_tie(): X = np.array([[1, 1], [2, 1], [1, 0], [4, 7], [3, 5], [3, 6]]) c1 = MeanShift(bandwidth=2).fit(X) X = np.array([[4, 7], [3, 5], [3, 6], [1, 1], [2, 1], [1, 0]]) c2 = MeanShift(bandwidth=2).fit(X) assert_array_equal(c1.labels_, [1, 1, 1, 0, 0, 0]) assert_array_equal(c2.labels_, [0, 0, 0, 1, 1, 1])
def test_mean_shift(bandwidth, cluster_all, expected, first_cluster_label): # Test MeanShift algorithm ms = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all) labels = ms.fit(X).labels_ labels_unique = np.unique(labels) n_clusters_ = len(labels_unique) assert n_clusters_ == expected assert labels_unique[0] == first_cluster_label cluster_centers, labels_mean_shift = mean_shift(X, cluster_all=cluster_all) labels_mean_shift_unique = np.unique(labels_mean_shift) n_clusters_mean_shift = len(labels_mean_shift_unique) assert n_clusters_mean_shift == expected assert labels_mean_shift_unique[0] == first_cluster_label
def test_parallel(): centers = np.array([[1, 1], [-1, -1], [1, -1]]) + 10 X, _ = make_blobs(n_samples=50, n_features=2, centers=centers, cluster_std=0.4, shuffle=True, random_state=11) ms1 = MeanShift(n_jobs=2) ms1.fit(X) ms2 = MeanShift() ms2.fit(X) assert_array_almost_equal(ms1.cluster_centers_, ms2.cluster_centers_) assert_array_equal(ms1.labels_, ms2.labels_)
import numpy as np from mrex.cluster import MeanShift, estimate_bandwidth from mrex.datasets.samples_generator import make_blobs # ############################################################################# # Generate sample data centers = [[1, 1], [-1, -1], [1, -1]] X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6) # ############################################################################# # Compute clustering with MeanShift # The following bandwidth can be automatically detected using bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500) ms = MeanShift(bandwidth=bandwidth, bin_seeding=True) ms.fit(X) labels = ms.labels_ cluster_centers = ms.cluster_centers_ labels_unique = np.unique(labels) n_clusters_ = len(labels_unique) print("number of estimated clusters : %d" % n_clusters_) # ############################################################################# # Plot result import matplotlib.pyplot as plt from itertools import cycle plt.figure(1)
def test_meanshift_all_orphans(): # init away from the data, crash with a sensible warning ms = MeanShift(bandwidth=0.1, seeds=[[-9, -9], [-10, -10]]) msg = "No point was within bandwidth=0.1" assert_raise_message(ValueError, msg, ms.fit, X,)
def test_meanshift_predict(): # Test MeanShift.predict ms = MeanShift(bandwidth=1.2) labels = ms.fit_predict(X) labels2 = ms.predict(X) assert_array_equal(labels, labels2)
def test_unfitted(): # Non-regression: before fit, there should be not fitted attributes. ms = MeanShift() assert not hasattr(ms, "cluster_centers_") assert not hasattr(ms, "labels_")