def fit(self, X, y): """Compute supervised k-means clustering. Parameters ---------- X : array-like of shape=(n_ts, sz, d) Time series dataset. y : array-like of shape=(n_ts,) Time series labels to fit. """ cls, self.labels_ = np.unique(y, return_inverse=True) self.n_clusters = len(cls) if self.metric_params is None: self.metric_params = {} self.gamma_sdtw = self.metric_params.get("gamma_sdtw", 1.) self.Xs_ = [] self.ys_ = [] centroids = [] for i in range(self.n_clusters): self.Xs_.append(to_time_series_dataset(X[self.labels_ == i, :, :])) self.ys_.append(self.labels_[self.labels_ == i]) if self.metric == 'euclidean': centroids.append(EuclideanBarycenter().fit(self.Xs_[i])) if self.metric == 'dtw': centroids.append(DTWBarycenterAveraging().fit(self.Xs_[i])) if self.metric == 'softdtw': centroids.append(SoftDTWBarycenter().fit(self.Xs_[i])) self.cluster_centers_ = np.stack([centroids]).squeeze() return self
def _update_centroids(self, X): for k in range(self.n_clusters): if self.metric == "dtw": self.cluster_centers_[k] = DTWBarycenterAveraging(max_iter=self.max_iter_barycenter, barycenter_size=None, init_barycenter=self.cluster_centers_[k], verbose=False).fit(X[self.labels_ == k]) elif self.metric == "softdtw": self.cluster_centers_[k] = SoftDTWBarycenter(max_iter=self.max_iter_barycenter, gamma=self.gamma_sdtw, init=self.cluster_centers_[k]).fit(X[self.labels_ == k]) else: self.cluster_centers_[k] = EuclideanBarycenter().fit(X[self.labels_ == k])
from tslearn.barycenters import EuclideanBarycenter, DTWBarycenterAveraging, SoftDTWBarycenter from tslearn.datasets import CachedDatasets numpy.random.seed(0) X_train, y_train, X_test, y_test = CachedDatasets().load_dataset("Trace") X = X_train[y_train == 2] plt.figure() plt.subplot(3, 1, 1) for ts in X: plt.plot(ts.ravel(), "k-", alpha=.2) plt.plot(EuclideanBarycenter().fit(X).ravel(), "r-", linewidth=2) plt.title("Euclidean barycenter") plt.subplot(3, 1, 2) dba = DTWBarycenterAveraging(max_iter=100, verbose=False) dba_bar = dba.fit(X) for ts in X: plt.plot(ts.ravel(), "k-", alpha=.2) plt.plot(dba_bar.ravel(), "r-", linewidth=2) plt.title("DBA") plt.subplot(3, 1, 3) sdtw = SoftDTWBarycenter(gamma=1., max_iter=100) sdtw_bar = sdtw.fit(X) for ts in X: plt.plot(ts.ravel(), "k-", alpha=.2) plt.plot(sdtw_bar.ravel(), "r-", linewidth=2) plt.title("Soft-DTW barycenter ($\gamma$=1.)") plt.tight_layout()