Esempio n. 1
0
 def show_t_sne(self, name, n_samples=1000, color_by='', save_name=''):
     latent, batch_indices, labels = get_latent(self.model,
                                                self.data_loaders[name],
                                                use_cuda=self.use_cuda)
     idx_t_sne = np.random.permutation(
         len(latent))[:n_samples] if n_samples else np.arange(len(latent))
     if latent.shape[1] != 2:
         latent = TSNE().fit_transform(latent[idx_t_sne])
     plt.figure(figsize=(10, 10))
     if not color_by:
         plt.scatter(latent[:, 0], latent[:, 1], edgecolors='none')
     else:
         if color_by == 'labels':
             indices = labels.ravel()
         elif color_by == 'batches':
             indices = batch_indices.ravel()
         for i in range(len(np.unique(indices))):
             plt.scatter(latent[indices == i, 0],
                         latent[indices == i, 1],
                         label=str(i),
                         edgecolors='none')
     plt.axis("off")
     plt.tight_layout()
     if save_name:
         plt.savefig(save_name)
Esempio n. 2
0
    def clustering_scores(self,
                          name,
                          verbose=True,
                          prediction_algorithm='knn'):
        if self.gene_dataset.n_labels > 1:
            latent, _, labels = get_latent(self.model, self.data_loaders[name])
            if prediction_algorithm == 'knn':
                labels_pred = KMeans(self.gene_dataset.n_labels,
                                     n_init=200).fit_predict(
                                         latent)  # n_jobs>1 ?
            elif prediction_algorithm == 'gmm':
                gmm = GMM(self.gene_dataset.n_labels)
                gmm.fit(latent)
                labels_pred = gmm.predict(latent)

            asw_score = silhouette_score(latent, labels)
            nmi_score = NMI(labels, labels_pred)
            ari_score = ARI(labels, labels_pred)
            uca_score = unsupervised_clustering_accuracy(labels,
                                                         labels_pred)[0]
            if verbose:
                print(
                    "Clustering Scores for %s:\nSilhouette: %.4f\nNMI: %.4f\nARI: %.4f\nUCA: %.4f"
                    % (name, asw_score, nmi_score, ari_score, uca_score))
            return asw_score, nmi_score, ari_score, uca_score
 def entropy_batch_mixing(self, name, verbose=False, **kwargs):
     if self.gene_dataset.n_batches == 2:
         latent, batch_indices, labels = get_latent(self.model,
                                                    self.data_loaders[name])
         be_score = entropy_batch_mixing(latent, batch_indices, **kwargs)
         if verbose:
             print("Entropy batch mixing :", be_score)
         return be_score
    def show_t_sne(self, name, n_samples=1000, color_by='', save_name=''):
        latent, batch_indices, labels = get_latent(self.model,
                                                   self.data_loaders[name])
        idx_t_sne = np.random.permutation(
            len(latent))[:n_samples] if n_samples else np.arange(len(latent))
        if latent.shape[1] != 2:
            latent = TSNE().fit_transform(latent[idx_t_sne])
        if not color_by:
            plt.figure(figsize=(10, 10))
            plt.scatter(latent[:, 0], latent[:, 1])
        else:
            batch_indices = batch_indices[idx_t_sne].ravel()
            labels = labels[idx_t_sne].ravel()
            if color_by == 'batches' or color_by == 'labels':
                indices = batch_indices if color_by == 'batches' else labels
                n = self.gene_dataset.n_batches if color_by == 'batches' else self.gene_dataset.n_labels
                if hasattr(self.gene_dataset,
                           'cell_types') and color_by == 'labels':
                    plt_labels = self.gene_dataset.cell_types
                else:
                    plt_labels = [
                        str(i) for i in range(len(np.unique(indices)))
                    ]
                plt.figure(figsize=(10, 10))
                for i, label in zip(range(n), plt_labels):
                    plt.scatter(latent[indices == i, 0],
                                latent[indices == i, 1],
                                label=label)
                plt.legend()
            elif color_by == 'batches and labels':
                fig, axes = plt.subplots(1, 2, figsize=(14, 7))
                for i in range(self.gene_dataset.n_batches):
                    axes[0].scatter(latent[batch_indices == i, 0],
                                    latent[batch_indices == i, 1],
                                    label=str(i))
                axes[0].set_title("batch coloring")
                axes[0].axis("off")
                axes[0].legend()

                indices = labels
                if hasattr(self.gene_dataset, 'cell_types'):
                    plt_labels = self.gene_dataset.cell_types
                else:
                    plt_labels = [
                        str(i) for i in range(len(np.unique(indices)))
                    ]
                for i, cell_type in zip(range(self.gene_dataset.n_labels),
                                        plt_labels):
                    axes[1].scatter(latent[indices == i, 0],
                                    latent[indices == i, 1],
                                    label=cell_type)
                axes[1].set_title("label coloring")
                axes[1].axis("off")
                axes[1].legend()
        plt.axis("off")
        plt.tight_layout()
        if save_name:
            plt.savefig(save_name)
 def nn_overlap_score(self, name='sequential', verbose=True, **kwargs):
     if hasattr(self.gene_dataset, 'adt_expression_clr'):
         assert name == 'sequential'  # only works for the sequential data_loader (mapping indices)
         latent, _, _ = get_latent(self.model, self.data_loaders[name])
         protein_data = self.gene_dataset.adt_expression_clr
         spearman_correlation, fold_enrichment = nn_overlap(
             latent, protein_data, **kwargs)
         if verbose:
             print(
                 "Overlap Scores for %s:\nSpearman Correlation: %.4f\nFold Enrichment: %.4f"
                 % (name, spearman_correlation, fold_enrichment))
         return spearman_correlation, fold_enrichment
 def clustering_scores(self, name, verbose=True):
     if self.gene_dataset.n_labels > 1:
         latent, _, labels = get_latent(self.model, self.data_loaders[name])
         labels_pred = KMeans(self.gene_dataset.n_labels,
                              n_init=200).fit_predict(latent)  # n_jobs>1 ?
         asw_score = silhouette_score(latent, labels)
         nmi_score = NMI(labels, labels_pred)
         ari_score = ARI(labels, labels_pred)
         if verbose:
             print(
                 "Clustering Scores for %s:\nSilhouette: %.4f\nNMI: %.4f\nARI: %.4f"
                 % (name, asw_score, nmi_score, ari_score))
         return asw_score, nmi_score, ari_score
Esempio n. 7
0
def harmonization_stat(model, data_loader, keys, pop1, pop2):
    latent, batch_indices, labels = get_latent(model, data_loader)
    batch_indices = np.concatenate(batch_indices)
    sample = sample_by_batch(batch_indices, 2000)
    sample_2batch = sample[(batch_indices[sample] == pop1) +
                           (batch_indices[sample] == pop2)]
    batch_entropy = entropy_batch_mixing(latent[sample_2batch, :],
                                         batch_indices[sample_2batch])
    print("Entropy batch mixing : %f.3" % batch_entropy)
    sample = sample_by_batch(labels, 200)
    res = knn_purity_avg(latent[sample, :],
                         labels.astype('int')[sample],
                         keys,
                         acc=True)
    print("Average knn purity : %f.3" % np.mean([x[1] for x in res]))
    return (batch_entropy, res)
Esempio n. 8
0
import sys

model_type = str(sys.argv[1])
plotname = 'Macosko_Regev'
dataset1 = MacoskoDataset()
dataset2 = RegevDataset()
gene_dataset = GeneExpressionDataset.concat_datasets(dataset1, dataset2)
gene_dataset.subsample_genes(5000)

if model_type == 'vae':
    vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches, n_labels=gene_dataset.n_labels,
              n_hidden=128, n_latent=10, n_layers=2, dispersion='gene')
    infer = VariationalInference(vae, gene_dataset, use_cuda=use_cuda)
    infer.train(n_epochs=250)
    data_loader = infer.data_loaders['sequential']
    latent, batch_indices, labels = get_latent(vae, data_loader)
    keys = gene_dataset.cell_types
    batch_indices = np.concatenate(batch_indices)
    keys = gene_dataset.cell_types
elif model_type == 'svaec':
    svaec = SCANVI(gene_dataset.nb_genes, gene_dataset.n_batches,
                   gene_dataset.n_labels, use_labels_groups=False,
                   n_latent=10, n_layers=2)
    infer = SemiSupervisedVariationalInference(svaec, gene_dataset)
    infer.train(n_epochs=50)
    print('svaec acc =', infer.accuracy('unlabelled'))
    data_loader = infer.data_loaders['unlabelled']
    latent, batch_indices, labels = get_latent(infer.model, infer.data_loaders['unlabelled'])
    keys = gene_dataset.cell_types
    batch_indices = np.concatenate(batch_indices)
elif model_type == 'Seurat':
Esempio n. 9
0
def run_benchmarks(dataset_name,
                   model=VAE,
                   n_epochs=1000,
                   lr=1e-3,
                   use_batches=False,
                   use_cuda=True,
                   show_batch_mixing=True,
                   benchmark=False,
                   tt_split=0.9,
                   unit_test=False):
    # options:
    # - gene_dataset: a GeneExpressionDataset object
    # call each of the 4 benchmarks:
    # - log-likelihood
    # - imputation
    # - batch mixing
    # - cluster scores
    gene_dataset = load_datasets(dataset_name, unit_test=unit_test)
    example_indices = np.random.permutation(len(gene_dataset))
    tt_split = int(tt_split * len(gene_dataset))  # 90%/10% train/test split

    data_loader_train = DataLoader(gene_dataset,
                                   batch_size=128,
                                   pin_memory=use_cuda,
                                   sampler=SubsetRandomSampler(
                                       example_indices[:tt_split]),
                                   collate_fn=gene_dataset.collate_fn)
    data_loader_test = DataLoader(gene_dataset,
                                  batch_size=128,
                                  pin_memory=use_cuda,
                                  sampler=SubsetRandomSampler(
                                      example_indices[tt_split:]),
                                  collate_fn=gene_dataset.collate_fn)
    vae = model(gene_dataset.nb_genes,
                n_batch=gene_dataset.n_batches * use_batches,
                n_labels=gene_dataset.n_labels,
                use_cuda=use_cuda)
    stats = train(vae,
                  data_loader_train,
                  data_loader_test,
                  n_epochs=n_epochs,
                  lr=lr,
                  benchmark=benchmark)

    if isinstance(vae, VAE):
        best_ll = adapt_encoder(vae,
                                data_loader_test,
                                n_path=1,
                                n_epochs=1,
                                record_freq=1)
        print("Best ll was :", best_ll)

    # - log-likelihood
    print("Log-likelihood Train:", stats.history["LL_train"][stats.best_index])
    print("Log-likelihood Test:", stats.history["LL_test"][stats.best_index])

    # - imputation
    imputation_test = imputation(vae, data_loader_test)
    print("Imputation score on test (MAE) is:", imputation_test.item())

    # - batch mixing
    if gene_dataset.n_batches == 2:
        latent, batch_indices, labels = get_latent(vae, data_loader_train)
        print(
            "Entropy batch mixing :",
            entropy_batch_mixing(latent.cpu().numpy(),
                                 batch_indices.cpu().numpy()))
        if show_batch_mixing:
            show_t_sne(
                latent.cpu().numpy(),
                np.array([batch[0] for batch in batch_indices.cpu().numpy()]))

    # - differential expression
    if type(gene_dataset) == CortexDataset:
        get_statistics(vae, data_loader_train, M_sampling=1,
                       M_permutation=1)  # 200 - 100000
Esempio n. 10
0
    def show_t_sne(self,
                   name,
                   n_samples=1000,
                   color_by='',
                   save_name='',
                   latent=None,
                   batch_indices=None,
                   labels=None,
                   n_batch=None):
        # If no latent representation is given
        if latent is None:
            latent, batch_indices, labels = get_latent(self.model,
                                                       self.data_loaders[name])
            latent, idx_t_sne = self.apply_t_sne(latent, n_samples)
            batch_indices = batch_indices[idx_t_sne].ravel()
            labels = labels[idx_t_sne].ravel()
        if not color_by:
            plt.figure(figsize=(10, 10))
            plt.scatter(latent[:, 0], latent[:, 1])
        if color_by == 'scalar':
            plt.figure(figsize=(10, 10))
            plt.scatter(latent[:, 0], latent[:, 1], c=labels.ravel())
        else:
            if n_batch is None:
                n_batch = self.gene_dataset.n_batches
            if color_by == 'batches' or color_by == 'labels':
                indices = batch_indices if color_by == 'batches' else labels
                n = n_batch if color_by == 'batches' else self.gene_dataset.n_labels
                if hasattr(self.gene_dataset,
                           'cell_types') and color_by == 'labels':
                    plt_labels = self.gene_dataset.cell_types
                else:
                    plt_labels = [
                        str(i) for i in range(len(np.unique(indices)))
                    ]
                plt.figure(figsize=(10, 10))
                for i, label in zip(range(n), plt_labels):
                    plt.scatter(latent[indices == i, 0],
                                latent[indices == i, 1],
                                label=label)
                plt.legend()
            elif color_by == 'batches and labels':
                fig, axes = plt.subplots(1, 2, figsize=(14, 7))
                for i in range(n_batch):
                    axes[0].scatter(latent[batch_indices == i, 0],
                                    latent[batch_indices == i, 1],
                                    label=str(i))
                axes[0].set_title("batch coloring")
                axes[0].axis("off")
                axes[0].legend()

                indices = labels
                if hasattr(self.gene_dataset, 'cell_types'):
                    plt_labels = self.gene_dataset.cell_types
                else:
                    plt_labels = [
                        str(i) for i in range(len(np.unique(indices)))
                    ]
                for i, cell_type in zip(range(self.gene_dataset.n_labels),
                                        plt_labels):
                    axes[1].scatter(latent[indices == i, 0],
                                    latent[indices == i, 1],
                                    label=cell_type)
                axes[1].set_title("label coloring")
                axes[1].axis("off")
                axes[1].legend()
        plt.axis("off")
        plt.tight_layout()
        if save_name:
            plt.savefig(save_name)
Esempio n. 11
0
          n_latent=10,
          n_layers=1,
          dispersion='gene')
infer_vae = VariationalInference(vae, gene_dataset, use_cuda=use_cuda)
infer_vae.fit(n_epochs=100)

np.save("../" + plotname + '.label.npy', gene_dataset.labels)
np.save("../" + plotname + '.batch.npy', gene_dataset.batch_indices)
mmwrite("../" + plotname + '.count.mtx', gene_dataset.X)

data_loader = DataLoader(gene_dataset,
                         batch_size=128,
                         pin_memory=use_cuda,
                         shuffle=False,
                         collate_fn=gene_dataset.collate_fn)
latent, batch_indices, labels = get_latent(infer_vae.model, data_loader)
keys = gene_dataset.cell_types
batch_indices = np.concatenate(batch_indices)

n_plotcells = 6000
pop1 = 0
pop2 = 1
nbatches = len(np.unique(batch_indices))
_, cell_count = np.unique(batch_indices, return_counts=True)

sample = sample_by_batch(batch_indices, int(n_plotcells / nbatches))
sample_2batch = sample[(batch_indices[sample] == pop1) +
                       (batch_indices[sample] == pop2)]
batch_entropy = entropy_batch_mixing(latent[sample_2batch, :],
                                     batch_indices[sample_2batch])
print("Entropy batch mixing :", batch_entropy)