def calculate_distances(data, centroids, data_dots): centroid_dots = np.square(np.linalg.norm(centroids, ord=2, axis=1)) pairwise_distances = (data_dots[:, np.newaxis] + centroid_dots[np.newaxis, :]) # ||x-y||^2 = ||x||^2 + ||y||^2 - 2 x . y # pairwise_distances has ||x||^2 + ||y||^2, so beta = 1 # The gemm calculates x.y for all x and y, so alpha = -2.0 pairwise_distances -= 2.0 * np.dot(data, centroids.T) return pairwise_distances
def run_kmeans(C, D, T, I, N, S, benchmarking): # noqa: E741 print("Running kmeans...") print("Number of data points: " + str(N)) print("Number of dimensions: " + str(D)) print("Number of centroids: " + str(C)) print("Max iterations: " + str(I)) start = datetime.datetime.now() data, centroids = initialize(N, D, C, T) data_dots = np.square(np.linalg.norm(data, ord=2, axis=1)) data_index = np.linspace(0, N - 1, N, dtype=np.int) labels = None iteration = 0 prior_distance_sum = None # We run for max iterations or until we converge # We only test convergence every S iterations while iteration < I: pairwise_distances = calculate_distances(data, centroids, data_dots) new_labels, distances = relabel(pairwise_distances, data_index) distance_sum = np.sum(distances) centroids = find_centroids(data, new_labels, C, D) if iteration > 0 and iteration % S == 0: changes = np.not_equal(labels, new_labels) total_changes = np.sum(changes) delta = distance_sum / prior_distance_sum print("Iteration " + str(iteration) + " produced " + str(total_changes) + " changes, and total distance is " + str(distance_sum)) # We ignore the result of the threshold test in the case that we # are running performance benchmarks to measure performance for a # certain number of iterations if delta > 1 - 0.000001 and not benchmarking: print("Threshold triggered, terminating iterations early") break prior_distance_sum = distance_sum labels = new_labels iteration += 1 # This final distance sum also synchronizes the results print("Final distance sum at iteration " + str(iteration) + ": " + str(prior_distance_sum)) stop = datetime.datetime.now() delta = stop - start total = delta.total_seconds() * 1000.0 print("Elapsed Time: " + str(total) + " ms") return total