Ejemplo n.º 1
0
    def on_epoch_end(self, epoch, logs=None):
        if epoch % 10 == 0:
            # Define model by using input layer and encoder layer that produces latent feature
            feature_model = Model(
                self.model.input,
                self.model.get_layer('encoder_latent').output)

            # Get latent feature given input
            features = feature_model.predict(self.x)

            # Using kmeans to cluster the data
            km = KMeans(n_clusters=len(np.unique(self.y)), n_init=20, n_jobs=4)
            y_pred = km.fit_predict(features)

            # Logging results
            print(
                ' ' * 8 + '|==>  f1-score: %.4f <==|' %
                (genome_acc(self.groups, y_pred, self.y, self.n_clusters)[2]))


# class TqdmCallbackWithLog(tqdm.keras.TqdmCallback):
#     def __init__(self, x, y):
#         self.x = x
#         self.y = y
#         super(TqdmCallbackWithLog, self).__init__()

#     def on_epoch_end(self, epoch, logs=None):
#         if epoch % 10 == 0:
#             # Define model by using input layer and encoder layer that produces latent feature
#             feature_model = Model(self.model.input,
#                                     self.model.get_layer('encoder_latent').output)

#             # Get latent feature given input
#             features = feature_model.predict(self.x)

#             # Using kmeans to cluster the data
#             km = KMeans(n_clusters=len(np.unique(self.y)), n_init=20, n_jobs=4)
#             y_pred = km.fit_predict(features)

#             # Calculate metrics
#             precision, recall, f1 = genome_acc(grps, y_pred, self.y, n_clusters)

#             logs['precision'] = precision
#             logs['recall'] = recall
#             logs['f1-measure'] = f1

#             tqdm.tqdm.write(f'Precision: {precision}\nRecall: {recall}\nF1-score: {f1}')
Ejemplo n.º 2
0
    t0 = time.time()
    optim_lr = 0.01
    dec.compile(optimizer=SGD(optim_lr))
    y_pred = dec.fit(x=seed_kmer_features,
                     y=labels,
                     grps=groups,
                     tol=TOL,
                     maxiter=MAX_ITERS,
                     batch_size=BATCH_SIZE,
                     update_interval=UPDATE_INTERVAL,
                     save_dir=save_dir)

    if verbose:
        print('...')
    latent = dec.encoder.predict(seed_kmer_features)
    y_pred = dec.predict(seed_kmer_features)

    if verbose:
        print('Saving results...')
    store_results(groups,
                  seed_kmer_features,
                  latent,
                  labels,
                  y_pred,
                  n_clusters,
                  dataset_name,
                  save_dir=os.path.join(LOG_DIR, dataset_name))
    if verbose:
        print(f'Finish clustering for dataset {dataset_name}.')
        print('F1-score:', genome_acc(groups, y_pred, labels, n_clusters)[2])
        print('Clustering time: ', (time.time() - t0))
Ejemplo n.º 3
0
    def fit(self,
            x,
            y=None,
            grps=None,
            tol=1e-3,
            update_interval=140,
            maxiter=2e4,
            batch_size=4,
            save_dir='./results/idec'):

        best_model = self.model
        best_acc = 0.0
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        print('Update interval', update_interval)
        save_interval = int(x.shape[0] / batch_size) * 5  # 5 epochs
        if save_interval == 0:
            save_interval = 10
        print('Save interval', save_interval)

        # Step 1: initialize cluster centers using k-means
        t1 = time.time()
        print('Initializing cluster centers with k-means.')
        kmeans = KMeans(n_clusters=self.n_clusters, n_init=20)
        y_pred = kmeans.fit_predict(self.encoder.predict(x))
        y_pred_last = np.copy(y_pred)
        self.model.get_layer(name='clustering').set_weights(
            [kmeans.cluster_centers_])

        # Step 2: deep clustering
        # logging file
        import csv
        logfile = open(save_dir + '/idec_log.csv', 'w')
        logwriter = csv.DictWriter(logfile,
                                   fieldnames=[
                                       'iter', 'precision', 'recall',
                                       'f1_score', 'nmi', 'ari', 'L', 'Lc',
                                       'Lr'
                                   ])
        logwriter.writeheader()

        loss = [0, 0, 0]
        index = 0
        for ite in range(int(maxiter)):
            if ite % update_interval == 0:
                q, _ = self.model.predict(x, verbose=0)
                p = self.target_distribution(
                    q)  # update the auxiliary target distribution p

                # evaluate the clustering performance
                y_pred = q.argmax(1)
                delta_label = np.sum(y_pred != y_pred_last).astype(
                    np.float32) / y_pred.shape[0]
                y_pred_last = y_pred
                if y is not None:
                    # acc = np.round(cluster_acc(y, y_pred), 5)
                    # nmi = np.round(metrics.normalized_mutual_info_score(y, y_pred), 5)
                    # ari = np.round(metrics.adjusted_rand_score(y, y_pred), 5)
                    prec, recall, f1_score = genome_acc(
                        grps, y_pred, y, self.n_clusters)
                    if f1_score > best_acc:
                        best_acc = f1_score
                        best_model = self.model
                        # save IDEC model checkpoints
                        print('saving model to:',
                              save_dir + '/IDEC_model_' + str(ite) + '.h5')
                        self.model.save_weights(save_dir + '/IDEC_model_' +
                                                str(ite) + '.h5')
                    loss = np.round(loss, 5)
                    logdict = dict(iter=ite,
                                   precision=prec,
                                   recall=recall,
                                   f1_score=f1_score,
                                   L=loss[0],
                                   Lc=loss[1],
                                   Lr=loss[2])
                    logwriter.writerow(logdict)
                    print(
                        'Iter %d: precision = %.5f, recall = %.5f, f1_score = %.5f,\
                            nmi = --, ari = --' %
                        (ite, prec, recall, f1_score), ' ; loss=', loss)

                # check stop criterion
                if ite > 0 and delta_label < tol:
                    print('delta_label ', delta_label, '< tol ', tol)
                    print('Reached tolerance threshold. Stopping training.')
                    logfile.close()
                    break

            # train on batch
            if (index + 1) * batch_size > x.shape[0]:
                loss = self.model.train_on_batch(
                    x=x[index * batch_size::],
                    y=[p[index * batch_size::], x[index * batch_size::]])
                index = 0
            else:
                loss = self.model.train_on_batch(
                    x=x[index * batch_size:(index + 1) * batch_size],
                    y=[
                        p[index * batch_size:(index + 1) * batch_size],
                        x[index * batch_size:(index + 1) * batch_size]
                    ])
                index += 1

            # save intermediate model
            # if ite % save_interval == 0:
            #     # save IDEC model checkpoints
            #     print('saving model to:', save_dir + '/IDEC_model_' + str(ite) + '.h5')
            #     self.model.save_weights(save_dir + '/IDEC_model_' + str(ite) + '.h5')

            ite += 1

        # save the trained model
        logfile.close()
        print('saving model to:', save_dir + '/IDEC_model_final.h5')
        self.model.save_weights(save_dir + '/IDEC_model_final.h5')

        return y_pred
Ejemplo n.º 4
0
def cluster(model: ADEC,
            seeds,
            groups,
            label,
            batch_size,
            epochs=1000,
            save_interval=200,
            save_path='./images'):
    n_epochs = tqdm.tqdm_notebook(range(epochs))
    total_batches = seeds.shape[0] // batch_size
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    groups_label = get_group_label(groups, label)

    # Using kmeans result to initialize clusters for model
    latent = adec.encoder.predict(seeds)
    kmeans = KMeans(n_clusters=model.n_clusters, n_init=100)
    init_cluster_pred = kmeans.fit_predict(latent)
    _, _, init_f1 = genome_acc(groups, init_cluster_pred, label,
                               model.n_clusters)
    print('Initialized performance: ', init_f1)

    model.cluster.get_layer(name='clustering').set_weights(
        [kmeans.cluster_centers_])
    # Check model cluster performance
    non_wh_cluster_res = model.cluster.predict(seeds).argmax(1)
    _, _, init_f1_1 = genome_acc(groups, non_wh_cluster_res, label,
                                 model.n_clusters)
    print('Model cluster performance: ', init_f1_1)

    stop = False
    last_cluster_pred = np.copy(non_wh_cluster_res)
    for epoch in n_epochs:
        offset = 0
        losses = []

        if epoch % save_interval == 0 or (epoch == epochs - 1):
            # Save the visualization of latent space
            latent = model.encoder.predict(seeds)
            latent_space_img = visualize_latent_space(
                latent,
                groups_label,
                model.n_clusters,
                is_save=True,
                save_path=f'{save_path}/latent_{epoch}.png')

            # Log the clustering performance
            cluster_res = model.cluster.predict(seeds)
            y_pred = cluster_res.argmax(1)
            _, _, f1 = genome_acc(groups, y_pred, label, model.n_clusters)

            try:
                wandb.log({
                    'latent_space':
                    [wandb.Image(latent_space_img, caption="Latent space")],
                    'cluster_f1':
                    f1
                })
            except:
                print('cluster_f1: ', f1)

            delta_label = np.sum(y_pred != last_cluster_pred).astype(
                np.float32) / y_pred.shape[0]
            if epoch > 0 and delta_label < model.tol:
                stop = False
                break

            last_cluster_pred = np.copy(cluster_res)

            # Update target distribution
            targ_dist = model.target_distribution(last_cluster_pred)

        is_alternate = False
        for batch_iter in range(total_batches):
            # Randomly choose each half batch
            imgs = seeds[offset:offset + batch_size, :] if (
                batch_iter < (total_batches - 1)) else seeds[:batch_size, :]
            y_cluster = targ_dist[offset:offset + batch_size, :] if (
                batch_iter <
                (total_batches - 1)) else targ_dist[:batch_size, :]
            offset += batch_size

            if batch_iter < int(2 * total_batches / 3) and total_batches >= 3:
                is_alternate = True
            else:
                is_alternate = False

            loss = model.train_on_batch(imgs, y_cluster, is_alternate)
            losses.append(loss)

        avg_loss = avg_losses(losses)
        try:
            wandb.log({'clustering_losses': avg_loss})
        except:
            pass

        if stop:
            # Reach stop condition, stop training
            break