def test(self, test_data, fit_model):
        """
        Method that compares test set clustering labels (i.e., A(X'), computed by
        :class:`reval.relative_validation.RelativeValidation.clust_method`) against
        the (permuted) labels obtained through the classification algorithm fitted to the training set
        (i.e., f(X'), computed by
        :class:`reval.relative_validation.RelativeValidation.class_method`).
        It returns the misclassification error, together with
        both clustering and classification labels.

        :param test_data: test dataset.
        :type test_data: ndarray, (n_samples, n_features)
        :param fit_model: fitted supervised model.
        :type fit_model: class
        :return: misclassification error, clustering and classification labels.
        :rtype: float, dictionary of ndarrays (n_samples,)
        """
        clustlab_ts = self.clust_method.fit_predict(test_data)  # A_k(X')
        if len([cl for cl in clustlab_ts if cl >= 0]) == 0:
            logging.info(
                f"No clusters found during testing with {self.clust_method}")
            return None
        classlab_ts = fit_model.predict(test_data)
        bestperm = kuhn_munkres_algorithm(classlab_ts,
                                          clustlab_ts)  # array of integers
        misclass = zero_one_loss(classlab_ts, bestperm)
        return misclass, bestperm
Пример #2
0
    def test_revaltraining(self):
        data_tr = np.array([[0] * 10, [1] * 10] * 20)
        misclass, model, labels = self.reval_cls.train(data_tr)
        self.assertSequenceEqual([misclass] + [labels.tolist()], [
            0.0,
            kuhn_munkres_algorithm(labels, np.array([1, 0] * 20)).tolist()
        ])
        self.assertEqual(type(model), type(self.s))

        noise = make_blobs(100, 1, centers=1, random_state=42)[0]
        self.assertEqual(self.reval_cls_new.train(noise), None)
Пример #3
0
def example_2():
    digits_dataset = load_digits()

    digits_data = digits_dataset['data']
    digits_target = digits_dataset['target']

    X_tr, X_ts, y_tr, y_ts = train_test_split(digits_data,
                                              digits_target,
                                              test_size=0.40,
                                              random_state=42,
                                              stratify=digits_target)

    transform = UMAP(n_components=2,
                     random_state=42,
                     n_neighbors=30,
                     min_dist=0.0)
    X_tr = transform.fit_transform(X_tr)
    X_ts = transform.transform(X_ts)

    s = KNeighborsClassifier(n_neighbors=30)
    c = KMeans()

    reval = FindBestClustCV(s=s, c=c, nfold=5, nclust_range=[2, 15], nrand=100)

    metrics, nclustbest, _ = reval.best_nclust(X_tr,
                                               iter_cv=10,
                                               strat_vect=y_tr)

    plot_metrics(metrics, title='Reval performance digits dataset')

    out = reval.evaluate(X_tr, X_ts, nclust=nclustbest)
    perm_lab = kuhn_munkres_algorithm(y_ts, out.test_cllab)

    print(f"Best number of clusters: {nclustbest}")
    print(f"Test set prediction ACC: " f"{1 - zero_one_loss(y_ts, perm_lab)}")
    print(f'AMI (true labels vs predicted labels) = '
          f'{adjusted_mutual_info_score(y_ts, out.test_cllab)}')
    print(f"Validation set normalized stability (misclassification):"
          f"{metrics['val'][nclustbest]}")
    print(f'Test set ACC = {out.test_acc} '
          f'(true labels vs predicted labels)')

    plt.figure(figsize=(6, 4))
    plt.scatter(X_ts[:, 0], X_ts[:, 1], c=y_ts, cmap='rainbow_r')
    plt.title("Test set true labels (digits dataset)")
    plt.show()

    plt.figure(figsize=(6, 4))
    plt.scatter(X_ts[:, 0], X_ts[:, 1], c=perm_lab, cmap='rainbow_r')
    plt.title("Test set clustering labels (digits dataset)")
    plt.show()
Пример #4
0
def example_1():
    data = make_blobs(1000, 2, 5, center_box=(-20, 20), random_state=42)
    plt.figure(figsize=(6, 4))
    plt.scatter(data[0][:, 0], data[0][:, 1], c=data[1], cmap='rainbow_r')
    plt.title("Blobs dataset (N=1000)")
    plt.show()

    X_tr, X_ts, y_tr, y_ts = train_test_split(data[0],
                                              data[1],
                                              test_size=0.30,
                                              random_state=42,
                                              stratify=data[1])

    classifier = KNeighborsClassifier(n_neighbors=5)
    clustering = KMeans()

    findbestclust = FindBestClustCV(nfold=10,
                                    nclust_range=[2, 7],
                                    s=classifier,
                                    c=clustering,
                                    nrand=100)
    metrics, nbest, _ = findbestclust.best_nclust(X_tr,
                                                  iter_cv=10,
                                                  strat_vect=y_tr)
    out = findbestclust.evaluate(X_tr, X_ts, nbest)

    perm_lab = kuhn_munkres_algorithm(y_ts, out.test_cllab)

    print(f"Best number of clusters: {nbest}")
    print(f"Test set prediction ACC: " f"{1 - zero_one_loss(y_ts, perm_lab)}")
    print(f'AMI (true labels vs predicted labels) = '
          f'{adjusted_mutual_info_score(y_ts, out.test_cllab)}')
    print(f"Validation set normalized stability (misclassification):"
          f"{metrics['val'][nbest]}")
    print(f'Test set ACC = {out.test_acc} '
          f'(true labels vs predicted labels)')

    plot_metrics(metrics,
                 title="Reval performance blobs dataset",
                 legend_loc=2)

    plt.figure(figsize=(6, 4))
    plt.scatter(X_ts[:, 0], X_ts[:, 1], c=y_ts, cmap='rainbow_r')
    plt.title("Test set true labels (blobs dataset)")
    plt.show()

    plt.figure(figsize=(6, 4))
    plt.scatter(X_ts[:, 0], X_ts[:, 1], c=perm_lab, cmap='rainbow_r')
    plt.title("Test set clustering labels (blobs dataset)")
    plt.show()
Пример #5
0
    def _rescale_score_(self, xtr, xts, randlabtr, labts):
        """
        Private method that computes the misclassification error when predicting test labels
        with classification model fitted on training set with random labels.

        :param xtr: training dataset
        :type xtr: ndarray, (n_samples, n_features)
        :param xts: test dataset
        :type xts: ndarray, (n_samples, n_features)
        :param randlabtr: random labels
        :type randlabtr: ndarray, (n_samples,)
        :param labts: test set labels
        :type labts: ndarray, (n_samples,)
        :return: misclassification error
        :rtype: float
        """
        self.class_method.fit(xtr, randlabtr)
        pred_lab = self.class_method.predict(xts)
        me_ts = zero_one_loss(pred_lab, kuhn_munkres_algorithm(pred_lab, labts))
        return me_ts
Пример #6
0
def test_kuhn_munkres_exceptions(true_lab, pred_lab, exception):
    with pytest.raises(exception):
        kuhn_munkres_algorithm(true_lab, pred_lab)
Пример #7
0
def test_kuhn_munkres_algorithm(true_lab, pred_lab, expected_type,
                                expected_output):
    out = kuhn_munkres_algorithm(true_lab, pred_lab)
    assert out.dtype == expected_type
    np.testing.assert_array_equal(out, expected_output)
Пример #8
0
 def test_khun_munkres_algorithm(self):
     true_lab = np.array([1, 1, 1, 0, 0, 0])
     pred_lab = np.array([0, 0, 0, 1, 1, 1])
     new_lab = kuhn_munkres_algorithm(true_lab, pred_lab)
     self.assertSequenceEqual(new_lab.tolist(), [1, 1, 1, 0, 0, 0])
Пример #9
0
                                          test_size=0.30, random_state=42,
                                          stratify=data2[1])

findbestclust = FindBestClustCV(nfold=10, nclust_range=list(range(2, 7)),
                                s=classifier, c=clustering, nrand=100)
metrics, nbest = findbestclust.best_nclust(data=X_tr, strat_vect=y_tr)
out = findbestclust.evaluate(X_tr, X_ts, nbest)

plot_metrics(metrics, title="Reval performance for synthetic dataset with 20 features")

plt.scatter(X_ts[:, 0], X_ts[:, 1],
            c=out.test_cllab, cmap='rainbow_r')
plt.title("Predicted labels for 20-feature dataset")

print(f'AMI test set = {adjusted_mutual_info_score(y_ts, out.test_cllab)}')
relabeling = kuhn_munkres_algorithm(y_ts, out.test_cllab)
print(f'ACC test set = {1 - zero_one_loss(y_ts, relabeling)}')

# Set seed for reproducible examples
np.random.seed(42)

# We generate three random samples from normal distributions
data1 = np.random.normal(-5, size=(100, 2))
data2 = np.random.normal(12, 2.5, size=(50, 2))
data3 = np.random.normal(6, 2.5, size=(50, 2))
data = np.append(data1, data2, axis=0)
data = np.append(data, data3, axis=0)

label = [0] * 100 + [1] * 50 + [2] * 50

plt.scatter(data[:, 0], data[:, 1],
                                nclust_range=list(range(2, 12)),
                                s=classifier,
                                c=clustering,
                                nrand=10,
                                n_jobs=1)

metrics, nbest = findbestclust.best_nclust(mnist_tr,
                                           iter_cv=10,
                                           strat_vect=label_tr)
out = findbestclust.evaluate(mnist_tr, mnist_ts, nbest)

plot_metrics(
    metrics,
    title="Relative clustering validation performance on MNIST dataset")

perm_lab = kuhn_munkres_algorithm(label_ts.astype(int), out.test_cllab)

plt.scatter(mnist_ts[:, 0],
            mnist_ts[:, 1],
            c=perm_lab,
            s=0.1,
            cmap='rainbow_r')
plt.title("Predicted labels for MNIST test set")

print(f"Best number of clusters: {nbest}")
print(f"Test set external ACC: "
      f"{1 - zero_one_loss(label_ts.astype(int), perm_lab)}")
print(f'AMI = {adjusted_mutual_info_score(label_ts.astype(int), perm_lab)}')
print(
    f"Validation set normalized stability (misclassification): {metrics['val'][nbest]}"
)
Пример #11
0
def example_3():
    # Classifiers
    knn = KNeighborsClassifier(n_neighbors=1, metric='euclidean')
    rf = RandomForestClassifier(n_estimators=100, random_state=42)
    svm = SVC(C=1, random_state=42)
    logreg = LogisticRegression(solver='liblinear', random_state=42)

    classifiers = [knn, logreg, svm, rf]

    # Clustering
    hc = AgglomerativeClustering()
    km = KMeans(random_state=42)
    sc = SpectralClustering(random_state=42)

    clustering = [hc, km, sc]

    # scaler = StandardScaler()
    transform = UMAP(n_neighbors=30, min_dist=0.0, random_state=42)

    # Import benchmark datasets
    uci_data = build_ucidatasets()
    # Run ensemble learning algorithm
    for data, name in zip(uci_data, uci_data._fields):
        nclass = len(np.unique(data['target']))
        logging.info(f"Processing dataset {name}")
        logging.info(f"Number of classes: {nclass}\n")
        X_tr, X_ts, y_tr, y_ts = train_test_split(data['data'],
                                                  data['target'],
                                                  test_size=0.40,
                                                  random_state=42,
                                                  stratify=data['target'])
        X_tr = transform.fit_transform(X_tr)
        X_ts = transform.transform(X_ts)
        for s in classifiers:
            if type(s) == type(svm):
                svm.gamma = 1 / data['data'].shape[0]
            for c in clustering:
                logging.info(
                    f"Clustering algorithm: {c} -- Classification algorithm {s}"
                )
                reval = FindBestClustCV(s=s,
                                        c=c,
                                        nfold=5,
                                        nclust_range=[2, nclass + 3],
                                        nrand=100)
                metrics, nclustbest, _ = reval.best_nclust(X_tr,
                                                           strat_vect=y_tr)

                out = reval.evaluate(X_tr, X_ts, nclust=nclustbest)
                perm_lab = kuhn_munkres_algorithm(y_ts, out.test_cllab)

                logging.info(f"Best number of clusters: {nclustbest}")
                logging.info(f"Test set prediction ACC: "
                             f"{1 - zero_one_loss(y_ts, perm_lab)}")
                logging.info(
                    f'AMI (true labels vs predicted labels) = '
                    f'{adjusted_mutual_info_score(y_ts, out.test_cllab)}')
                logging.info(
                    f"Validation set normalized stability (misclassification):"
                    f"{metrics['val'][nclustbest]}")
                logging.info(f'Test set ACC = {out.test_acc} '
                             f'(true labels vs predicted labels)\n')
        logging.info('*' * 100)
        logging.info('\n\n')
Пример #12
0
X_tr, X_ts, y_tr, y_ts = train_test_split(data[0],
                                          data[1],
                                          test_size=0.30,
                                          random_state=42,
                                          stratify=data[1])

findbestclust = FindBestClustCV(nfold=2,
                                nclust_range=list(range(2, 7)),
                                s=classifier,
                                c=clustering,
                                nrand=10)
metrics, nbest = findbestclust.best_nclust(X_tr, iter_cv=10, strat_vect=y_tr)
out = findbestclust.evaluate(X_tr, X_ts, nbest)
plot_metrics(metrics, title="Reval performance")

perm_lab = kuhn_munkres_algorithm(y_ts, out.test_cllab)

print(f"Best number of clusters: {nbest}")
print(f"Test set external ACC: " f"{1 - zero_one_loss(y_ts, perm_lab)}")
print(f'AMI = {adjusted_mutual_info_score(y_ts, out.test_cllab)}')
print(
    f"Validation set normalized stability (misclassification): {metrics['val'][nbest]}"
)
print(f'Test set ACC = {out.test_acc}')

plot_metrics(metrics, title="Reval performance")

plt.scatter(X_ts[:, 0], X_ts[:, 1], c=y_ts, cmap='rainbow_r')
plt.title("True labels for test set")

plt.scatter(X_ts[:, 0], X_ts[:, 1], c=perm_lab, cmap='rainbow_r')
def example1():
    # Generate dataset
    data = make_blobs(1000, 2, centers=5,
                      center_box=(-20, 20),
                      random_state=42)

    # Visualize dataset
    plt.figure(figsize=(6, 4))
    for i in range(5):
        plt.scatter(data[0][data[1] == i][:, 0],
                    data[0][data[1] == i][:, 1],
                    label=i, cmap='tab20')
    plt.title("Blobs dataset")
    # plt.savefig('./blobs.png', format='png')
    plt.show()

    # Create training and test sets
    X_tr, X_ts, y_tr, y_ts = train_test_split(data[0],
                                              data[1],
                                              test_size=0.30,
                                              random_state=42,
                                              stratify=data[1])

    # Initialize clustering and classifier
    classifier = KNeighborsClassifier(n_neighbors=15)
    clustering = KMeans()

    # Run relatve validation (repeated CV and testing)
    findbestclust = FindBestClustCV(nfold=2,
                                    nclust_range=list(range(2, 7, 1)),
                                    s=classifier,
                                    c=clustering,
                                    nrand=10,
                                    n_jobs=N_JOBS)
    metrics, nbest = findbestclust.best_nclust(X_tr, iter_cv=10, strat_vect=y_tr)
    out = findbestclust.evaluate(X_tr, X_ts, nclust=nbest)

    # Plot CV metrics
    plot_metrics(metrics, prob_lines=False)
    logging.info(f"Validation stability: {metrics['val'][nbest]}")
    perm_lab = kuhn_munkres_algorithm(y_ts, out.test_cllab)

    logging.info(f"Best number of clusters: {nbest}")
    logging.info(f'AMI (true labels vs predicted labels) for test set = '
                 f'{adjusted_mutual_info_score(y_ts, out.test_cllab)}')
    logging.info('\n\n')

    # Compute metrics
    logging.info("Metrics from true label comparisons on test set:")
    class_scores = compute_metrics(y_ts, perm_lab, perm=False)
    for k, val in class_scores.items():
        if k in ['F1', 'MCC']:
            logging.info(f"{k}, {val}")
    logging.info("\n\n")

    # Internal measures
    # SILHOUETTE
    logging.info("Silhouette score based selection")
    sil_score_tr, sil_best_tr, sil_label_tr = select_best(X_tr, clustering, silhouette_score,
                                                          select='max',
                                                          nclust_range=list(range(2, 7, 1)))
    sil_score_ts, sil_best_ts, sil_label_ts = select_best(X_ts, clustering, silhouette_score,
                                                          select='max',
                                                          nclust_range=list(range(2, 7, 1)))

    sil_eval = evaluate_best(X_ts, clustering, silhouette_score, sil_best_tr)

    logging.info(f"Best number of clusters (and scores) for tr/ts independent runs: "
                 f"{sil_best_tr}({sil_score_tr})/{sil_best_ts}({sil_score_ts})")
    logging.info(f"Test set evaluation {sil_eval}")
    logging.info(f'AMI (true labels vs clustering labels) training = '
                 f'{adjusted_mutual_info_score(y_tr, kuhn_munkres_algorithm(y_tr, sil_label_tr))}')
    logging.info(f'AMI (true labels vs clustering labels) test = '
                 f'{adjusted_mutual_info_score(y_ts, kuhn_munkres_algorithm(y_ts, sil_label_ts))}')
    logging.info('\n\n')

    # DAVIES-BOULDIN
    logging.info("Davies-Bouldin score based selection")
    db_score_tr, db_best_tr, db_label_tr = select_best(X_tr, clustering, davies_bouldin_score,
                                                       select='min', nclust_range=list(range(2, 7, 1)))
    db_score_ts, db_best_ts, db_label_ts = select_best(X_ts, clustering, davies_bouldin_score,
                                                       select='min', nclust_range=list(range(2, 7, 1)))

    db_eval = evaluate_best(X_ts, clustering, davies_bouldin_score, db_best_tr)

    logging.info(f"Best number of clusters (and scores) for tr/ts independent runs: "
                 f"{db_best_tr}({db_score_tr})/{db_best_ts}({db_score_ts})")
    logging.info(f"Test set evaluation {db_eval}")
    logging.info(f'AMI (true labels vs clustering labels) training = '
                 f'{adjusted_mutual_info_score(y_tr, kuhn_munkres_algorithm(y_tr, db_label_tr))}')
    logging.info(f'AMI (true labels vs clustering labels) test = '
                 f'{adjusted_mutual_info_score(y_ts, kuhn_munkres_algorithm(y_ts, db_label_ts))}')
    logging.info('\n\n')

    # Plot true vs predicted labels for test sets
    plt.figure(figsize=(6, 4))
    for i in range(5):
        plt.scatter(X_ts[y_ts == i][:, 0],
                    X_ts[y_ts == i][:, 1],
                    label=str(i),
                    cmap='tab20')
    plt.legend(loc=3)
    plt.title("Test set true labels")
    # plt.savefig('./blobs_true.png', format='png')
    plt.show()

    plt.figure(figsize=(6, 4))
    for i in range(5):
        plt.scatter(X_ts[perm_lab == i][:, 0],
                    X_ts[perm_lab == i][:, 1],
                    label=str(i),
                    cmap='tab20')
    plt.legend(loc=3)
    plt.title("Test set clustering labels")
    # plt.savefig('./blobs_clustering.png', format='png')
    plt.show()
def example2():
    mnist = fetch_openml('mnist_784', version=1)
    mnist.target = mnist.target.astype(int)

    X_tr, y_tr = mnist['data'][:60000], mnist.target[:60000]
    X_ts, y_ts = mnist['data'][60000::], mnist.target[60000::]
    transform = UMAP(n_components=2,
                     random_state=42,
                     n_neighbors=30,
                     min_dist=0.0)
    X_tr = transform.fit_transform(X_tr)
    X_ts = transform.transform(X_ts)

    s = KNeighborsClassifier(n_neighbors=30)
    c = hdbscan.HDBSCAN(min_samples=10,
                        min_cluster_size=200)

    reval = FindBestClustCV(s=s,
                            c=c,
                            nfold=2,
                            nrand=10,
                            n_jobs=N_JOBS)

    metrics, nclustbest, tr_lab = reval.best_nclust(X_tr, iter_cv=10, strat_vect=y_tr)

    plot_metrics(metrics)

    out = reval.evaluate(X_tr, X_ts, nclust=nclustbest, tr_lab=tr_lab)
    perm_lab = kuhn_munkres_algorithm(y_ts, out.test_cllab)
    logging.info(f"Validation stability: {metrics['val'][nclustbest]}")

    logging.info(f"Best number of clusters during CV: {nclustbest}")
    logging.info(f"Best number of clusters on test set: "
                 f"{len([lab for lab in np.unique(out.test_cllab) if lab >= 0])}")
    logging.info(f'AMI (true labels vs predicted labels) = '
                 f'{adjusted_mutual_info_score(y_ts, out.test_cllab)}')
    logging.info('\n\n')

    logging.info("Metrics from true label comparisons on test set:")
    class_scores = compute_metrics(y_ts, perm_lab)
    for k, val in class_scores.items():
        logging.info(f'{k}, {val}')
    logging.info('\n\n')

    # Visualization
    fig, ax = plt.subplots(figsize=(10, 8))
    scatter = ax.scatter(X_tr[:, 0],
                         X_tr[:, 1],
                         c=y_tr, cmap='rainbow_r',
                         s=0.1)
    legend = ax.legend(*scatter.legend_elements())
    ax.add_artist(legend)
    plt.title("Train set true labels (digits dataset)")
    plt.show()

    fig, ax = plt.subplots(figsize=(10, 8))
    scatter = ax.scatter(X_tr[:, 0],
                         X_tr[:, 1],
                         c=kuhn_munkres_algorithm(y_tr, tr_lab),
                         cmap='tab20',
                         s=0.1)
    legend = ax.legend(*scatter.legend_elements())
    ax.add_artist(legend)
    plt.title("Train set predicted labels (digits dataset)")
    plt.show()

    fig, ax = plt.subplots(figsize=(10, 8))
    scatter = ax.scatter(X_ts[:, 0],
                         X_ts[:, 1],
                         c=y_ts, cmap='tab20',
                         s=0.1)
    legend = ax.legend(*scatter.legend_elements())
    ax.add_artist(legend)
    plt.title("Test set true labels (digits dataset)")
    plt.show()

    fig, ax = plt.subplots(figsize=(10, 8))
    scatter = ax.scatter(X_ts[:, 0],
                         X_ts[:, 1],
                         s=0.1,
                         c=perm_lab, cmap='tab20')
    legend = ax.legend(*scatter.legend_elements())
    ax.add_artist(legend)
    plt.title("Test set clustering labels (digits dataset)")
    plt.show()

    # Internal measures
    # SILHOUETTE
    logging.info("Silhouette score based selection")
    sil_score_tr, sil_best_tr, sil_label_tr = select_best(X_tr, c, silhouette_score, select='max')
    sil_score_ts, sil_best_ts, sil_label_ts = select_best(X_ts, c, silhouette_score, select='max')
    logging.info(
        f"Best number of clusters (and scores) for tr/ts independent runs: "
        f"{sil_best_tr}({sil_score_tr})/{sil_best_ts}({sil_score_ts})")
    logging.info(f'AMI (true labels vs clustering labels) training = '
                 f'{adjusted_mutual_info_score(y_tr, kuhn_munkres_algorithm(y_tr, sil_label_tr))}')
    logging.info(f'AMI (true labels vs clustering labels) test = '
                 f'{adjusted_mutual_info_score(y_ts, kuhn_munkres_algorithm(y_ts, sil_label_ts))}')
    logging.info('\n\n')

    # DAVIES-BOULDIN
    logging.info("Davies-Bouldin score based selection")
    db_score_tr, db_best_tr, db_label_tr = select_best(X_tr, c, davies_bouldin_score,
                                                       select='min')
    db_score_ts, db_best_ts, db_label_ts = select_best(X_ts, c, davies_bouldin_score,
                                                       select='min')

    logging.info(
        f"Best number of clusters (and scores) for tr/ts independent runs: "
        f"{db_best_tr}({db_score_tr})/{db_best_ts}({db_score_ts})")
    logging.info(f'AMI (true labels vs clustering labels) training = '
                 f'{adjusted_mutual_info_score(y_tr, kuhn_munkres_algorithm(y_tr, db_label_tr))}')
    logging.info(f'AMI (true labels vs clustering labels) test = '
                 f'{adjusted_mutual_info_score(y_ts, kuhn_munkres_algorithm(y_ts, db_label_ts))}')
    logging.info('\n\n')

    # Visualization
    fig, ax = plt.subplots(figsize=(10, 8))
    scatter = ax.scatter(X_tr[:, 0],
                         X_tr[:, 1],
                         c=sil_label_tr, cmap='tab20',
                         s=0.1)
    legend = ax.legend(*scatter.legend_elements())
    ax.add_artist(legend)
    plt.title("Train set silhouette labels (digits dataset)")
    plt.show()

    fig, ax = plt.subplots(figsize=(10, 8))
    scatter = ax.scatter(X_ts[:, 0],
                         X_ts[:, 1],
                         c=sil_label_ts, cmap='tab20',
                         s=0.1)
    legend = ax.legend(*scatter.legend_elements())
    ax.add_artist(legend)
    legend = ax.legend(*scatter.legend_elements())
    ax.add_artist(legend)
    plt.title("Test set silhouette labels (digits dataset)")
    plt.show()

    fig, ax = plt.subplots(figsize=(10, 8))
    scatter = ax.scatter(X_tr[:, 0],
                         X_tr[:, 1],
                         c=db_label_tr, cmap='tab20',
                         s=0.1)
    legend = ax.legend(*scatter.legend_elements())
    ax.add_artist(legend)
    plt.title("Train set Davies-Bouldin labels (digits dataset)")
    plt.show()

    fig, ax = plt.subplots(figsize=(10, 8))
    scatter = ax.scatter(X_ts[:, 0],
                         X_ts[:, 1],
                         s=0.1,
                         c=db_label_ts, cmap='tab20')
    legend = ax.legend(*scatter.legend_elements())
    ax.add_artist(legend)
    plt.title("Test set Davies-Bouldin labels (digits dataset)")
    plt.show()