示例#1
0
def correct_scvi(Xs, genes):
    import torch
    use_cuda = True
    torch.cuda.set_device(1)

    from scvi.dataset.dataset import GeneExpressionDataset
    from scvi.inference import UnsupervisedTrainer
    from scvi.models import SCANVI, VAE
    from scvi.dataset.anndata import AnnDataset

    all_ann = [AnnDataset(AnnData(X, var=genes)) for X in Xs]

    all_dataset = GeneExpressionDataset.concat_datasets(*all_ann)

    vae = VAE(all_dataset.nb_genes,
              n_batch=all_dataset.n_batches,
              n_labels=all_dataset.n_labels,
              n_hidden=128,
              n_latent=30,
              n_layers=2,
              dispersion='gene')
    trainer = UnsupervisedTrainer(vae, all_dataset, train_size=0.99999)
    n_epochs = 100
    #trainer.train(n_epochs=n_epochs)
    #torch.save(trainer.model.state_dict(),
    #           'data/harmonization.vae.pkl')
    trainer.model.load_state_dict(torch.load('data/harmonization.vae.pkl'))
    trainer.model.eval()

    full = trainer.create_posterior(trainer.model,
                                    all_dataset,
                                    indices=np.arange(len(all_dataset)))
    latent, batch_indices, labels = full.sequential().get_latent()

    return latent
示例#2
0
    X2 = np.log(1 + norm_X[index_1])
    from scvi.harmonization.classification.CORAL import CORAL
    coral = CORAL()
    coral1 = coral.fit_predict(X1, labels[index_0], X2)
    coral = CORAL()
    coral2 = coral.fit_predict(X2, labels[index_1], X1)
    return pred1,pred2, coral1, coral2


from scvi.dataset.muris_tabula import TabulaMuris
plotname = 'MarrowTM'
dataset1 = TabulaMuris('facs', save_path='/data/yosef2/scratch/chenling/scanvi_data/')
dataset2 = TabulaMuris('droplet', save_path='/data/yosef2/scratch/chenling/scanvi_data/')
dataset1.subsample_genes(dataset1.nb_genes)
dataset2.subsample_genes(dataset2.nb_genes)
gene_dataset = GeneExpressionDataset.concat_datasets(dataset1, dataset2)

#
pred1, pred2, coral1, coral2 = labelpred(gene_dataset, dataset1, dataset2, plotname)
# SCANVI_acc(gene_dataset, plotname, pred1,pred2,coral1,coral2)

#
#
# plotname = 'PBMC8KCITE'
# from scvi.harmonization.utils_chenling import get_matrix_from_dir,assign_label
# from scvi.dataset.pbmc import PbmcDataset
# from scvi.dataset.dataset import GeneExpressionDataset
# dataset1 = PbmcDataset(filter_out_de_genes=False)
# dataset1.update_cells(dataset1.batch_indices.ravel()==0)
# dataset1.subsample_genes(dataset1.nb_genes)
# save_path='/data/yosef2/scratch/chenling/scanvi_data/'
if not os.path.exists("./data"):
    os.mkdir("./data")
hemat_full = pd.read_csv("count_data_pancreas_v1.txt", sep=" ", header=None)

heamt_full.index = ["gene_" + str(i) for i in range(1, 3471)]
hemat_full.columns = ["sample_" + str(i) for i in range(1, 4650)]

hemat_full.iloc[:, 0:2729].to_csv("./data/count_data_hemat_v1_batch1.csv",
                                  sep=",")
hemat_full.iloc[:, 2729:4649].to_csv("./data/count_data_hemat_v1_batch2.csv",
                                     sep=",")

hemat_batch_1 = CsvDataset("count_data_hemat_v1_batch1.csv", new_n_genes=3470)
hemat_batch_2 = CsvDataset("count_data_hemat_v1_batch2.csv", new_n_genes=3470)

hemat_data = GeneExpressionDataset.concat_datasets(hemat_batch_1,
                                                   hemat_batch_2)

hemat_vae = VAE(hemat_data.nb_genes,
                n_batch=hemat_data.n_batches,
                n_labels=hemat_data.n_labels,
                n_hidden=128,
                n_latent=30,
                n_layers=2,
                dispersion='gene')

hemat_trainer = UnsupervisedTrainer(hemat_vae, hemat_data, train_size=0.9)

hemat_trainer.train(n_epochs=100)

hemat_full = hemat_trainer.create_posterior(hemat_trainer.model,
                                            hemat_data,
simulation_full.iloc[:, 600:800].to_csv(
    "./data/count_data_simulation_v1_batch3.csv", sep=",")
simulation_full.iloc[:, 800:1000].to_csv(
    "./data/count_data_simulation_v1_batch4.csv", sep=",")

simulation_batch_1 = CsvDataset("count_data_simulation_v1_batch1.csv",
                                new_n_genes=3000)
simulation_batch_2 = CsvDataset("count_data_simulation_v1_batch2.csv",
                                new_n_genes=3000)
simulation_batch_3 = CsvDataset("count_data_simulation_v1_batch3.csv",
                                new_n_genes=3000)
simulation_batch_4 = CsvDataset("count_data_simulation_v1_batch4.csv",
                                new_n_genes=3000)

simulation_data = GeneExpressionDataset.concat_datasets(
    simulation_batch_1, simulation_batch_2, simulation_batch_3,
    simulation_batch_4)

simulation_vae = VAE(simulation_data.nb_genes,
                     n_batch=simulation_data.n_batches,
                     n_labels=simulation_data.n_labels,
                     n_hidden=128,
                     n_latent=30,
                     n_layers=2,
                     dispersion='gene')

simulation_trainer = UnsupervisedTrainer(simulation_vae,
                                         simulation_data,
                                         train_size=0.9)

simulation_trainer.train(n_epochs=100)
示例#5
0
matplotlib.rcParams['ps.fonttype'] = 42
import matplotlib.pyplot as plt
import seaborn as sns
from scvi.metrics.clustering import select_indices_evenly
from sklearn.manifold import TSNE
dirs = (open('/data/scanorama/conf/panorama.txt').read().rstrip().split())
# ['data/293t_jurkat/293t', 'data/293t_jurkat/jurkat', 'data/293t_jurkat/jurkat_293t_50_50', 'data/293t_jurkat/jurkat_293t_99_1',
#  'data/brain/neuron_9k',
#  'data/macrophage/infected', 'data/macrophage/mixed_infected', 'data/macrophage/uninfected', 'data/macrophage/uninfected_donor2',
#  'data/hsc/hsc_mars', 'data/hsc/hsc_ss2',
#  'data/pancreas/pancreas_inDrop', 'data/pancreas/pancreas_multi_celseq2_expression_matrix', 'data/pancreas/pancreas_multi_celseq_expression_matrix', 'data/pancreas/pancreas_multi_fluidigmc1_expression_matrix', 'data/pancreas/pancreas_multi_smartseq2_expression_matrix',
#  'data/pbmc/10x/68k_pbmc', 'data/pbmc/10x/b_cells', 'data/pbmc/10x/cd14_monocytes', 'data/pbmc/10x/cd4_t_helper', 'data/pbmc/10x/cd56_nk', 'data/pbmc/10x/cytotoxic_t', 'data/pbmc/10x/memory_t', 'data/pbmc/10x/regulatory_t', 'data/pbmc/pbmc_kang', 'data/pbmc/pbmc_10X']

datasets = [DatasetSCANORAMA(d) for d in dirs]

all_dataset = GeneExpressionDataset.concat_datasets(*datasets)
# Keeping 5216 genes

labels = (
    open('/data/scanorama/data/cell_labels/all.txt').read().rstrip().split())
all_dataset.cell_types, all_dataset.labels = np.unique(labels,
                                                       return_inverse=True)
all_dataset.labels = all_dataset.labels.reshape(len(all_dataset.labels), 1)
all_dataset.batch_indices = all_dataset.batch_indices.astype('int')

from scvi.harmonization.utils_chenling import trainVAE, VAEstats
# full = trainVAE(all_dataset, 'scanorama', 1, nlayers=3,n_hidden=256)
full = trainVAE(all_dataset, 'scanorama', 0)  #  nlayers=2,n_hidden=128
# full = trainVAE(all_dataset, 'scanorama', 2, nlayers=3,n_hidden=128)

# for 250 iterations, takes 45:14 to train VAE
示例#6
0
 for celltype2 in dataset2.cell_types[:6]:
     if celltype1 != celltype2:
         print(celltype1 + ' ' + celltype2)
         pbmc = deepcopy(dataset1)
         newCellType = [
             k for i, k in enumerate(dataset1.cell_types)
             if k not in [celltype1, 'Other']
         ]
         pbmc.filter_cell_types(newCellType)
         pbmc2 = deepcopy(dataset2)
         newCellType = [
             k for i, k in enumerate(dataset2.cell_types)
             if k not in [celltype2, 'Other']
         ]
         pbmc2.filter_cell_types(newCellType)
         gene_dataset = GeneExpressionDataset.concat_datasets(pbmc, pbmc2)
         # _,_,_,_,_ = run_model('writedata', gene_dataset, pbmc, pbmc2,filename=plotname+'.'
         #                                                                       +celltype1.replace(' ','')+'.'
         #                                                                       +celltype2.replace(' ',''))
         rmCellTypes = '.' + celltype1.replace(
             ' ', '') + '.' + celltype2.replace(' ', '')
         latent1 = np.genfromtxt('../harmonization/Seurat_data/' +
                                 plotname + rmCellTypes.replace(' ', '') +
                                 '.1.CCA.txt')
         latent2 = np.genfromtxt('../harmonization/Seurat_data/' +
                                 plotname + rmCellTypes.replace(' ', '') +
                                 '.2.CCA.txt')
         latent, batch_indices, labels, keys, stats = run_model(
             'readSeurat',
             gene_dataset,
             pbmc,
示例#7
0
countnonUMI = np.load('../sim_data/count.nonUMI.npy').T
labelUMI = np.load('../sim_data/label.UMI.npy')
labelnonUMI = np.load('../sim_data/label.nonUMI.npy')

UMI = GeneExpressionDataset(
            *GeneExpressionDataset.get_attributes_from_matrix(
                csr_matrix(countUMI), labels=labelUMI),
            gene_names=['gene'+str(i) for i in range(2000)], cell_types=['type'+str(i+1) for i in range(5)])

nonUMI = GeneExpressionDataset(
            *GeneExpressionDataset.get_attributes_from_matrix(
                csr_matrix(countnonUMI), labels=labelnonUMI),
            gene_names=['gene'+str(i) for i in range(2000)], cell_types=['type'+str(i+1) for i in range(5)])

if model_type in ['vae', 'svaec', 'Seurat', 'Combat']:
    gene_dataset = GeneExpressionDataset.concat_datasets(UMI, nonUMI)

    if model_type == 'vae':
        vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches, n_labels=gene_dataset.n_labels,
                  n_hidden=128, n_latent=10, n_layers=2, dispersion='gene')
        infer_vae = VariationalInference(vae, gene_dataset, use_cuda=use_cuda)
        infer_vae.train(n_epochs=250)
        data_loader = infer_vae.data_loaders['sequential']
        latent, batch_indices, labels = get_latent(vae, data_loader)
        keys = gene_dataset.cell_types
        batch_indices = np.concatenate(batch_indices)
        keys = gene_dataset.cell_types
    elif model_type == 'svaec':
        gene_dataset.subsample_genes(1000)

        n_epochs_vae = 100
示例#8
0
 def preprocess(self):
     if os.path.isfile(self.save_path + 'regev_data.svmlight'):
         count, labels = load_svmlight_file(self.save_path +
                                            'regev_data.svmlight')
         cell_type = np.load(self.save_path + 'regev_data.celltypes.npy')
         gene_names = np.load(self.save_path + 'regev_data.gene_names.npy')
         labels_groups = np.load(self.save_path +
                                 'regev_data.labels_groups.npy')
         return (count, labels, cell_type, gene_names, labels_groups)
     else:
         regev_batches = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
         label = np.genfromtxt(self.save_path +
                               '10X_nuclei_Regev/cluster.membership.csv',
                               dtype='str',
                               delimiter=',')
         label_batch = np.asarray([
             str(int(int(x.split('-')[1].split('"')[0])))
             for x in label[1:, 0]
         ])
         label_barcode = np.asarray(
             [x.split('-')[0].split('"')[1] for x in label[1:, 0]])
         label_cluster = np.asarray([x.split('"')[1] for x in label[1:, 1]])
         label_map = np.genfromtxt(
             self.save_path + '10X_nuclei_Regev/cluster.annotation.csv',
             dtype='str',
             delimiter=',')
         label_map = dict(
             zip([x.split('"')[1] for x in label_map[:, 0]],
                 [x.split('"')[1] for x in label_map[:, 1]]))
         regev_data = []
         for batch_i, batch in enumerate(regev_batches):
             geneid, cellid, count = get_matrix_from_h5(
                 self.save_path + '10X_nuclei_Regev/' + batch +
                 '1/filtered_gene_bc_matrices_h5.h5', 'mm10-1.2.0_premrna')
             count = count.T.tocsr()
             cellid = [id.split('-')[0] for id in cellid]
             label_dict = dict(
                 zip(label_barcode[label_batch == str(batch_i + 1)],
                     label_cluster[label_batch == str(batch_i + 1)]))
             new_count, matched_label = TryFindCells(
                 label_dict, cellid, count)
             new_label = np.repeat(0, len(matched_label))
             for i, x in enumerate(np.unique(matched_label)):
                 new_label[matched_label == x] = i
             cell_type = [label_map[x] for x in np.unique(matched_label)]
             dataset = GeneExpressionDataset(
                 *GeneExpressionDataset.get_attributes_from_matrix(
                     new_count, labels=new_label),
                 gene_names=geneid,
                 cell_types=cell_type)
             print(dataset.X.shape, len(dataset.labels))
             if len(regev_data) > 0:
                 regev_data = GeneExpressionDataset.concat_datasets(
                     regev_data, dataset)
             else:
                 regev_data = dataset
         dataset = regev_data
         cell_type = dataset.cell_types
         groups = [
             'Pvalb', 'L2/3', 'Sst', 'L5 PT', 'L5 IT Tcap', 'L5 IT Aldh1a7',
             'L5 IT Foxp2', 'L5 NP', 'L6 IT', 'L6 CT', 'L6 NP', 'L6b',
             'Lamp5', 'Vip', 'Astro', 'OPC', 'VLMC', 'Oligo', 'Sncg',
             'Endo', 'SMC', 'MICRO'
         ]
         cell_type = [x.upper() for x in cell_type]
         groups = [x.upper() for x in groups]
         labels = np.asarray(
             [cell_type[x] for x in np.concatenate(dataset.labels)])
         cell_type_bygroup = np.concatenate(
             [[x for x in cell_type if x.startswith(y)] for y in groups])
         new_labels_dict = dict(
             zip(cell_type_bygroup, np.arange(len(cell_type_bygroup))))
         new_labels = np.asarray([new_labels_dict[x] for x in labels])
         labels_groups = [[
             i for i, x in enumerate(groups) if y.startswith(x)
         ][0] for y in cell_type_bygroup]
         dump_svmlight_file(dataset.X, new_labels,
                            self.save_path + 'regev_data.svmlight')
         np.save(self.save_path + 'regev_data.celltypes.npy',
                 cell_type_bygroup)
         np.save(self.save_path + 'regev_data.gene_names.npy',
                 dataset.gene_names)
         np.save(self.save_path + 'regev_data.labels_groups.npy',
                 labels_groups)
         return (dataset.X, new_labels, cell_type_bygroup,
                 dataset.gene_names, labels_groups)
LUAD_full.columns = ["sample_" + str(i) for i in range(1, 1402)]

# write count data into desired format

LUAD_full.iloc[:, 0:274].to_csv("./data/count_data_LUAD_v1_batch1.csv",
                                sep=",")
LUAD_full.iloc[:, 274:1176].to_csv("./data/count_data_LUAD_v1_batch2.csv",
                                   sep=",")
LUAD_full.iloc[:, 1176:1401].to_csv("./data/count_data_LUAD_v1_batch3.csv",
                                    sep=",")

LUAD_batch_1 = CsvDataset("count_data_LUAD_v1_batch1.csv", new_n_genes=2267)
LUAD_batch_2 = CsvDataset("count_data_LUAD_v1_batch2.csv", new_n_genes=2267)
LUAD_batch_3 = CsvDataset("count_data_LUAD_v1_batch3.csv", new_n_genes=2267)

LUAD_data = GeneExpressionDataset.concat_datasets(LUAD_batch_1, LUAD_batch_2,
                                                  LUAD_batch_3)

LUAD_vae = VAE(LUAD_data.nb_genes,
               n_batch=LUAD_data.n_batches,
               n_labels=LUAD_data.n_labels,
               n_hidden=128,
               n_latent=30,
               n_layers=2,
               dispersion='gene')

LUAD_trainer = UnsupervisedTrainer(LUAD_vae, LUAD_data, train_size=0.9)

LUAD_trainer.train(n_epochs=100)

LUAD_full = LUAD_trainer.create_posterior(LUAD_trainer.model,
                                          LUAD_data,
示例#10
0
labels_new = deepcopy(pbmc_labels)
for i, j in enumerate(labels_map):
    labels_new[pbmc_labels == i] = j

dataset3 = GeneExpressionDataset(
    *GeneExpressionDataset.get_attributes_from_matrix(pbmc.tocsr(),
                                                      labels=labels_new),
    gene_names=genenames,
    cell_types=cell_type)

sub_dataset1 = sample_celltype(dataset1, subpop, prop)
print('total number of cells =' + str([
    np.sum(sub_dataset1.labels == i)
    for i, k in enumerate(sub_dataset1.cell_types) if k == subpop
][0]))
gene_dataset = GeneExpressionDataset.concat_datasets(sub_dataset1, dataset2,
                                                     dataset3)
gene_dataset.subsample_genes(5000)

vae = VAE(gene_dataset.nb_genes,
          n_batch=gene_dataset.n_batches,
          n_labels=gene_dataset.n_labels,
          n_hidden=128,
          n_latent=10,
          n_layers=1,
          dispersion='gene')
infer_vae = VariationalInference(vae, gene_dataset, use_cuda=use_cuda)
infer_vae.fit(n_epochs=100)

np.save("../" + plotname + '.label.npy', gene_dataset.labels)
np.save("../" + plotname + '.batch.npy', gene_dataset.batch_indices)
mmwrite("../" + plotname + '.count.mtx', gene_dataset.X)