예제 #1
0
def run_scVI(DataPath, LabelsPath, CV_RDataPath, OutputDir, GeneOrderPath = "", NumGenes = 0):
    '''
    run scVI
    Wrapper script to run scVI on a benchmark dataset with 5-fold cross validation,
    outputs lists of true and predicted cell labels as csv files, as well as computation time.
  
    Parameters
    ----------
    DataPath : Data file path (.csv), cells-genes matrix with cell unique barcodes 
    as row names and gene names as column names.
    LabelsPath : Cell population annotations file path (.csv).
    CV_RDataPath : Cross validation RData file path (.RData), obtained from Cross_Validation.R function.
    OutputDir : Output directory defining the path of the exported file.
    GeneOrderPath : Gene order file path (.csv) obtained from feature selection, 
    defining the genes order for each cross validation fold, default is NULL.
    NumGenes : Number of genes used in case of feature selection (integer), default is 0.
    '''
    
    # read the Rdata file
    robjects.r['load'](CV_RDataPath)

    nfolds = np.array(robjects.r['n_folds'], dtype = 'int')
    tokeep = np.array(robjects.r['Cells_to_Keep'], dtype = 'bool')
    col = np.array(robjects.r['col_Index'], dtype = 'int')
    col = col - 1 
    test_ind = np.array(robjects.r['Test_Idx'])
    train_ind = np.array(robjects.r['Train_Idx'])

    # read the data
    data = pd.read_csv(DataPath,index_col=0,sep=',')
    labels = pd.read_csv(LabelsPath, header=0,index_col=None, sep=',', usecols = col)

    labels = labels.iloc[tokeep]
    data = data.iloc[tokeep] 
    
    # read the feature file
    if (NumGenes > 0):
        features = pd.read_csv(GeneOrderPath,header=0,index_col=None, sep=',')
    
    if (NumGenes == 0):
        #save labels as csv file with header and index column
        labels.to_csv('Labels_scvi.csv')
        data.to_csv('Data_scvi.csv')    
        
        train = CsvDataset('Data_scvi.csv', save_path = "", sep = ",", labels_file = "Labels_scvi.csv", gene_by_cell = False)
        
        ## this semisupervised trainer automatically uses a part of the input data for training and a part for testing
        scanvi = SCANVI(train.nb_genes, train.n_batches, train.n_labels)
        trainer_scanvi = SemiSupervisedTrainer(scanvi, train, frequency=5)
    
    n_epochs = 200
    
    truelab = []
    pred = []
    tr_time = []
    ts_time = []
    
    for i in range(np.squeeze(nfolds)):
        test_ind_i = np.array(test_ind[i], dtype = 'int') - 1
        train_ind_i = np.array(train_ind[i], dtype = 'int') - 1
        
        if (NumGenes > 0):
            feat_to_use = features.iloc[0:NumGenes,i]
            data2 = data.iloc[:,feat_to_use]
            
            labels.to_csv('Labels_scvi.csv')
            data2.to_csv('Data_scvi.csv')    
            
            train = CsvDataset('Data_scvi.csv', save_path = "", sep = ",", labels_file = "Labels_scvi.csv", gene_by_cell = False, new_n_genes = False)
            
            ## this semisupervised trainer automatically uses a part of the input data for training and a part for testing
            scanvi = SCANVI(train.nb_genes, train.n_batches, train.n_labels)
            trainer_scanvi = SemiSupervisedTrainer(scanvi, train, frequency=5)

        trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(indices=(train_ind_i).ravel(), shuffle = False)
        trainer_scanvi.labelled_set.to_monitor = ['ll','accuracy']
        trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(indices=(test_ind_i).ravel(), shuffle = False)
        trainer_scanvi.unlabelled_set.to_monitor = ['ll','accuracy']
    
        start = tm.time()
        trainer_scanvi.train(n_epochs)
        tr_time.append(tm.time()-start)
    
        ## labels of test set are in y_pred
        ## labels are returned in numbers, should be mapped back to the real labels
        ## indices are permutated
        start = tm.time()
        y_true, y_pred = trainer_scanvi.unlabelled_set.compute_predictions()
        ts_time.append(tm.time()-start)
        
        truelab.extend(y_true)
        pred.extend(y_pred)
    
    #write results
    os.chdir(OutputDir)
    
    truelab = pd.DataFrame(truelab)
    pred = pd.DataFrame(pred)
    
    tr_time = pd.DataFrame(tr_time)
    ts_time = pd.DataFrame(ts_time)

    
    if (NumGenes == 0):  
        truelab.to_csv("scVI_True_Labels.csv", index = False)
        pred.to_csv("scVI_Pred_Labels.csv", index = False)
        tr_time.to_csv("scVI_Training_Time.csv", index = False)
        ts_time.to_csv("scVI_Testing_Time.csv", index = False)
    else:
        truelab.to_csv("scVI_" + str(NumGenes) + "_True_Labels.csv", index = False)
        pred.to_csv("scVI_" + str(NumGenes) + "_Pred_Labels.csv", index = False)
        tr_time.to_csv("scVI_" + str(NumGenes) + "_Training_Time.csv", index = False)
        ts_time.to_csv("scVI_" + str(NumGenes) + "_Testing_Time.csv", index = False)
예제 #2
0
def run_scVI(input_dir, output_dir, datafile, labfile, Rfile):
    '''
    Run scVI
	
	Parameters
	----------
	input_dir : directory of the input files
	output_dir : directory of the output files
	datafile : name of the data file
    labfile : name of the label file
    Rfile : file to read the cross validation indices from
    '''
    os.chdir(input_dir)

    # read the Rdata file
    robjects.r['load'](Rfile)

    nfolds = np.array(robjects.r['n_folds'], dtype='int')
    tokeep = np.array(robjects.r['Cells_to_Keep'], dtype='bool')
    col = np.array(robjects.r['col_Index'], dtype='int')
    col = col - 1
    test_ind = np.array(robjects.r['Test_Idx'])
    train_ind = np.array(robjects.r['Train_Idx'])

    # read the data
    os.chdir(input_dir)
    data = pd.read_csv(datafile, index_col=0, sep=',')
    labels = pd.read_csv(labfile,
                         header=0,
                         index_col=None,
                         sep=',',
                         usecols=col)

    labels = labels.iloc[tokeep]
    data = data.iloc[tokeep]

    #save labels as csv file with header and index column
    labels.to_csv('Labels_scvi.csv')
    data.to_csv('Data_scvi.csv')

    train = CsvDataset('Data_scvi.csv',
                       save_path=input_dir,
                       sep=",",
                       labels_file="Labels_scvi.csv",
                       gene_by_cell=False)

    ## this semisupervised trainer automatically uses a part of the input data for training and a part for testing
    scanvi = SCANVI(train.nb_genes, train.n_batches, train.n_labels)
    trainer_scanvi = SemiSupervisedTrainer(scanvi, train, frequency=5)

    n_epochs = 200

    truelab = []
    pred = []
    tr_time = []
    ts_time = []

    for i in range(np.squeeze(nfolds)):
        test_ind_i = np.array(test_ind[i], dtype='int') - 1
        train_ind_i = np.array(train_ind[i], dtype='int') - 1

        trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
            indices=(train_ind_i).ravel(), shuffle=False)
        trainer_scanvi.labelled_set.to_monitor = ['ll', 'accuracy']
        trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
            indices=(test_ind_i).ravel(), shuffle=False)
        trainer_scanvi.unlabelled_set.to_monitor = ['ll', 'accuracy']

        start = tm.time()
        trainer_scanvi.train(n_epochs)
        tr_time.append(tm.time() - start)

        ## labels of test set are in y_pred
        ## labels are returned in numbers, should be mapped back to the real labels
        ## indices are permutated
        start = tm.time()
        y_true, y_pred = trainer_scanvi.unlabelled_set.compute_predictions()
        ts_time.append(tm.time() - start)

        truelab.extend(y_true)
        pred.extend(y_pred)

    #write results
    os.chdir(output_dir)

    truelab = pd.DataFrame(truelab)
    pred = pd.DataFrame(pred)

    tr_time = pd.DataFrame(tr_time)
    ts_time = pd.DataFrame(ts_time)

    truelab.to_csv("scVI_" + str(col) + "_true.csv", index=False)
    pred.to_csv("scVI_" + str(col) + "_pred.csv", index=False)

    tr_time.to_csv("scVI_" + str(col) + "_training_time.csv", index=False)
    ts_time.to_csv("scVI_" + str(col) + "_test_time.csv", index=False)
def custom_objective_hyperopt(space,
                              is_best_training=False,
                              dataset=None,
                              n_epochs=None):
    """Custom objective function for advanced autotune tutorial."""
    space = defaultdict(dict, space)
    model_tunable_kwargs = space["model_tunable_kwargs"]
    trainer_tunable_kwargs = space["trainer_tunable_kwargs"]
    train_func_tunable_kwargs = space["train_func_tunable_kwargs"]

    trainer_specific_kwargs = {}
    model_specific_kwargs = {}
    train_func_specific_kwargs = {}
    trainer_specific_kwargs["use_cuda"] = bool(torch.cuda.device_count())
    train_func_specific_kwargs["n_epochs"] = n_epochs

    # add hardcoded parameters
    # disable scVI progbar
    trainer_specific_kwargs["show_progbar"] = False
    trainer_specific_kwargs["frequency"] = 1

    # merge params with fixed param precedence
    model_tunable_kwargs.update(model_specific_kwargs)
    trainer_tunable_kwargs.update(trainer_specific_kwargs)
    train_func_tunable_kwargs.update(train_func_specific_kwargs)

    scanvi = SCANVI(dataset.nb_genes, dataset.n_batches, dataset.n_labels,
                    **model_tunable_kwargs)
    trainer_scanvi = SemiSupervisedTrainer(scanvi, dataset,
                                           **trainer_tunable_kwargs)
    trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
        indices=np.squeeze(dataset.batch_indices == 1))
    trainer_scanvi.unlabelled_set.to_monitor = [
        "reconstruction_error", "accuracy"
    ]
    indices_labelled = np.squeeze(dataset.batch_indices == 0)

    if not is_best_training:
        # compute k-fold accuracy on a 20% validation set
        k = 5
        accuracies = np.zeros(k)
        indices_labelled = np.squeeze(dataset.batch_indices == 0)
        for i in range(k):
            indices_labelled_train, indices_labelled_val = train_test_split(
                indices_labelled.nonzero()[0], test_size=0.2)
            trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
                indices=indices_labelled_train)
            trainer_scanvi.labelled_set.to_monitor = [
                "reconstruction_error",
                "accuracy",
            ]
            trainer_scanvi.validation_set = trainer_scanvi.create_posterior(
                indices=indices_labelled_val)
            trainer_scanvi.validation_set.to_monitor = ["accuracy"]
            trainer_scanvi.train(**train_func_tunable_kwargs)
            accuracies[i] = trainer_scanvi.history["accuracy_unlabelled_set"][
                -1]
        return {
            "loss": -accuracies.mean(),
            "space": space,
            "status": STATUS_OK
        }
    else:
        trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
            indices=indices_labelled)
        trainer_scanvi.labelled_set.to_monitor = [
            "reconstruction_error", "accuracy"
        ]
        trainer_scanvi.train(**train_func_tunable_kwargs)
        return trainer_scanvi
예제 #4
0
adata_test.var_names_make_unique()


# PRE-PROCESS
# First find do log1p
sc.pp.filter_genes(adata, min_cells=int(0.05 * adata.shape[0]))
sc.pp.log1p(adata)
sc.pp.log1p(adata_test)
# Then find variable genes
sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor="seurat")
# Label test data
adata.obs["scanvi_test"] = False
adata_test.obs["scanvi_test"] = True

# SELECT SAME GENES
genes = adata[:,adata.var.highly_variable.tolist()].var_names
genes_shared = [i in adata_test.var_names.to_list() for i in genes]
genes = genes[genes_shared]

adata = adata[:, genes.to_list()]
adata_test = adata_test[:, genes.to_list()]

adata_merged = adata.concatenate(adata_test)

# SCANVI
adata_scanvi = AnnDatasetFromAnnData(adata_merged)
scanvi = SCANVI(adata_scanvi.nb_genes, adata_scanvi.n_batches, adata_scanvi.n_labels)
trainer_scanvi = SemiSupervisedTrainer(scanvi, adata_scanvi, frequency=5)

n_epochs = 200
trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(indices=adata_scanvi.)
예제 #5
0
def run_scVI(trainname, testname, n):

    #trainDataPath = "/Users/yue/Dropbox (Sydney Uni)/scclassify/scRNAseq_Benchmark_datasets/Pancreatic_data/Segerstolpe/Filtered_Segerstolpe_HumanPancreas_data.csv"
    #train = pd.read_csv(trainDataPath,index_col=0,sep=',')
    #trainLabelsPath =  "/Users/yue/Dropbox (Sydney Uni)/scclassify/scRNAseq_Benchmark_datasets/Pancreatic_data/Segerstolpe/Labels.csv"
    #trainlabels = pd.read_csv(trainLabelsPath, header=0,index_col=None, sep=',')

    #testDataPath = "/Users/yue/Dropbox (Sydney Uni)/scclassify/scRNAseq_Benchmark_datasets/Pancreatic_data/Xin/Filtered_Xin_HumanPancreas_data.csv"
    #test = pd.read_csv(testDataPath,index_col=0,sep=',')
    #testLabelsPath =  "/Users/yue/Dropbox (Sydney Uni)/scclassify/scRNAseq_Benchmark_datasets/Pancreatic_data/Xin/Labels.csv"
    #testlabels = pd.read_csv(testLabelsPath, header=0,index_col=None, sep=',')

    train = pd.read_csv(
        '/albona/nobackup/biostat/datasets/singlecell/tabulaMuris_benchmark/' +
        trainname + '.csv',
        index_col=0,
        sep=',')
    test = pd.read_csv(
        '/albona/nobackup/biostat/datasets/singlecell/tabulaMuris_benchmark/' +
        testname + '.csv',
        index_col=0,
        sep=',')
    trainlabel = pd.read_csv(
        '/albona/nobackup/biostat/datasets/singlecell/tabulaMuris_benchmark/' +
        trainname + '_label.csv',
        header=0,
        index_col=0,
        sep=',')
    testlabel = pd.read_csv(
        '/albona/nobackup/biostat/datasets/singlecell/tabulaMuris_benchmark/' +
        testname + '_label.csv',
        header=0,
        index_col=0,
        sep=',')

    newdata = pd.concat([train, test], axis=1)
    newlabel = pd.concat([trainlabel, testlabel], axis=0)

    #train = '/Users/yue/Dropbox (Sydney Uni)/scclassify/countmatrix/logcount/xin.csv'

    #save labels as csv file with header and index column
    #trainlabels.to_csv('trainLabels_scvi.csv')
    #train.to_csv('trainData_scvi.csv')

    #testlabels.to_csv('testLabels_scvi.csv')
    #test.to_csv('testData_scvi.csv')

    os.chdir("/dora/nobackup/yuec/scclassify/benchmark/scVI/vary_test")

    newdata.to_csv('data_scvi.csv')
    newlabel.to_csv('labels_scvi.csv')
    data = CsvDataset('data_scvi.csv',
                      save_path="",
                      sep=",",
                      labels_file="labels_scvi.csv",
                      gene_by_cell=True)

    n_epochs = 100

    truelab = []
    pred = []

    ## this semisupervised trainer automatically uses a part of the input data for training and a part for testing

    now = time.time()
    tracemalloc.start()

    scanvi = SCANVI(data.nb_genes, data.n_batches, data.n_labels)
    trainer_scanvi = SemiSupervisedTrainer(scanvi, data, frequency=5)

    trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
        indices=(list(range(0, trainlabel.shape[0]))), shuffle=False)
    trainer_scanvi.labelled_set.to_monitor = ['ll', 'accuracy']
    trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
        indices=(list(
            range(trainlabel.shape[0],
                  trainlabel.shape[0] + testlabel.shape[0]))),
        shuffle=False)
    trainer_scanvi.unlabelled_set.to_monitor = ['ll', 'accuracy']

    trainer_scanvi.train(n_epochs)

    snapshot = tracemalloc.take_snapshot()
    mem_train = display_top(snapshot)

    later = time.time()
    time_train = int(later - now)

    ## labels of test set are in y_pred
    ## labels are returned in numbers, should be mapped back to the real labels
    ## indices are permutated

    now = time.time()
    tracemalloc.start()

    y_true, y_pred = trainer_scanvi.unlabelled_set.compute_predictions()

    snapshot = tracemalloc.take_snapshot()
    mem_test = display_top(snapshot)

    later = time.time()
    time_test = int(later - now)

    truelab.extend(y_true)
    pred.extend(y_pred)

    truelab = pd.DataFrame(truelab)
    pred = pd.DataFrame(pred)

    os.chdir("/dora/nobackup/yuec/scclassify/benchmark/scVI/vary_test")

    truelab.to_csv(n + "_scVI_True.csv", index=False)
    pred.to_csv(n + "_scVI_Pred.csv", index=False)

    return mem_train, time_train, mem_test, time_test
예제 #6
0
def runScanvi(adata, batch, labels):
    # Use non-normalized (count) data for scanvi!

    # Check for counts data layer
    if 'counts' not in adata.layers:
        raise TypeError(
            'Adata does not contain a `counts` layer in `adata.layers[`counts`]`'
        )

    from scvi.models import VAE, SCANVI
    from scvi.inference import UnsupervisedTrainer, SemiSupervisedTrainer
    from sklearn.preprocessing import LabelEncoder
    from scvi.dataset import AnnDatasetFromAnnData
    import numpy as np

    # STEP 1: prepare the data
    net_adata = adata.copy()
    net_adata.X = adata.layers['counts']
    del net_adata.layers['counts']
    # Ensure that the raw counts are not accidentally used
    del net_adata.raw  # Note that this only works from anndata 0.7

    # Define batch indices
    le = LabelEncoder()
    net_adata.obs['batch_indices'] = le.fit_transform(
        net_adata.obs[batch].values)
    net_adata.obs['labels'] = le.fit_transform(net_adata.obs[labels].values)

    net_adata = AnnDatasetFromAnnData(net_adata)

    print("scANVI dataset object with {} batches and {} cell types".format(
        net_adata.n_batches, net_adata.n_labels))

    #if hvg is True:
    #    # this also corrects for different batches by default
    #    net_adata.subsample_genes(2000, mode="seurat_v3")

    # # Defaults from SCVI github tutorials scanpy_pbmc3k and harmonization
    n_epochs_scVI = np.min([round((20000 / adata.n_obs) * 400), 400])  #400
    n_epochs_scANVI = int(np.min([10, np.max([2, round(n_epochs_scVI / 3.)])]))
    n_latent = 30
    n_hidden = 128
    n_layers = 2

    # STEP 2: RUN scVI to initialize scANVI

    vae = VAE(
        net_adata.nb_genes,
        reconstruction_loss='nb',
        n_batch=net_adata.n_batches,
        n_latent=n_latent,
        n_hidden=n_hidden,
        n_layers=n_layers,
    )

    trainer = UnsupervisedTrainer(
        vae,
        net_adata,
        train_size=1.0,
        use_cuda=False,
    )

    trainer.train(n_epochs=n_epochs_scVI, lr=1e-3)

    # STEP 3: RUN scANVI

    scanvi = SCANVI(net_adata.nb_genes,
                    net_adata.n_batches,
                    net_adata.n_labels,
                    n_hidden=n_hidden,
                    n_latent=n_latent,
                    n_layers=n_layers,
                    dispersion='gene',
                    reconstruction_loss='nb')
    scanvi.load_state_dict(trainer.model.state_dict(), strict=False)

    # use default parameter from semi-supervised trainer class
    trainer_scanvi = SemiSupervisedTrainer(scanvi, net_adata)
    # use all cells as labelled set
    trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
        trainer_scanvi.model, net_adata, indices=np.arange(len(net_adata)))
    # put one cell in the unlabelled set
    trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
        indices=[0])
    trainer_scanvi.train(n_epochs=n_epochs_scANVI)

    # extract info from posterior
    scanvi_full = trainer_scanvi.create_posterior(trainer_scanvi.model,
                                                  net_adata,
                                                  indices=np.arange(
                                                      len(net_adata)))
    latent, _, _ = scanvi_full.sequential().get_latent()

    adata.obsm['X_emb'] = latent

    return adata
                                indices=np.arange(len(gene_dataset)))
latent, batch_indices, _ = full.sequential().get_latent()

print("Transferring labels from scVI")
scVI_labels = transfer_nn_labels(latent, mislabels, batch_indices)

# train scANVI
print("Training scANVI")
scanvi = SCANVI(gene_dataset.nb_genes,
                gene_dataset.n_batches,
                gene_dataset.n_labels,
                n_latent=10)
scanvi.load_state_dict(trainer.model.state_dict(), strict=False)
trainer_scanvi = SemiSupervisedTrainer(scanvi,
                                       gene_dataset,
                                       classification_ratio=50,
                                       n_epochs_classifier=1,
                                       lr_classification=5 * 1e-3)
# trainer_scanvi = AlternateSemiSupervisedTrainer(scanvi, gene_dataset,
#                                                 n_epochs_classifier=5, lr_classification=5 * 1e-3, kl=1)
labelled = np.where(gene_dataset.batch_indices == 0)[0]
np.random.shuffle(labelled)
unlabelled = np.where(gene_dataset.batch_indices == 1)[0]
np.random.shuffle(unlabelled)
trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(indices=labelled)
trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
    indices=unlabelled)

trainer_scanvi.train(n_epochs=5)

scanvi_labels = trainer_scanvi.full_dataset.sequential().compute_predictions(
예제 #8
0
    pbmc, pbmc2, gene_dataset = SubsetGenes(pbmc, pbmc2, gene_dataset, plotname + rmCellTypes.replace(' ', ''))
    latent1, _, _, _ = trainVAE(pbmc, rmCellTypes, rep='1')
    latent2, _, _, _ = trainVAE(pbmc2, rmCellTypes, rep='2')
    latent, batch_indices, labels, trainer = trainVAE(gene_dataset, rmCellTypes, rep='')
    acc, cell_type = KNNpurity(latent1, latent2,latent,batch_indices.ravel(),labels,keys)
    f.write('vae' + '\t' + rmCellTypes + ("\t%.4f" * 8 + "\t%s" * 8 + "\n") % tuple(list(acc) + list(cell_type)))
    be, cell_type2 = BEbyType(keys, latent, labels, batch_indices)
    g.write('vae' + '\t' + rmCellTypes + ("\t%.4f" * 8 + "\t%s" * 8 + "\n") % tuple(be + list(cell_type2)))
    # plotUMAP(latent, plotname, 'vae', rmCellTypes)

    labelledset = deepcopy(gene_dataset)
    labelledset.update_cells(gene_dataset.batch_indices.ravel() == 0)
    scanvi = SCANVI(labelledset.nb_genes, 2, (labelledset.n_labels + 1),
                    n_hidden=128, n_latent=10, n_layers=2, dispersion='gene')
    scanvi.load_state_dict(trainer.model.state_dict(), strict=False)
    trainer_scanvi = SemiSupervisedTrainer(scanvi, labelledset, n_epochs_classifier=1, lr_classification=5 * 1e-3)
    trainer_scanvi.train(n_epochs=5)

    # scanvi = SCANVI(gene_dataset.nb_genes, gene_dataset.n_batches, (gene_dataset.n_labels), n_layers=2)
    # scanvi.load_state_dict(trainer.model.state_dict(), strict=False)
    # trainer_scanvi = AlternateSemiSupervisedTrainer(scanvi, gene_dataset, classification_ratio=50,
    #                                                 n_epochs_classifier=100, lr_classification=5 * 1e-3)
    # trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(indices=gene_dataset.batch_indices.ravel() == 0)
    # trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(indices=gene_dataset.batch_indices.ravel() == 1)
    # if os.path.isfile('../NoOverlap/scanvi.%s.pkl' % rmCellTypes):
    #     trainer_scanvi.model.load_state_dict(torch.load('../NoOverlap/scanvi.%s.pkl' % rmCellTypes))
    #     trainer_scanvi.model.eval()
    # else:
    #     trainer_scanvi.train(n_epochs=10)
    #     torch.save(trainer_scanvi.model.state_dict(), '../NoOverlap/scanvi.%s.pkl' % rmCellTypes)
    scanvi_full = trainer_scanvi.create_posterior(trainer_scanvi.model, gene_dataset, indices=np.arange(len(gene_dataset)))
def CompareModels(gene_dataset, dataset1, dataset2, plotname, models):
    KNeighbors = np.concatenate(
        [np.arange(10, 100, 10),
         np.arange(100, 500, 50)])
    K_int = np.concatenate([np.repeat(10, 10), np.repeat(50, 7)])
    f = open('../' + plotname + '/' + models + '.res.txt', "w+")
    f.write("model_type " + \
            "knn_asw knn_nmi knn_ari knn_uca knn_wuca " + \
            "p_knn_asw p_knn_nmi p_knn_ari p_knn_uca p_knn_wuca " + \
            "p1_knn_asw p1_knn_nmi p1_knn_ari p1_knn_uca p1_knn_wuca " + \
            "p2_knn_asw p2_knn_nmi p2_knn_ari p2_knn_uca p2_knn_wuca " + \
            "kmeans_asw kmeans_nmi kmeans_ari kmeans_uca kmeans_wuca " + \
            "p_kmeans_asw p_kmeans_nmi p_kmeans_ari p_kmeans_uca p_kmeans_wuca " + \
            "p1_kmeans_asw p1_kmeans_nmi p1_kmeans_ari p1_kmeans_uca p1_kmeans_wuca " + \
            "p2_kmeans_asw p2_kmeans_nmi p2_kmeans_ari p2_kmeans_uca p2_kmeans_wuca " + \
            " ".join(['res_jaccard' + x for x in
                      np.concatenate([np.repeat(10, 10), np.repeat(50, 7)]).astype('str')]) + " " + \
            'jaccard_score likelihood BE classifier_acc\n'
            )
    g = open('../' + plotname + '/' + models + '.percluster.res.txt', "w+")
    g.write("model_type\tannotation\t" + "\t".join(gene_dataset.cell_types) +
            "\n")

    scanvi = SCANVI(gene_dataset.nb_genes, gene_dataset.n_batches,
                    gene_dataset.n_labels)
    trainer_scanvi = SemiSupervisedTrainer(scanvi,
                                           gene_dataset,
                                           classification_ratio=1,
                                           n_epochs_classifier=1,
                                           lr_classification=5 * 1e-3)
    labelled_idx = trainer_scanvi.labelled_set.indices
    unlabelled_idx = trainer_scanvi.unlabelled_set.indices

    if models == 'others':
        latent1 = np.genfromtxt('../harmonization/Seurat_data/' + plotname +
                                '.1.CCA.txt')
        latent2 = np.genfromtxt('../harmonization/Seurat_data/' + plotname +
                                '.2.CCA.txt')
        for model_type in [
                'scmap', 'readSeurat', 'coral', 'Combat', 'MNN', 'PCA'
        ]:
            print(model_type)
            if (model_type == 'scmap') or (model_type == 'coral'):
                latent, batch_indices, labels, keys, stats = run_model(
                    model_type,
                    gene_dataset,
                    dataset1,
                    dataset2,
                    filename=plotname)
                pred1 = latent
                pred2 = stats
                res1 = scmap_eval(pred1, labels[batch_indices == 1], labels)
                res2 = scmap_eval(pred2, labels[batch_indices == 0], labels)
                g.write("%s\t" % (model_type) + "p1\t" +
                        ("%.4f\t" * len(gene_dataset.cell_types) %
                         tuple(res1['clusteracc']) + "\n"))
                g.write("%s\t" % (model_type) + "p2\t" +
                        ("%.4f\t" * len(gene_dataset.cell_types) %
                         tuple(res2['clusteracc']) + "\n"))
                res = [-1] * 10 + \
                      [-1] + [res1[x] for x in ['nmi', 'ari', 'ca', 'weighted ca']] + \
                      [-1] + [res2[x] for x in ['nmi', 'ari', 'ca', 'weighted ca']] + \
                      [-1] * 41

                f.write(model_type + (" %.4f" * 61 + "\n") % tuple(res))
            else:
                if model_type == 'readSeurat':
                    dataset1, dataset2, gene_dataset = SubsetGenes(
                        dataset1, dataset2, gene_dataset, plotname)

                latent, batch_indices, labels, keys, stats = run_model(
                    model_type,
                    gene_dataset,
                    dataset1,
                    dataset2,
                    filename=plotname)

                res_jaccard = [
                    KNNJaccardIndex(latent1, latent2, latent, batch_indices,
                                    k)[0] for k in KNeighbors
                ]
                res_jaccard_score = np.sum(res_jaccard * K_int)
                res_knn, res_knn_partial, res_kmeans, res_kmeans_partial = \
                    eval_latent(batch_indices, labels, latent, keys,
                                labelled_idx, unlabelled_idx,
                                plotname=plotname + '.' + model_type, plotting=False, partial_only=False)

                _, res_knn_partial1, _, res_kmeans_partial1 = \
                    eval_latent(batch_indices, labels, latent, keys,
                                batch_indices == 0, batch_indices == 1,
                                plotname=plotname + '.' + model_type, plotting=False)

                _, res_knn_partial2, _, res_kmeans_partial2 = \
                    eval_latent(batch_indices, labels, latent, keys,
                                batch_indices == 1, batch_indices == 0,
                                plotname=plotname + '.' + model_type, plotting=False)

                sample = select_indices_evenly(
                    np.min(np.unique(batch_indices, return_counts=True)[1]),
                    batch_indices)
                batch_entropy = entropy_batch_mixing(latent[sample, :],
                                                     batch_indices[sample])

                res = [res_knn[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \
                      [res_knn_partial[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \
                      [res_knn_partial1[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \
                      [res_knn_partial2[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \
                      [res_kmeans[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \
                      [res_kmeans_partial[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \
                      [res_kmeans_partial1[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \
                      [res_kmeans_partial2[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \
                      res_jaccard + \
                      [res_jaccard_score, -1, batch_entropy, -1]

                f.write(model_type + (" %.4f" * 61 + "\n") % tuple(res))
                g.write("%s\t" % (model_type) + 'all\t' +
                        ("%.4f\t" * len(gene_dataset.cell_types) %
                         tuple(res_knn['clusteracc']) + "\n"))
                g.write("%s\t" % (model_type) + 'p\t' +
                        ("%.4f\t" * len(gene_dataset.cell_types) %
                         tuple(res_knn_partial['clusteracc']) + "\n"))
                g.write("%s\t" % (model_type) + 'p1\t' +
                        ("%.4f\t" * len(gene_dataset.cell_types) %
                         tuple(res_knn_partial1['clusteracc']) + "\n"))
                g.write("%s\t" % (model_type) + 'p2\t' +
                        ("%.4f\t" * len(gene_dataset.cell_types) %
                         tuple(res_knn_partial2['clusteracc']) + "\n"))

    elif (models == 'scvi') or (models == 'scvi_nb'):
        dataset1, dataset2, gene_dataset = SubsetGenes(dataset1, dataset2,
                                                       gene_dataset, plotname)
        if models == 'scvi_nb':
            latent1, _, _, _, _ = run_model('vae_nb',
                                            dataset1,
                                            0,
                                            0,
                                            filename=plotname,
                                            rep='vae1_nb')
            latent2, _, _, _, _ = run_model('vae_nb',
                                            dataset2,
                                            0,
                                            0,
                                            filename=plotname,
                                            rep='vae2_nb')
        else:
            latent1, _, _, _, _ = run_model('vae',
                                            dataset1,
                                            0,
                                            0,
                                            filename=plotname,
                                            rep='vae1')
            latent2, _, _, _, _ = run_model('vae',
                                            dataset2,
                                            0,
                                            0,
                                            filename=plotname,
                                            rep='vae2')

        for model_type in [
                'vae', 'scanvi1', 'scanvi2', 'vae_nb', 'scanvi1_nb',
                'scanvi2_nb'
        ]:
            print(model_type)
            latent, batch_indices, labels, keys, stats = run_model(
                model_type,
                gene_dataset,
                dataset1,
                dataset2,
                filename=plotname,
                rep='0')

            res_jaccard = [
                KNNJaccardIndex(latent1, latent2, latent, batch_indices, k)[0]
                for k in KNeighbors
            ]
            res_jaccard_score = np.sum(res_jaccard * K_int)
            res_knn, res_knn_partial, res_kmeans, res_kmeans_partial = \
                eval_latent(batch_indices=batch_indices, labels=labels, latent=latent, keys=keys,
                            labelled_idx=labelled_idx, unlabelled_idx=unlabelled_idx,
                            plotname=plotname + '.' + model_type, plotting=False, partial_only=False)

            _, res_knn_partial1, _, res_kmeans_partial1 = \
                eval_latent(batch_indices=batch_indices, labels=labels, latent=latent, keys=keys,
                            labelled_idx=(batch_indices == 0), unlabelled_idx=(batch_indices == 1),
                            plotname=plotname + '.' + model_type, plotting=False)

            _, res_knn_partial2, _, res_kmeans_partial2 = \
                eval_latent(batch_indices=batch_indices, labels=labels, latent=latent, keys=keys,
                            labelled_idx=(batch_indices == 1), unlabelled_idx=(batch_indices == 0),
                            plotname=plotname + '.' + model_type, plotting=False)

            res = [res_knn[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \
                  [res_knn_partial[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \
                  [res_knn_partial1[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \
                  [res_knn_partial2[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \
                  [res_kmeans[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \
                  [res_kmeans_partial[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \
                  [res_kmeans_partial1[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \
                  [res_kmeans_partial2[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \
                  res_jaccard + \
                  [res_jaccard_score, stats[0], stats[1], stats[2]]

            f.write(model_type + (" %.4f" * 61 + "\n") % tuple(res))
            g.write("%s\t" % (model_type) + 'all\t' +
                    ("%.4f\t" * len(gene_dataset.cell_types) %
                     tuple(res_knn['clusteracc']) + "\n"))
            g.write("%s\t" % (model_type) + 'p\t' +
                    ("%.4f\t" * len(gene_dataset.cell_types) %
                     tuple(res_knn_partial['clusteracc']) + "\n"))
            g.write("%s\t" % (model_type) + 'p1\t' +
                    ("%.4f\t" * len(gene_dataset.cell_types) %
                     tuple(res_knn_partial1['clusteracc']) + "\n"))
            g.write("%s\t" % (model_type) + 'p2\t' +
                    ("%.4f\t" * len(gene_dataset.cell_types) %
                     tuple(res_knn_partial2['clusteracc']) + "\n"))
            # for i in [1, 2, 3]:
            #     latent, batch_indices, labels, keys, stats = run_model(model_type, gene_dataset, dataset1, dataset2,
            #                                                            filename=plotname, rep=str(i))
            #     res_jaccard, res_jaccard_score = KNNJaccardIndex(latent1, latent2, latent, batch_indices)
            #
            #     res_knn, res_knn_partial, res_kmeans, res_kmeans_partial = \
            #         eval_latent(batch_indices=batch_indices, labels=labels, latent=latent, keys=keys,
            #                     labelled_idx=labelled_idx, unlabelled_idx=unlabelled_idx,
            #                     plotname=plotname + '.' + model_type, plotting=False,partial_only=False)
            #
            #     _, res_knn_partial1, _, res_kmeans_partial1 = \
            #         eval_latent(batch_indices=batch_indices, labels=labels, latent=latent, keys=keys,
            #                     labelled_idx=(batch_indices == 0), unlabelled_idx=(batch_indices == 1),
            #                     plotname=plotname + '.' + model_type, plotting=False)
            #
            #     _, res_knn_partial2, _, res_kmeans_partial2 = \
            #         eval_latent(batch_indices=batch_indices, labels=labels, latent=latent, keys=keys,
            #                     labelled_idx=(batch_indices == 1), unlabelled_idx=(batch_indices == 0),
            #                     plotname=plotname + '.' + model_type, plotting=False)
            #
            #     res = [res_knn[x] for x in res_knn] + \
            #           [res_knn_partial[x] for x in res_knn_partial] + \
            #           [res_knn_partial1[x] for x in res_knn_partial1] + \
            #           [res_knn_partial2[x] for x in res_knn_partial2] + \
            #           [res_kmeans[x] for x in res_kmeans] + \
            #           [res_kmeans_partial[x] for x in res_kmeans_partial] + \
            #           [res_kmeans_partial1[x] for x in res_kmeans_partial1] + \
            #           [res_kmeans_partial2[x] for x in res_kmeans_partial2] + \
            #           res_jaccard + \
            #           [res_jaccard_score,stats[0], stats[1], stats[2]]
            #     f.write(model_type + (" %.4f" * 61 + "\n") % tuple(res))

    elif models == 'writedata':
        _, _, _, _, _ = run_model('writedata',
                                  gene_dataset,
                                  dataset1,
                                  dataset2,
                                  filename=plotname)
    f.close()
    g.close()
def trainSCANVI(gene_dataset,
                model_type,
                filename,
                rep,
                nlayers=2,
                reconstruction_loss: str = "zinb"):
    vae_posterior = trainVAE(gene_dataset,
                             filename,
                             rep,
                             reconstruction_loss=reconstruction_loss)
    filename = '../' + filename + '/' + model_type + '.' + reconstruction_loss + '.rep' + str(
        rep) + '.pkl'
    scanvi = SCANVI(gene_dataset.nb_genes,
                    gene_dataset.n_batches,
                    gene_dataset.n_labels,
                    n_layers=nlayers,
                    reconstruction_loss=reconstruction_loss)
    scanvi.load_state_dict(vae_posterior.model.state_dict(), strict=False)
    if model_type == 'scanvi1':
        trainer_scanvi = AlternateSemiSupervisedTrainer(
            scanvi,
            gene_dataset,
            classification_ratio=0,
            n_epochs_classifier=100,
            lr_classification=5 * 1e-3)
        labelled = np.where(gene_dataset.batch_indices.ravel() == 0)[0]
        labelled = np.random.choice(labelled, len(labelled), replace=False)
        trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
            indices=labelled)
        trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
            indices=(gene_dataset.batch_indices.ravel() == 1))
    elif model_type == 'scanvi2':
        trainer_scanvi = AlternateSemiSupervisedTrainer(
            scanvi,
            gene_dataset,
            classification_ratio=0,
            n_epochs_classifier=100,
            lr_classification=5 * 1e-3)
        labelled = np.where(gene_dataset.batch_indices.ravel() == 1)[0]
        labelled = np.random.choice(labelled, len(labelled), replace=False)
        trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
            indices=labelled)
        trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
            indices=(gene_dataset.batch_indices.ravel() == 0))
    elif model_type == 'scanvi0':
        trainer_scanvi = SemiSupervisedTrainer(scanvi,
                                               gene_dataset,
                                               classification_ratio=0,
                                               n_epochs_classifier=100,
                                               lr_classification=5 * 1e-3)
        trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
            indices=(gene_dataset.batch_indices.ravel() < 0))
        trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
            indices=(gene_dataset.batch_indices.ravel() >= 0))
    else:
        trainer_scanvi = SemiSupervisedTrainer(scanvi,
                                               gene_dataset,
                                               classification_ratio=10,
                                               n_epochs_classifier=100,
                                               lr_classification=5 * 1e-3)

    if os.path.isfile(filename):
        trainer_scanvi.model.load_state_dict(torch.load(filename))
        trainer_scanvi.model.eval()
    else:
        trainer_scanvi.train(n_epochs=5)
        torch.save(trainer_scanvi.model.state_dict(), filename)
    return trainer_scanvi