예제 #1
0
    def _fit(self, X: np.ndarray) -> np.ndarray:
        pred = self._cluster_and_decide(X)
        self.children: Tuple["DivisiveCluster"] = cast(
            Tuple["DivisiveCluster"], tuple())

        uni_labels = np.unique(pred)
        labels = pred.reshape((-1, 1)).copy()
        if len(uni_labels) > 1:
            for ul in uni_labels:
                inds = pred == ul
                new_X = X[inds]
                dc = DivisiveCluster(
                    cluster_method=self.cluster_method,
                    max_components=self.max_components,
                    min_split=self.min_split,
                    max_level=self.max_level,
                    cluster_kws=self.cluster_kws,
                    delta_criter=self.delta_criter,
                )
                dc.parent = self
                if (len(new_X) > self.max_components
                        and len(new_X) >= self.min_split
                        and self.depth + 1 < self.max_level):
                    child_labels = dc._fit(new_X)
                    while labels.shape[1] <= child_labels.shape[1]:
                        labels = np.column_stack(
                            (labels, np.zeros((len(X), 1), dtype=int)))
                    labels[inds, 1:child_labels.shape[1] + 1] = child_labels
                else:
                    # make a "GaussianMixture" model for clusters
                    # that were not fitted
                    if self.cluster_method == "gmm":
                        cluster_idx = len(dc.parent.children) - 1
                        parent_model = dc.parent.model_
                        model = GaussianMixture()
                        model.weights_ = np.array([1])
                        model.means_ = parent_model.means_[
                            cluster_idx].reshape(1, -1)
                        model.covariance_type = parent_model.covariance_type
                        if model.covariance_type == "tied":
                            model.covariances_ = parent_model.covariances_
                            model.precisions_ = parent_model.precisions_
                            model.precisions_cholesky_ = (
                                parent_model.precisions_cholesky_)
                        else:
                            cov_types = ["spherical", "diag", "full"]
                            n_features = model.means_.shape[-1]
                            cov_shapes = [
                                (1, ),
                                (1, n_features),
                                (1, n_features, n_features),
                            ]
                            cov_shape_idx = cov_types.index(
                                model.covariance_type)
                            model.covariances_ = parent_model.covariances_[
                                cluster_idx].reshape(cov_shapes[cov_shape_idx])
                            model.precisions_ = parent_model.precisions_[
                                cluster_idx].reshape(cov_shapes[cov_shape_idx])
                            model.precisions_cholesky_ = (
                                parent_model.precisions_cholesky_[cluster_idx].
                                reshape(cov_shapes[cov_shape_idx]))

                        dc.model_ = model

        return labels
예제 #2
0
    x1 = x1.dot(m)
    x2 = multivariate_normal(mean=(-1, 10), cov=cov1, size=N2)
    x = np.vstack((x1, x2))
    y = np.array([True] * 500 + [False] * 300)
    return x, y


if __name__ == '__main__':
    x, y = get_data()
    types = ('spherical', 'diag', 'tied', 'full')
    err = np.empty(len(types))
    bic = np.empty(len(types))

    for i, type in enumerate(types):
        model = GaussianMixture(n_components=2, random_state=0)
        model.covariance_type = type
        model.fit(x)
        y_hat = model.predict(x)
        accuracy = accuracy_score(y_hat.ravel(), y.ravel())
        if accuracy > 0.5:
            err[i] = 1 - accuracy
        else:
            err[i] = accuracy
        bic[i] = model.bic(x)
    print('错误率:', err.ravel())
    print('BIC:', bic.ravel())
    xpos = np.arange(4)
    plt.figure(facecolor='w')
    ax = plt.axes()
    b1 = ax.bar(xpos - 0.3, err, width=0.3, color='#77E0A0')
    b2 = ax.twinx().bar(xpos, bic, width=0.3, color='#FF8080')