gene_dataset = GeneExpressionDataset.concat_datasets(dataset1, dataset2) gene_dataset.subsample_genes(5000) if model_type == 'vae': vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches, n_labels=gene_dataset.n_labels, n_hidden=128, n_latent=10, n_layers=2, dispersion='gene') infer = VariationalInference(vae, gene_dataset, use_cuda=use_cuda) infer.train(n_epochs=250) data_loader = infer.data_loaders['sequential'] latent, batch_indices, labels = get_latent(vae, data_loader) keys = gene_dataset.cell_types batch_indices = np.concatenate(batch_indices) keys = gene_dataset.cell_types elif model_type == 'svaec': svaec = SCANVI(gene_dataset.nb_genes, gene_dataset.n_batches, gene_dataset.n_labels, use_labels_groups=False, n_latent=10, n_layers=2) infer = SemiSupervisedVariationalInference(svaec, gene_dataset) infer.train(n_epochs=50) print('svaec acc =', infer.accuracy('unlabelled')) data_loader = infer.data_loaders['unlabelled'] latent, batch_indices, labels = get_latent(infer.model, infer.data_loaders['unlabelled']) keys = gene_dataset.cell_types batch_indices = np.concatenate(batch_indices) elif model_type == 'Seurat': # SEURAT = SEURAT() # seurat1 = SEURAT.create_seurat(dataset1, 1) # seurat2 = SEURAT.create_seurat(dataset2, 2) # latent, batch_indices,labels,keys = SEURAT.get_cca() latent = np.genfromtxt('../macosko_regev.CCA.txt') label = np.genfromtxt('../macosko_regev.CCA.label.txt',dtype='str')
latent, batch_indices, labels = full.sequential().get_latent() # n_samples_tsne = 4000 # full.show_t_sne(n_samples=n_samples_tsne, color_by='batches and labels', # save_name=os.path.join(save_path, "scVI_tSNE_batches_labels.png")) print("Transferring labels from scVI") true_labels = labels.ravel() scVI_labels = transfer_nn_labels(latent, labels, batch_indices) # train scANVI print("Training scANVI") # scanvi = SCANVI(gene_dataset.nb_genes, gene_dataset.n_batches, gene_dataset.n_labels, n_latent=10) scanvi = SCANVI(gene_dataset.nb_genes, gene_dataset.n_batches, gene_dataset.n_labels, n_latent=10, reconstruction_loss='nb') scanvi.load_state_dict(trainer.model.state_dict(), strict=False) trainer_scanvi = AlternateSemiSupervisedTrainer(scanvi, gene_dataset, n_epochs_classifier=5, lr_classification=5 * 1e-3) labelled = np.where(gene_dataset.batch_indices == 0)[0] # np.random.shuffle(labelled) unlabelled = np.where(gene_dataset.batch_indices == 1)[0] # np.random.shuffle(unlabelled) trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(indices=labelled) trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior( indices=unlabelled)
rep='') acc, cell_type = KNNpurity(latent1, latent2, latent, batch_indices.ravel(), labels, keys) f.write('vae' + '\t' + rmCellTypes + ("\t%.4f" * 8 + "\t%s" * 8 + "\n") % tuple(list(acc) + list(cell_type))) be, cell_type2 = BEbyType(keys, latent, labels, batch_indices, celltype1) g.write('vae' + '\t' + rmCellTypes + ("\t%.4f" * 8 + "\t%s" * 8 + "\n") % tuple(be + list(cell_type2))) plotUMAP(latent, plotname, 'vae', gene_dataset.cell_types, rmCellTypes, gene_dataset.batch_indices.ravel()) scanvi = SCANVI(gene_dataset.nb_genes, 2, (gene_dataset.n_labels + 1), n_hidden=128, n_latent=10, n_layers=2, dispersion='gene') scanvi.load_state_dict(trainer.model.state_dict(), strict=False) trainer_scanvi = AlternateSemiSupervisedTrainer( scanvi, gene_dataset, n_epochs_classifier=10, lr_classification=5 * 1e-3) trainer_scanvi.labelled_set = trainer_scanvi.create_posterior( indices=gene_dataset.batch_indices.ravel() == 0) trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior( indices=gene_dataset.batch_indices.ravel() == 1) trainer_scanvi.train(n_epochs=10) scanvi_full = trainer_scanvi.create_posterior( trainer_scanvi.model,
data_loader = infer_vae.data_loaders['sequential'] latent, batch_indices, labels = get_latent(vae, data_loader) keys = gene_dataset.cell_types batch_indices = np.concatenate(batch_indices) keys = gene_dataset.cell_types elif model_type == 'svaec': gene_dataset.subsample_genes(1000) n_epochs_vae = 100 n_epochs_scanvi = 50 vae = VAE(gene_dataset.nb_genes, gene_dataset.n_batches, gene_dataset.n_labels, n_latent=10, n_layers=2) trainer = UnsupervisedTrainer(vae, gene_dataset, train_size=1.0) trainer.train(n_epochs=n_epochs_vae) for i in [0, 1]: scanvi = SCANVI(gene_dataset.nb_genes, gene_dataset.n_batches, gene_dataset.n_labels, n_layers=2) scanvi.load_state_dict(vae.state_dict(), strict=False) trainer_scanvi = SemiSupervisedTrainer(scanvi, gene_dataset, classification_ratio=1, n_epochs_classifier=1, lr_classification=5 * 1e-3) trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(indices=(gene_dataset.batch_indices == i)) trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior( indices=(gene_dataset.batch_indices == 1 - i) ) trainer_scanvi.model.eval() print('NN: ', trainer_scanvi.nn_latentspace()) trainer_scanvi.unlabelled_set.to_monitor = ['accuracy'] trainer_scanvi.labelled_set.to_monitor = ['accuracy'] trainer_scanvi.full_dataset.to_monitor = ['entropy_batch_mixing'] trainer_scanvi.train(n_epochs=n_epochs_scanvi)
trainer.model.load_state_dict( torch.load(save_path + 'DE.vae.%i.mis%.2f.pkl' % (rep, misprop))) trainer.model.eval() full = trainer.create_posterior(trainer.model, gene_dataset, indices=np.arange(len(gene_dataset))) latent, batch_indices, _ = full.sequential().get_latent() print("Transferring labels from scVI") scVI_labels = transfer_nn_labels(latent, mislabels, batch_indices) # train scANVI print("Training scANVI") scanvi = SCANVI(gene_dataset.nb_genes, gene_dataset.n_batches, gene_dataset.n_labels, n_latent=10) scanvi.load_state_dict(trainer.model.state_dict(), strict=False) trainer_scanvi = SemiSupervisedTrainer(scanvi, gene_dataset, classification_ratio=50, n_epochs_classifier=1, lr_classification=5 * 1e-3) # trainer_scanvi = AlternateSemiSupervisedTrainer(scanvi, gene_dataset, # n_epochs_classifier=5, lr_classification=5 * 1e-3, kl=1) labelled = np.where(gene_dataset.batch_indices == 0)[0] np.random.shuffle(labelled) unlabelled = np.where(gene_dataset.batch_indices == 1)[0] np.random.shuffle(unlabelled) trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(indices=labelled) trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
def CompareModels(gene_dataset, dataset1, dataset2, plotname, models): KNeighbors = np.concatenate( [np.arange(10, 100, 10), np.arange(100, 500, 50)]) K_int = np.concatenate([np.repeat(10, 10), np.repeat(50, 7)]) f = open('../' + plotname + '/' + models + '.res.txt', "w+") f.write("model_type " + \ "knn_asw knn_nmi knn_ari knn_uca knn_wuca " + \ "p_knn_asw p_knn_nmi p_knn_ari p_knn_uca p_knn_wuca " + \ "p1_knn_asw p1_knn_nmi p1_knn_ari p1_knn_uca p1_knn_wuca " + \ "p2_knn_asw p2_knn_nmi p2_knn_ari p2_knn_uca p2_knn_wuca " + \ "kmeans_asw kmeans_nmi kmeans_ari kmeans_uca kmeans_wuca " + \ "p_kmeans_asw p_kmeans_nmi p_kmeans_ari p_kmeans_uca p_kmeans_wuca " + \ "p1_kmeans_asw p1_kmeans_nmi p1_kmeans_ari p1_kmeans_uca p1_kmeans_wuca " + \ "p2_kmeans_asw p2_kmeans_nmi p2_kmeans_ari p2_kmeans_uca p2_kmeans_wuca " + \ " ".join(['res_jaccard' + x for x in np.concatenate([np.repeat(10, 10), np.repeat(50, 7)]).astype('str')]) + " " + \ 'jaccard_score likelihood BE classifier_acc\n' ) g = open('../' + plotname + '/' + models + '.percluster.res.txt', "w+") g.write("model_type\tannotation\t" + "\t".join(gene_dataset.cell_types) + "\n") scanvi = SCANVI(gene_dataset.nb_genes, gene_dataset.n_batches, gene_dataset.n_labels) trainer_scanvi = SemiSupervisedTrainer(scanvi, gene_dataset, classification_ratio=1, n_epochs_classifier=1, lr_classification=5 * 1e-3) labelled_idx = trainer_scanvi.labelled_set.indices unlabelled_idx = trainer_scanvi.unlabelled_set.indices if models == 'others': latent1 = np.genfromtxt('../harmonization/Seurat_data/' + plotname + '.1.CCA.txt') latent2 = np.genfromtxt('../harmonization/Seurat_data/' + plotname + '.2.CCA.txt') for model_type in [ 'scmap', 'readSeurat', 'coral', 'Combat', 'MNN', 'PCA' ]: print(model_type) if (model_type == 'scmap') or (model_type == 'coral'): latent, batch_indices, labels, keys, stats = run_model( model_type, gene_dataset, dataset1, dataset2, filename=plotname) pred1 = latent pred2 = stats res1 = scmap_eval(pred1, labels[batch_indices == 1], labels) res2 = scmap_eval(pred2, labels[batch_indices == 0], labels) g.write("%s\t" % (model_type) + "p1\t" + ("%.4f\t" * len(gene_dataset.cell_types) % tuple(res1['clusteracc']) + "\n")) g.write("%s\t" % (model_type) + "p2\t" + ("%.4f\t" * len(gene_dataset.cell_types) % tuple(res2['clusteracc']) + "\n")) res = [-1] * 10 + \ [-1] + [res1[x] for x in ['nmi', 'ari', 'ca', 'weighted ca']] + \ [-1] + [res2[x] for x in ['nmi', 'ari', 'ca', 'weighted ca']] + \ [-1] * 41 f.write(model_type + (" %.4f" * 61 + "\n") % tuple(res)) else: if model_type == 'readSeurat': dataset1, dataset2, gene_dataset = SubsetGenes( dataset1, dataset2, gene_dataset, plotname) latent, batch_indices, labels, keys, stats = run_model( model_type, gene_dataset, dataset1, dataset2, filename=plotname) res_jaccard = [ KNNJaccardIndex(latent1, latent2, latent, batch_indices, k)[0] for k in KNeighbors ] res_jaccard_score = np.sum(res_jaccard * K_int) res_knn, res_knn_partial, res_kmeans, res_kmeans_partial = \ eval_latent(batch_indices, labels, latent, keys, labelled_idx, unlabelled_idx, plotname=plotname + '.' + model_type, plotting=False, partial_only=False) _, res_knn_partial1, _, res_kmeans_partial1 = \ eval_latent(batch_indices, labels, latent, keys, batch_indices == 0, batch_indices == 1, plotname=plotname + '.' + model_type, plotting=False) _, res_knn_partial2, _, res_kmeans_partial2 = \ eval_latent(batch_indices, labels, latent, keys, batch_indices == 1, batch_indices == 0, plotname=plotname + '.' + model_type, plotting=False) sample = select_indices_evenly( np.min(np.unique(batch_indices, return_counts=True)[1]), batch_indices) batch_entropy = entropy_batch_mixing(latent[sample, :], batch_indices[sample]) res = [res_knn[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \ [res_knn_partial[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \ [res_knn_partial1[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \ [res_knn_partial2[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \ [res_kmeans[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \ [res_kmeans_partial[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \ [res_kmeans_partial1[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \ [res_kmeans_partial2[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \ res_jaccard + \ [res_jaccard_score, -1, batch_entropy, -1] f.write(model_type + (" %.4f" * 61 + "\n") % tuple(res)) g.write("%s\t" % (model_type) + 'all\t' + ("%.4f\t" * len(gene_dataset.cell_types) % tuple(res_knn['clusteracc']) + "\n")) g.write("%s\t" % (model_type) + 'p\t' + ("%.4f\t" * len(gene_dataset.cell_types) % tuple(res_knn_partial['clusteracc']) + "\n")) g.write("%s\t" % (model_type) + 'p1\t' + ("%.4f\t" * len(gene_dataset.cell_types) % tuple(res_knn_partial1['clusteracc']) + "\n")) g.write("%s\t" % (model_type) + 'p2\t' + ("%.4f\t" * len(gene_dataset.cell_types) % tuple(res_knn_partial2['clusteracc']) + "\n")) elif (models == 'scvi') or (models == 'scvi_nb'): dataset1, dataset2, gene_dataset = SubsetGenes(dataset1, dataset2, gene_dataset, plotname) if models == 'scvi_nb': latent1, _, _, _, _ = run_model('vae_nb', dataset1, 0, 0, filename=plotname, rep='vae1_nb') latent2, _, _, _, _ = run_model('vae_nb', dataset2, 0, 0, filename=plotname, rep='vae2_nb') else: latent1, _, _, _, _ = run_model('vae', dataset1, 0, 0, filename=plotname, rep='vae1') latent2, _, _, _, _ = run_model('vae', dataset2, 0, 0, filename=plotname, rep='vae2') for model_type in [ 'vae', 'scanvi1', 'scanvi2', 'vae_nb', 'scanvi1_nb', 'scanvi2_nb' ]: print(model_type) latent, batch_indices, labels, keys, stats = run_model( model_type, gene_dataset, dataset1, dataset2, filename=plotname, rep='0') res_jaccard = [ KNNJaccardIndex(latent1, latent2, latent, batch_indices, k)[0] for k in KNeighbors ] res_jaccard_score = np.sum(res_jaccard * K_int) res_knn, res_knn_partial, res_kmeans, res_kmeans_partial = \ eval_latent(batch_indices=batch_indices, labels=labels, latent=latent, keys=keys, labelled_idx=labelled_idx, unlabelled_idx=unlabelled_idx, plotname=plotname + '.' + model_type, plotting=False, partial_only=False) _, res_knn_partial1, _, res_kmeans_partial1 = \ eval_latent(batch_indices=batch_indices, labels=labels, latent=latent, keys=keys, labelled_idx=(batch_indices == 0), unlabelled_idx=(batch_indices == 1), plotname=plotname + '.' + model_type, plotting=False) _, res_knn_partial2, _, res_kmeans_partial2 = \ eval_latent(batch_indices=batch_indices, labels=labels, latent=latent, keys=keys, labelled_idx=(batch_indices == 1), unlabelled_idx=(batch_indices == 0), plotname=plotname + '.' + model_type, plotting=False) res = [res_knn[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \ [res_knn_partial[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \ [res_knn_partial1[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \ [res_knn_partial2[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \ [res_kmeans[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \ [res_kmeans_partial[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \ [res_kmeans_partial1[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \ [res_kmeans_partial2[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \ res_jaccard + \ [res_jaccard_score, stats[0], stats[1], stats[2]] f.write(model_type + (" %.4f" * 61 + "\n") % tuple(res)) g.write("%s\t" % (model_type) + 'all\t' + ("%.4f\t" * len(gene_dataset.cell_types) % tuple(res_knn['clusteracc']) + "\n")) g.write("%s\t" % (model_type) + 'p\t' + ("%.4f\t" * len(gene_dataset.cell_types) % tuple(res_knn_partial['clusteracc']) + "\n")) g.write("%s\t" % (model_type) + 'p1\t' + ("%.4f\t" * len(gene_dataset.cell_types) % tuple(res_knn_partial1['clusteracc']) + "\n")) g.write("%s\t" % (model_type) + 'p2\t' + ("%.4f\t" * len(gene_dataset.cell_types) % tuple(res_knn_partial2['clusteracc']) + "\n")) # for i in [1, 2, 3]: # latent, batch_indices, labels, keys, stats = run_model(model_type, gene_dataset, dataset1, dataset2, # filename=plotname, rep=str(i)) # res_jaccard, res_jaccard_score = KNNJaccardIndex(latent1, latent2, latent, batch_indices) # # res_knn, res_knn_partial, res_kmeans, res_kmeans_partial = \ # eval_latent(batch_indices=batch_indices, labels=labels, latent=latent, keys=keys, # labelled_idx=labelled_idx, unlabelled_idx=unlabelled_idx, # plotname=plotname + '.' + model_type, plotting=False,partial_only=False) # # _, res_knn_partial1, _, res_kmeans_partial1 = \ # eval_latent(batch_indices=batch_indices, labels=labels, latent=latent, keys=keys, # labelled_idx=(batch_indices == 0), unlabelled_idx=(batch_indices == 1), # plotname=plotname + '.' + model_type, plotting=False) # # _, res_knn_partial2, _, res_kmeans_partial2 = \ # eval_latent(batch_indices=batch_indices, labels=labels, latent=latent, keys=keys, # labelled_idx=(batch_indices == 1), unlabelled_idx=(batch_indices == 0), # plotname=plotname + '.' + model_type, plotting=False) # # res = [res_knn[x] for x in res_knn] + \ # [res_knn_partial[x] for x in res_knn_partial] + \ # [res_knn_partial1[x] for x in res_knn_partial1] + \ # [res_knn_partial2[x] for x in res_knn_partial2] + \ # [res_kmeans[x] for x in res_kmeans] + \ # [res_kmeans_partial[x] for x in res_kmeans_partial] + \ # [res_kmeans_partial1[x] for x in res_kmeans_partial1] + \ # [res_kmeans_partial2[x] for x in res_kmeans_partial2] + \ # res_jaccard + \ # [res_jaccard_score,stats[0], stats[1], stats[2]] # f.write(model_type + (" %.4f" * 61 + "\n") % tuple(res)) elif models == 'writedata': _, _, _, _, _ = run_model('writedata', gene_dataset, dataset1, dataset2, filename=plotname) f.close() g.close()
def trainSCANVI(gene_dataset, model_type, filename, rep, nlayers=2, reconstruction_loss: str = "zinb"): vae_posterior = trainVAE(gene_dataset, filename, rep, reconstruction_loss=reconstruction_loss) filename = '../' + filename + '/' + model_type + '.' + reconstruction_loss + '.rep' + str( rep) + '.pkl' scanvi = SCANVI(gene_dataset.nb_genes, gene_dataset.n_batches, gene_dataset.n_labels, n_layers=nlayers, reconstruction_loss=reconstruction_loss) scanvi.load_state_dict(vae_posterior.model.state_dict(), strict=False) if model_type == 'scanvi1': trainer_scanvi = AlternateSemiSupervisedTrainer( scanvi, gene_dataset, classification_ratio=0, n_epochs_classifier=100, lr_classification=5 * 1e-3) labelled = np.where(gene_dataset.batch_indices.ravel() == 0)[0] labelled = np.random.choice(labelled, len(labelled), replace=False) trainer_scanvi.labelled_set = trainer_scanvi.create_posterior( indices=labelled) trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior( indices=(gene_dataset.batch_indices.ravel() == 1)) elif model_type == 'scanvi2': trainer_scanvi = AlternateSemiSupervisedTrainer( scanvi, gene_dataset, classification_ratio=0, n_epochs_classifier=100, lr_classification=5 * 1e-3) labelled = np.where(gene_dataset.batch_indices.ravel() == 1)[0] labelled = np.random.choice(labelled, len(labelled), replace=False) trainer_scanvi.labelled_set = trainer_scanvi.create_posterior( indices=labelled) trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior( indices=(gene_dataset.batch_indices.ravel() == 0)) elif model_type == 'scanvi0': trainer_scanvi = SemiSupervisedTrainer(scanvi, gene_dataset, classification_ratio=0, n_epochs_classifier=100, lr_classification=5 * 1e-3) trainer_scanvi.labelled_set = trainer_scanvi.create_posterior( indices=(gene_dataset.batch_indices.ravel() < 0)) trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior( indices=(gene_dataset.batch_indices.ravel() >= 0)) else: trainer_scanvi = SemiSupervisedTrainer(scanvi, gene_dataset, classification_ratio=10, n_epochs_classifier=100, lr_classification=5 * 1e-3) if os.path.isfile(filename): trainer_scanvi.model.load_state_dict(torch.load(filename)) trainer_scanvi.model.eval() else: trainer_scanvi.train(n_epochs=5) torch.save(trainer_scanvi.model.state_dict(), filename) return trainer_scanvi