-
Notifications
You must be signed in to change notification settings - Fork 0
/
prof_kmeans.py
54 lines (43 loc) · 2 KB
/
prof_kmeans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import utils
import numpy as np
from scipy.cluster.vq import kmeans
from sklearn.cluster import k_means
from sklearn.cluster.k_means_ import _labels_inertia
from sklearn.datasets import make_blobs
# Having unit standard deviations is very important for a fair
# test between scipy and sklearn: stopping threshold is computed
# differently in both implementations (and are identical if all variances
# are equal to one.
n_clusters = 10
random_state = 4
n_samples, n_features = 100000, 200
#n_samples, n_features = 200, 10000
X = make_blobs(n_samples, n_features, centers=n_clusters, random_state=0)[0]
## n_samples, n_features = 10000, 200
## n_samples, n_features = 200, 10000
## n_samples, n_features = 200, 100000
## X = np.random.normal(size=(n_samples, n_features))
tol = 1e-4
## print("\n-- scipy.cluster.vq")
## ratio = 1.
## np.random.seed(random_state)
## sc, _ = utils.timeit(profile(kmeans))(X, n_clusters, iter=2,
## thresh=tol / ratio)
## ## utils.cache_value(sc, 'prof_kmeans/scipy_kmeans_%d_%d'
## ## % (n_samples, n_features))
## inertia1 = _labels_inertia(X, (X ** 2).sum(axis=-1), sc)[1]
## print('scipy inertia: %.1f' % np.sqrt(inertia1))
print("\n-- sklearn.cluster")
ratio = 1. #np.mean(np.var(X, axis=0)) # just to make the comparison fair.
np.random.seed(random_state)
sk, _, _ = utils.timeit(profile(k_means))(X, n_clusters, n_init=2,
tol=tol / ratio,
init="random",
random_state=random_state)
## utils.cache_value(sk, 'prof_kmeans/sklearn_kmeans_%d_%d' %
## (n_samples, n_features))
inertia2 = _labels_inertia(X, (X ** 2).sum(axis=-1), sk)[1]
print('inertia: %.1f' % np.sqrt(inertia2))
## print ('\nsklearn - scipy inertia: %.1f. Relative variation: %.1e' %
## ((inertia2 - inertia1), (inertia2 - inertia1) / (
## 2. * (inertia1 + inertia2))))