def square_root_velocity(self, curve): """Compute the square root velocity representation of a curve. The velocity is computed using the log map. The case of several curves is handled through vectorization. In that case, an index selection procedure allows to get rid of the log between the end point of curve[k, :, :] and the starting point of curve[k + 1, :, :]. Parameters ---------- curve : Returns ------- srv : """ curve = gs.to_ndarray(curve, to_ndim=3) n_curves, n_sampling_points, n_coords = curve.shape srv_shape = (n_curves, n_sampling_points - 1, n_coords) curve = gs.reshape(curve, (n_curves * n_sampling_points, n_coords)) coef = gs.cast(gs.array(n_sampling_points - 1), gs.float32) velocity = coef * self.ambient_metric.log(point=curve[1:, :], base_point=curve[:-1, :]) velocity_norm = self.ambient_metric.norm(velocity, curve[:-1, :]) srv = velocity / gs.sqrt(velocity_norm) index = gs.arange(n_curves * n_sampling_points - 1) mask = ~gs.equal((index + 1) % n_sampling_points, 0) index_select = gs.gather(index, gs.squeeze(gs.where(mask))) srv = gs.reshape(gs.gather(srv, index_select), srv_shape) return srv
def online_kmeans(X, metric, n_clusters, n_repetitions=20, tolerance=1e-5, n_max_iterations=5e4): """Perform online K-means clustering. Perform online version of k-means algorithm on data contained in X. The data points are treated sequentially and the cluster centers are updated one at a time. This version of k-means avoids computing the mean of each cluster at each iteration and is therefore less computationally intensive than the offline version. In the setting of quantization of probability distributions, this algorithm is also known as Competitive Learning Riemannian Quantization. It computes the closest approximation of the empirical distribution of data by a discrete distribution supported by a smaller number of points with respect to the Wasserstein distance. This smaller number of points is n_clusters. Parameters ---------- X : array-like, shape=[n_samples, n_features] Input data. It is treated sequentially by the algorithm, i.e. one datum is chosen randomly at each iteration. metric : object Metric of the space in which the data lives. At each iteration, one of the cluster centers is moved in the direction of the new datum, according the exponential map of the underlying space, which is a method of metric. n_clusters : int Number of clusters of the k-means clustering, or number of desired atoms of the quantized distribution. n_repetitions : int, default=20 The cluster centers are updated using decreasing step sizes, each of which stays constant for n_repetitions iterations to allow a better exploration of the data points. n_max_iterations : int, default=5e4 Maximum number of iterations. If it is reached, the quantization may be inacurate. Returns ------- cluster_centers : array, shape=[n_clusters, n_features] Coordinates of cluster centers. labels : array, shape=[n_samples] Cluster labels for each point. """ n_samples = X.shape[0] random_indices = gs.random.randint(low=0, high=n_samples, size=(n_clusters, )) cluster_centers = gs.gather(X, gs.cast(random_indices, gs.int32), axis=0) gap = 1.0 iteration = 0 while iteration < n_max_iterations: iteration += 1 step_size = gs.floor(gs.array(iteration / n_repetitions)) + 1 random_index = gs.random.randint(low=0, high=n_samples, size=(1, )) point = gs.gather(X, gs.cast(random_index, gs.int32), axis=0) index_to_update = metric.closest_neighbor_index(point, cluster_centers) center_to_update = gs.copy( gs.gather(cluster_centers, index_to_update, axis=0)) tangent_vec_update = metric.log( point=point, base_point=center_to_update) / (step_size + 1) new_center = metric.exp(tangent_vec=tangent_vec_update, base_point=center_to_update) gap = metric.dist(center_to_update, new_center) if gap == 0 and iteration == 1: gap = gs.array(1.0) cluster_centers[index_to_update, :] = new_center if gs.isclose(gap, 0, atol=tolerance): break if iteration == n_max_iterations - 1: print('Maximum number of iterations {} reached. The' 'clustering may be inaccurate'.format(n_max_iterations)) labels = gs.zeros(n_samples) for i in range(n_samples): labels[i] = int(metric.closest_neighbor_index(X[i], cluster_centers)) return cluster_centers, labels