Beispiel #1
0
def predict(spectral_net, x, y, n_clusters):
    x_spectralnet = spectral_net.predict(x)

    kmeans_assignments, km = get_cluster_sols(x_spectralnet, ClusterClass=KMeans, n_clusters=n_clusters, init_args={'n_init':10})
    y_spectralnet, _ = get_y_preds(kmeans_assignments, y, n_clusters)

    return x_spectralnet, y_spectralnet
Beispiel #2
0
    def fit(self, x, y):
        if self.cluster_obj is None:
            self.cluster_obj = self.ClusterClass(self.n_clusters,
                                                 **self.init_args)
            for _ in range(10):
                try:
                    self.cluster_obj.fit(x)
                    break
                except:
                    print("Unexpected error:", sys.exc_info())
            else:
                return np.zeros((len(x), ))

        cluster_assignments = self.cluster_obj.predict(x)
        _, confusion_matrix, kmeans_to_true_cluster_labels = get_y_preds(
            cluster_assignments, y, self.n_clusters)
        _, self.kmeans_to_true_cluster_labels = get_y_preds_from_cm(
            cluster_assignments, self.n_clusters, confusion_matrix)
        return self
def run_net(data, params):
    #
    # UNPACK DATA
    #

    x_train, y_train, x_val, y_val, x_test, y_test = data['spectral'][
        'train_and_test']
    x_train_unlabeled, y_train_unlabeled, x_train_labeled, y_train_labeled = data[
        'spectral']['train_unlabeled_and_labeled']
    x_val_unlabeled, y_val_unlabeled, x_val_labeled, y_val_labeled = data[
        'spectral']['val_unlabeled_and_labeled']

    if 'siamese' in params['affinity']:
        pairs_train, dist_train, pairs_val, dist_val = data['siamese'][
            'train_and_test']

    x = np.concatenate((x_train, x_val, x_test), axis=0)
    y = np.concatenate((y_train, y_val, y_test), axis=0)

    if len(x_train_labeled):
        y_train_labeled_onehot = OneHotEncoder().fit_transform(
            y_train_labeled.reshape(-1, 1)).toarray()
    else:
        y_train_labeled_onehot = np.empty((0, len(np.unique(y))))

    #
    # SET UP INPUTS
    #

    # create true y placeholder (not used in unsupervised training)
    y_true = tf.placeholder(tf.float32,
                            shape=(None, params['n_clusters']),
                            name='y_true')

    batch_sizes = {
        'Unlabeled': params['batch_size'],
        'Labeled': params['batch_size'],
        'Orthonorm': params.get('batch_size_orthonorm', params['batch_size']),
    }

    input_shape = x.shape[1:]

    # spectralnet has three inputs -- they are defined here
    inputs = {
        'Unlabeled': Input(shape=input_shape, name='UnlabeledInput'),
        'Labeled': Input(shape=input_shape, name='LabeledInput'),
        'Orthonorm': Input(shape=input_shape, name='OrthonormInput'),
    }

    #
    # DEFINE AND TRAIN SIAMESE NET
    #

    # run only if we are using a siamese network
    if params['affinity'] == 'siamese':
        siamese_net = networks.SiameseNet(inputs, params['arch'],
                                          params.get('siam_reg'), y_true)

        history = siamese_net.train(pairs_train, dist_train, pairs_val,
                                    dist_val, params['siam_lr'],
                                    params['siam_drop'],
                                    params['siam_patience'], params['siam_ne'],
                                    params['siam_batch_size'])

    else:
        siamese_net = None

    #
    # DEFINE AND TRAIN SPECTRALNET
    #

    spectral_net = networks.SpectralNet(inputs, params['arch'],
                                        params.get('spec_reg'), y_true,
                                        y_train_labeled_onehot,
                                        params['n_clusters'],
                                        params['affinity'],
                                        params['scale_nbr'], params['n_nbrs'],
                                        batch_sizes, siamese_net, x_train,
                                        len(x_train_labeled))

    spectral_net.train(x_train_unlabeled, x_train_labeled, x_val_unlabeled,
                       params['spec_lr'], params['spec_drop'],
                       params['spec_patience'], params['spec_ne'])

    print("finished training")

    #
    # EVALUATE
    #

    # get final embeddings
    x_spectralnet = spectral_net.predict(x)

    # get accuracy and nmi
    kmeans_assignments, km = get_cluster_sols(x_spectralnet,
                                              ClusterClass=KMeans,
                                              n_clusters=params['n_clusters'],
                                              init_args={'n_init': 10})
    y_spectralnet, _ = get_y_preds(kmeans_assignments, y, params['n_clusters'])
    print_accuracy(kmeans_assignments, y, params['n_clusters'])
    from sklearn.metrics import normalized_mutual_info_score as nmi
    nmi_score = nmi(kmeans_assignments, y)
    print('NMI: ' + str(np.round(nmi_score, 3)))

    if params['generalization_metrics']:
        x_spectralnet_train = spectral_net.predict(x_train_unlabeled)
        x_spectralnet_test = spectral_net.predict(x_test)
        km_train = KMeans(
            n_clusters=params['n_clusters']).fit(x_spectralnet_train)
        from scipy.spatial.distance import cdist
        dist_mat = cdist(x_spectralnet_test, km_train.cluster_centers_)
        closest_cluster = np.argmin(dist_mat, axis=1)
        print_accuracy(closest_cluster, y_test, params['n_clusters'],
                       ' generalization')
        nmi_score = nmi(closest_cluster, y_test)
        print('generalization NMI: ' + str(np.round(nmi_score, 3)))

    return x_spectralnet, y_spectralnet
Beispiel #4
0
def run_net(data, params):
    #
    # UNPACK DATA
    #

    x_train_unlabeled, y_train_unlabeled, x_val, y_val, x_test, y_test = data[
        'spectral']['train_and_test']

    print(params['input_shape'])
    inputs_vae = Input(shape=params['input_shape'], name='inputs_vae')
    ConvAE = Conv.ConvAE(inputs_vae, params)
    try:
        ConvAE.vae.load_weights('vae_mnist.h5')
    except OSError:
        print('No pretrained weights available...')

    lh = LearningHandler(lr=params['spec_lr'],
                         drop=params['spec_drop'],
                         lr_tensor=ConvAE.learning_rate,
                         patience=params['spec_patience'])

    lh.on_train_begin()

    n_epochs = 5000
    losses_vae = np.empty((n_epochs, ))
    homo_plot = np.empty((n_epochs, ))
    nmi_plot = np.empty((n_epochs, ))
    ari_plot = np.empty((n_epochs, ))

    y_val = np.squeeze(np.asarray(y_val).ravel())  # squeeze into 1D array

    start_time = time.time()
    for i in range(n_epochs):
        # if i==0:
        x_recon, _, x_val_y = ConvAE.vae.predict(x_val)
        losses_vae[i] = ConvAE.train_vae(x_val, x_val_y, params['batch_size'])
        #x_val_y = ConvAE.vae.predict(x_val)[2]
        #y_sp = x_val_y.argmax(axis=1)
        #print_accuracy(y_sp, y_val, params['n_clusters'])
        print("Epoch: {}, loss={:2f}".format(i, losses_vae[i]))

        os.makedirs('vae', exist_ok=True)
        os.makedirs('vae_umap', exist_ok=True)

        fig, axs = plt.subplots(3, 4, figsize=(25, 18))
        fig.subplots_adjust(wspace=0.25)

        embedding = ConvAE.encoder.predict(x_val)
        kmeans = KMeans(n_clusters=params['n_clusters'], n_init=30)
        predicted_labels = kmeans.fit_predict(
            embedding)  # cluster on current embeddings for metric eval
        _, confusion_matrix = get_y_preds(predicted_labels, y_val,
                                          params['n_clusters'])

        homo_plot[i] = metrics.acc(y_val, predicted_labels)
        nmi_plot[i] = metrics.nmi(y_val, predicted_labels)
        ari_plot[i] = metrics.ari(y_val, predicted_labels)

        tsne = manifold.TSNE(n_components=2, init='pca', random_state=0)
        Z_tsne = tsne.fit_transform(embedding)
        sc = axs[1][0].scatter(Z_tsne[:, 0],
                               Z_tsne[:, 1],
                               s=2,
                               c=y_train_unlabeled,
                               cmap=plt.cm.get_cmap("jet", 14))
        axs[1][0].set_title('t-SNE Embeddings')
        axs[1][0].set_xlabel('t-SNE 1')
        axs[1][0].set_ylabel('t-SNE 2')
        axs[1][0].set_xticks([])
        axs[1][0].set_yticks([])
        axs[1][0].spines['right'].set_visible(False)
        axs[1][0].spines['top'].set_visible(False)
        divider = make_axes_locatable(axs[1][0])
        cax = divider.append_axes('right', size='15%', pad=0.05)
        cbar = fig.colorbar(sc,
                            cax=cax,
                            orientation='vertical',
                            ticks=range(params['n_clusters']))
        cbar.ax.set_yticklabels(
            params['cluster_names'])  # vertically oriented colorbar
        # Create offset transform by 5 points in x direction
        dx = 0 / 72.
        dy = -5 / 72.
        offset = matplotlib.transforms.ScaledTranslation(
            dx, dy, fig.dpi_scale_trans)

        # apply offset transform to all cluster ticklabels.
        for label in cbar.ax.yaxis.get_majorticklabels():
            label.set_transform(label.get_transform() + offset)

        reducer = umap.UMAP(transform_seed=36, random_state=36)
        matrix_reduce = reducer.fit_transform(embedding)
        sc = axs[1][1].scatter(matrix_reduce[:, 0],
                               matrix_reduce[:, 1],
                               s=2,
                               c=y_train_unlabeled,
                               cmap=plt.cm.get_cmap("jet", 14))
        axs[1][1].set_title('UMAP Embeddings')
        axs[1][1].set_xlabel('UMAP 1')
        axs[1][1].set_ylabel('UMAP 2')
        axs[1][1].set_xticks([])
        axs[1][1].set_yticks([])
        # Hide the right and top spines
        axs[1][1].spines['right'].set_visible(False)
        axs[1][1].spines['top'].set_visible(False)

        im = axs[1][2].imshow(confusion_matrix, cmap='YlOrRd')
        axs[1][2].set_title('Confusion Matrix')
        axs[1][2].set_xticks(range(params['n_clusters']))
        axs[1][2].set_yticks(range(params['n_clusters']))
        axs[1][2].set_xticklabels(params['cluster_names'], fontsize=8)
        axs[1][2].set_yticklabels(params['cluster_names'], fontsize=8)
        divider = make_axes_locatable(axs[1][2])
        cax = divider.append_axes('right', size='10%', pad=0.05)
        cbar = fig.colorbar(im, cax=cax, orientation='vertical', ticks=[])

        axs[0][0].plot(losses_vae[:i + 1])
        axs[0][0].set_title('VAE Loss')
        axs[0][0].set_xlabel('epochs')

        axs[0][1].plot(homo_plot[:i + 1])
        axs[0][1].set_title('Homogeneity')
        axs[0][1].set_xlabel('epochs')
        axs[0][1].set_ylim(0, 1)

        axs[0][2].plot(ari_plot[:i + 1])
        axs[0][2].set_title('ARI')
        axs[0][2].set_xlabel('epochs')
        axs[0][2].set_ylim(0, 1)

        axs[0][3].plot(nmi_plot[:i + 1])
        axs[0][3].set_title('NMI')
        axs[0][3].set_xlabel('epochs')
        axs[0][3].set_ylim(0, 1)

        #reconstructed_cell = ConvAE.vae.predict(x_val[:1, ...])[0, ..., 0]
        cell_tile = x_val[0, ..., 0]
        cell_tile = cell_tile[:, :64]
        x_recon = x_recon[0, ..., 0]
        reconstructed_cell_tile = x_recon[:, :64]
        reconstructed_cell_tile = np.flipud(reconstructed_cell_tile)
        cell_heatmap = np.vstack((cell_tile, reconstructed_cell_tile))
        axs[1][3].imshow(cell_heatmap, cmap='Reds')
        axs[1][3].set_xticks([])
        axs[1][3].set_yticks([])
        axs[1][3].spines['right'].set_visible(False)
        axs[1][3].spines['top'].set_visible(False)
        axs[1][3].spines['left'].set_visible(False)
        axs[1][3].spines['bottom'].set_visible(False)

        # get eigenvalues and eigenvectors
        scale = get_scale(embedding, params['batch_size'], params['scale_nbr'])
        values, vectors = spectral_clustering(embedding, scale,
                                              params['n_nbrs'],
                                              params['affinity'])

        # sort, then store the top n_clusters=2
        values_idx = np.argsort(values)
        x_spectral_clustering = vectors[:, values_idx[:params['n_clusters']]]

        # do kmeans clustering in this subspace
        y_spectral_clustering = KMeans(
            n_clusters=params['n_clusters']).fit_predict(
                vectors[:, values_idx[:params['n_clusters']]])

        tsne = manifold.TSNE(n_components=2, init='pca', random_state=0)
        Z_tsne = tsne.fit_transform(x_spectral_clustering)
        sc = axs[2][0].scatter(Z_tsne[:, 0],
                               Z_tsne[:, 1],
                               s=2,
                               c=y_train_unlabeled,
                               cmap=plt.cm.get_cmap("jet", 14))
        axs[2][0].set_title('Spectral Clusters (t-SNE) True Labels')
        axs[2][0].set_xlabel('t-SNE 1')
        axs[2][0].set_ylabel('t-SNE 2')
        axs[2][0].set_xticks([])
        axs[2][0].set_yticks([])
        axs[2][0].spines['right'].set_visible(False)
        axs[2][0].spines['top'].set_visible(False)

        reducer = umap.UMAP(transform_seed=36, random_state=36)
        matrix_reduce = reducer.fit_transform(x_spectral_clustering)
        axs[2][1].scatter(matrix_reduce[:, 0],
                          matrix_reduce[:, 1],
                          s=2,
                          c=y_spectral_clustering,
                          cmap=plt.cm.get_cmap("jet", 14))
        axs[2][1].set_title('Spectral Clusters (UMAP)')
        axs[2][1].set_xlabel('UMAP 1')
        axs[2][1].set_ylabel('UMAP 2')
        axs[2][1].set_xticks([])
        axs[2][1].set_yticks([])
        # Hide the right and top spines
        axs[2][1].spines['right'].set_visible(False)
        axs[2][1].spines['top'].set_visible(False)

        axs[2][2].scatter(matrix_reduce[:, 0],
                          matrix_reduce[:, 1],
                          s=2,
                          c=y_train_unlabeled,
                          cmap=plt.cm.get_cmap("jet", 14))
        axs[2][2].set_title('True Labels (UMAP)')
        axs[2][2].set_xlabel('UMAP 1')
        axs[2][2].set_ylabel('UMAP 2')
        axs[2][2].set_xticks([])
        axs[2][2].set_yticks([])
        # Hide the right and top spines
        axs[2][2].spines['right'].set_visible(False)
        axs[2][2].spines['top'].set_visible(False)

        axs[2][3].hist(x_spectral_clustering)
        axs[2][3].set_title("histogram of true eigenvectors")

        train_time = str(
            datetime.timedelta(seconds=(int(time.time() - start_time))))
        n_matrices = (i + 1) * params['batch_size'] * 100
        fig.suptitle('Trained on ' + '{:,}'.format(n_matrices) + ' cells\n' +
                     train_time)

        plt.savefig('vae/%d.png' % i)
        plt.close()

        plt.close()

        if i > 1:
            if np.abs(losses_vae[i] - losses_vae[i - 1]) < 0.0001:
                print('STOPPING EARLY')
                break

    print("finished training")

    plt.plot(losses_vae)
    plt.title('VAE Loss')
    plt.show()

    x_val_y = ConvAE.vae.predict(x_val)[2]
    # x_val_y = ConvAE.classfier.predict(x_val_lp)
    y_sp = x_val_y.argmax(axis=1)
    print_accuracy(y_sp, y_val, params['n_clusters'])
    from sklearn.metrics import normalized_mutual_info_score as nmi
    y_val = np.squeeze(np.asarray(y_val).ravel())  # squeeze into 1D array
    print(y_sp.shape, y_val.shape)
    nmi_score1 = nmi(y_sp, y_val)
    print('NMI: ' + str(np.round(nmi_score1, 4)))

    embedding = ConvAE.encoder.predict(x_val)
    tsne = manifold.TSNE(n_components=2, init='pca', random_state=0)
    Z_tsne = tsne.fit_transform(embedding)
    fig = plt.figure()
    plt.scatter(Z_tsne[:, 0],
                Z_tsne[:, 1],
                s=2,
                c=y_train_unlabeled,
                cmap=plt.cm.get_cmap("jet", 14))
    plt.colorbar(ticks=range(params['n_clusters']))
    plt.show()
def run_net(data, params):
    #
    # UNPACK DATA
    #

    x_train, y_train, x_val, y_val, x_test, y_test = data['spectral'][
        'train_and_test']
    x_train_unlabeled, y_train_unlabeled, x_train_labeled, y_train_labeled = data[
        'spectral']['train_unlabeled_and_labeled']
    x_val_unlabeled, y_val_unlabeled, x_val_labeled, y_val_labeled = data[
        'spectral']['val_unlabeled_and_labeled']

    if 'siamese' in params['affinity']:
        pairs_train, dist_train, pairs_val, dist_val = data['siamese'][
            'train_and_test']

    x = np.concatenate((x_train, x_val, x_test), axis=0)
    y = np.concatenate((y_train, y_val, y_test), axis=0)

    X = x
    n_samples, n_features = X.shape

    def plot_embedding(X, title=None):
        x_min, x_max = np.min(X, 0), np.max(X, 0)
        X = (X - x_min) / (x_max - x_min)
        plt.figure(figsize=(12, 12))
        ax = plt.subplot(111)
        for i in range(X.shape[0]):
            plt.text(X[i, 0],
                     X[i, 1],
                     '.',
                     color=plt.cm.Set1(1),
                     fontdict={
                         'weight': 'bold',
                         'size': 20
                     })
        plt.xticks([]), plt.yticks([])
        if title is not None:
            plt.title(title)

    from sklearn import manifold

    tsne = manifold.TSNE(n_components=2, init='pca', random_state=0)
    start_time = time.time()
    X_tsne = tsne.fit_transform(X)

    plot_embedding(X_tsne)
    plt.show()

    if len(x_train_labeled):
        y_train_labeled_onehot = OneHotEncoder().fit_transform(
            y_train_labeled.reshape(-1, 1)).toarray()
    else:
        y_train_labeled_onehot = np.empty((0, len(np.unique(y))))

    #
    # SET UP INPUTS
    #

    # create true y placeholder (not used in unsupervised training)
    y_true = tf.placeholder(tf.float32,
                            shape=(None, params['n_clusters']),
                            name='y_true')

    batch_sizes = {
        'Unlabeled': params['batch_size'],
        'Labeled': params['batch_size'],
        'Orthonorm': params.get('batch_size_orthonorm', params['batch_size']),
    }

    input_shape = x.shape[1:]

    # spectralnet has three inputs -- they are defined here
    inputs = {
        'Unlabeled': Input(shape=input_shape, name='UnlabeledInput'),
        'Labeled': Input(shape=input_shape, name='LabeledInput'),
        'Orthonorm': Input(shape=input_shape, name='OrthonormInput'),
    }

    #
    # DEFINE AND TRAIN SIAMESE NET
    #

    # run only if we are using a siamese network
    if params['affinity'] == 'siamese':
        siamese_net = networks.SiameseNet(inputs, params['arch'],
                                          params.get('siam_reg'), y_true)

        history = siamese_net.train(pairs_train, dist_train, pairs_val,
                                    dist_val, params['siam_lr'],
                                    params['siam_drop'],
                                    params['siam_patience'], params['siam_ne'],
                                    params['siam_batch_size'])

    else:
        siamese_net = None

    spectral_net = networks.SpectralNet(inputs, params['arch'],
                                        params.get('spec_reg'), y_true,
                                        y_train_labeled_onehot,
                                        params['n_clusters'],
                                        params['affinity'],
                                        params['scale_nbr'], params['n_nbrs'],
                                        batch_sizes, siamese_net, x_train,
                                        len(x_train_labeled))

    spectral_net.train(x_train_unlabeled, x_train_labeled, x_val_unlabeled,
                       params['spec_lr'], params['spec_drop'],
                       params['spec_patience'], params['spec_ne'])

    print("finished training")
    #
    # EVALUATE
    #

    # get final embeddings
    x_spectralnet = spectral_net.predict(x)

    # get accuracy and nmi
    kmeans_assignments, km = get_cluster_sols(x_spectralnet,
                                              ClusterClass=KMeans,
                                              n_clusters=params['n_clusters'],
                                              init_args={'n_init': 10})
    y_spectralnet, _, _1 = get_y_preds(kmeans_assignments, y,
                                       params['n_clusters'])
    print_accuracy(kmeans_assignments, y, params['n_clusters'])

    from sklearn.metrics import normalized_mutual_info_score as nmi
    nmi_score = nmi(kmeans_assignments, y)
    print('NMI: ' + str(np.round(nmi_score, 3)))

    if params['generalization_metrics']:
        x_spectralnet_train = spectral_net.predict(x_train_unlabeled)
        x_spectralnet_test = spectral_net.predict(x_test)
        km_train = KMeans(
            n_clusters=params['n_clusters']).fit(x_spectralnet_train)
        from scipy.spatial.distance import cdist
        dist_mat = cdist(x_spectralnet_test, km_train.cluster_centers_)
        closest_cluster = np.argmin(dist_mat, axis=1)
        print_accuracy(closest_cluster, y_test, params['n_clusters'],
                       ' generalization')
        nmi_score = nmi(closest_cluster, y_test)
        print('generalization NMI: ' + str(np.round(nmi_score, 3)))

    return x_spectralnet, y_spectralnet
Beispiel #6
0
def run_net(data, params):
    #
    # UNPACK DATA
    #

    x_train, y_train, x_val, y_val, x_test, y_test = data['spectral'][
        'train_and_test']
    x_train_unlabeled, y_train_unlabeled, x_train_labeled, y_train_labeled = data[
        'spectral']['train_unlabeled_and_labeled']
    x_val_unlabeled, y_val_unlabeled, x_val_labeled, y_val_labeled = data[
        'spectral']['val_unlabeled_and_labeled']

    if 'siamese' in params['affinity']:
        pairs_train, dist_train, pairs_val, dist_val = data['siamese'][
            'train_and_test']

    x = np.concatenate((x_train, x_val, x_test), axis=0)
    y = np.concatenate((y_train, y_val, y_test), axis=0)

    if len(x_train_labeled):
        y_train_labeled_onehot = OneHotEncoder().fit_transform(
            y_train_labeled.reshape(-1, 1)).toarray()
    else:
        y_train_labeled_onehot = np.empty((0, len(np.unique(y))))

    #
    # SET UP INPUTS
    #

    # create true y placeholder (not used in unsupervised training)
    y_true = tf.placeholder(tf.float32,
                            shape=(None, params['n_clusters']),
                            name='y_true')

    batch_sizes = {
        'Unlabeled': params['batch_size'],
        'Labeled': params['batch_size'],
        'Orthonorm': params.get('batch_size_orthonorm', params['batch_size']),
    }

    input_shape = x.shape[1:]

    # spectralnet has three inputs -- they are defined here
    inputs = {
        'Unlabeled': Input(shape=input_shape, name='UnlabeledInput'),
        'Labeled': Input(shape=input_shape, name='LabeledInput'),
        'Orthonorm': Input(shape=input_shape, name='OrthonormInput'),
    }

    #
    # DEFINE SIAMESE NET
    #

    # run only if we are using a siamese network
    if params['affinity'] == 'siamese':
        # set up the siamese network inputs as well
        siamese_inputs = {
            'A': inputs['Unlabeled'],
            'B': Input(shape=input_shape),
            'Labeled': inputs['Labeled'],
        }

        # generate layers
        layers = []
        layers += make_layer_list(params['arch'], 'siamese',
                                  params.get('siam_reg'))

        # create the siamese net
        siamese_outputs = stack_layers(siamese_inputs, layers)

        # add the distance layer
        distance = Lambda(costs.euclidean_distance,
                          output_shape=costs.eucl_dist_output_shape)(
                              [siamese_outputs['A'], siamese_outputs['B']])

        #create the distance model for training
        siamese_net_distance = Model(
            [siamese_inputs['A'], siamese_inputs['B']], distance)

        #
        # TRAIN SIAMESE NET
        #

        # compile the siamese network
        siamese_net_distance.compile(loss=costs.contrastive_loss,
                                     optimizer=RMSprop())

        # create handler for early stopping and learning rate scheduling
        siam_lh = LearningHandler(lr=params['siam_lr'],
                                  drop=params['siam_drop'],
                                  lr_tensor=siamese_net_distance.optimizer.lr,
                                  patience=params['siam_patience'])

        # initialize the training generator
        train_gen_ = train_gen(pairs_train, dist_train,
                               params['siam_batch_size'])

        # format the validation data for keras
        validation_data = ([pairs_val[:, 0], pairs_val[:, 1]], dist_val)

        # compute the steps per epoch
        steps_per_epoch = int(len(pairs_train) / params['siam_batch_size'])

        # train the network
        hist = siamese_net_distance.fit_generator(
            train_gen_,
            epochs=params['siam_ne'],
            validation_data=validation_data,
            steps_per_epoch=steps_per_epoch,
            callbacks=[siam_lh])

        # compute the siamese embeddings of the input data
        all_siam_preds = train.predict(siamese_outputs['A'],
                                       x_unlabeled=x_train,
                                       inputs=inputs,
                                       y_true=y_true,
                                       batch_sizes=batch_sizes)

    #
    # DEFINE SPECTRALNET
    #

    # generate layers
    layers = []
    layers = make_layer_list(params['arch'][:-1], 'spectral',
                             params.get('spec_reg'))
    layers += [{
        'type': 'tanh',
        'size': params['n_clusters'],
        'l2_reg': params.get('spec_reg'),
        'name': 'spectral_{}'.format(len(params['arch']) - 1)
    }, {
        'type': 'Orthonorm',
        'name': 'orthonorm'
    }]

    # create spectralnet
    outputs = stack_layers(inputs, layers)
    spectral_net = Model(inputs=inputs['Unlabeled'],
                         outputs=outputs['Unlabeled'])

    #
    # DEFINE SPECTRALNET LOSS
    #

    # generate affinity matrix W according to params
    if params['affinity'] == 'siamese':
        input_affinity = tf.concat(
            [siamese_outputs['A'], siamese_outputs['Labeled']], axis=0)
        x_affinity = all_siam_preds
    elif params['affinity'] in ['knn', 'full']:
        input_affinity = tf.concat([inputs['Unlabeled'], inputs['Labeled']],
                                   axis=0)
        x_affinity = x_train

    # calculate scale for affinity matrix
    scale = get_scale(x_affinity, batch_sizes['Unlabeled'],
                      params['scale_nbr'])

    # create affinity matrix
    if params['affinity'] == 'full':
        W = costs.full_affinity(input_affinity, scale=scale)
    elif params['affinity'] in ['knn', 'siamese']:
        W = costs.knn_affinity(input_affinity,
                               params['n_nbrs'],
                               scale=scale,
                               scale_nbr=params['scale_nbr'])

    # if we have labels, use them
    if len(x_train_labeled):
        # get true affinities (from labeled data)
        W_true = tf.cast(tf.equal(costs.squared_distance(y_true), 0),
                         dtype='float32')

        # replace lower right corner of W with W_true
        unlabeled_end = tf.shape(inputs['Unlabeled'])[0]
        W_u = W[:unlabeled_end, :]  # upper half
        W_ll = W[unlabeled_end:, :unlabeled_end]  # lower left
        W_l = tf.concat((W_ll, W_true), axis=1)  # lower half
        W = tf.concat((W_u, W_l), axis=0)

        # create pairwise batch distance matrix Dy
        Dy = costs.squared_distance(
            tf.concat([outputs['Unlabeled'], outputs['Labeled']], axis=0))
    else:
        Dy = costs.squared_distance(outputs['Unlabeled'])

    # define loss
    spectral_net_loss = K.sum(W * Dy) / (2 * params['batch_size'])

    # create the train step update
    learning_rate = tf.Variable(0., name='spectral_net_learning_rate')
    train_step = tf.train.RMSPropOptimizer(
        learning_rate=learning_rate).minimize(
            spectral_net_loss, var_list=spectral_net.trainable_weights)

    #
    # TRAIN SPECTRALNET
    #

    # initialize spectralnet variables
    K.get_session().run(
        tf.variables_initializer(spectral_net.trainable_weights))

    # set up validation/test set inputs
    inputs_test = {
        'Unlabeled': inputs['Unlabeled'],
        'Orthonorm': inputs['Orthonorm']
    }

    # create handler for early stopping and learning rate scheduling
    spec_lh = LearningHandler(lr=params['spec_lr'],
                              drop=params['spec_drop'],
                              lr_tensor=learning_rate,
                              patience=params['spec_patience'])

    # begin spectralnet training loop
    spec_lh.on_train_begin()
    for i in range(params['spec_ne']):
        # train spectralnet
        loss = train.train_step(return_var=[spectral_net_loss],
                                updates=spectral_net.updates + [train_step],
                                x_unlabeled=x_train_unlabeled,
                                inputs=inputs,
                                y_true=y_true,
                                batch_sizes=batch_sizes,
                                x_labeled=x_train_labeled,
                                y_labeled=y_train_labeled_onehot,
                                batches_per_epoch=100)[0]

        # get validation loss
        val_loss = train.predict_sum(spectral_net_loss,
                                     x_unlabeled=x_val_unlabeled,
                                     inputs=inputs,
                                     y_true=y_true,
                                     x_labeled=x[0:0],
                                     y_labeled=y_train_labeled_onehot,
                                     batch_sizes=batch_sizes)

        # do early stopping if necessary
        if spec_lh.on_epoch_end(i, val_loss):
            print('STOPPING EARLY')
            break

        # print training status
        print("Epoch: {}, loss={:2f}, val_loss={:2f}".format(
            i, loss, val_loss))

    print("finished training")

    #
    # EVALUATE
    #

    # get final embeddings
    x_spectralnet = train.predict(outputs['Unlabeled'],
                                  x_unlabeled=x,
                                  inputs=inputs_test,
                                  y_true=y_true,
                                  x_labeled=x_train_labeled[0:0],
                                  y_labeled=y_train_labeled_onehot[0:0],
                                  batch_sizes=batch_sizes)

    # get accuracy and nmi
    kmeans_assignments, km = get_cluster_sols(x_spectralnet,
                                              ClusterClass=KMeans,
                                              n_clusters=params['n_clusters'],
                                              init_args={'n_init': 10})
    y_spectralnet, _ = get_y_preds(kmeans_assignments, y, params['n_clusters'])
    print_accuracy(kmeans_assignments, y, params['n_clusters'])
    from sklearn.metrics import normalized_mutual_info_score as nmi
    nmi_score = nmi(kmeans_assignments, y)
    print('NMI: ' + str(np.round(nmi_score, 3)))

    if params['generalization_metrics']:
        x_spectralnet_train = train.predict(
            outputs['Unlabeled'],
            x_unlabeled=x_train_unlabeled,
            inputs=inputs_test,
            y_true=y_true,
            x_labeled=x_train_labeled[0:0],
            y_labeled=y_train_labeled_onehot[0:0],
            batch_sizes=batch_sizes)
        x_spectralnet_test = train.predict(
            outputs['Unlabeled'],
            x_unlabeled=x_test,
            inputs=inputs_test,
            y_true=y_true,
            x_labeled=x_train_labeled[0:0],
            y_labeled=y_train_labeled_onehot[0:0],
            batch_sizes=batch_sizes)
        km_train = KMeans(
            n_clusters=params['n_clusters']).fit(x_spectralnet_train)
        from scipy.spatial.distance import cdist
        dist_mat = cdist(x_spectralnet_test, km_train.cluster_centers_)
        closest_cluster = np.argmin(dist_mat, axis=1)
        print_accuracy(closest_cluster, y_test, params['n_clusters'],
                       ' generalization')
        nmi_score = nmi(closest_cluster, y_test)
        print('generalization NMI: ' + str(np.round(nmi_score, 3)))

    return x_spectralnet, y_spectralnet