Пример #1
0
gene_dataset = GeneExpressionDataset.concat_datasets(dataset1, dataset2)
gene_dataset.subsample_genes(5000)

if model_type == 'vae':
    vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches, n_labels=gene_dataset.n_labels,
              n_hidden=128, n_latent=10, n_layers=2, dispersion='gene')
    infer = VariationalInference(vae, gene_dataset, use_cuda=use_cuda)
    infer.train(n_epochs=250)
    data_loader = infer.data_loaders['sequential']
    latent, batch_indices, labels = get_latent(vae, data_loader)
    keys = gene_dataset.cell_types
    batch_indices = np.concatenate(batch_indices)
    keys = gene_dataset.cell_types
elif model_type == 'svaec':
    svaec = SCANVI(gene_dataset.nb_genes, gene_dataset.n_batches,
                   gene_dataset.n_labels, use_labels_groups=False,
                   n_latent=10, n_layers=2)
    infer = SemiSupervisedVariationalInference(svaec, gene_dataset)
    infer.train(n_epochs=50)
    print('svaec acc =', infer.accuracy('unlabelled'))
    data_loader = infer.data_loaders['unlabelled']
    latent, batch_indices, labels = get_latent(infer.model, infer.data_loaders['unlabelled'])
    keys = gene_dataset.cell_types
    batch_indices = np.concatenate(batch_indices)
elif model_type == 'Seurat':
    # SEURAT = SEURAT()
    # seurat1 = SEURAT.create_seurat(dataset1, 1)
    # seurat2 = SEURAT.create_seurat(dataset2, 2)
    # latent, batch_indices,labels,keys = SEURAT.get_cca()
    latent = np.genfromtxt('../macosko_regev.CCA.txt')
    label = np.genfromtxt('../macosko_regev.CCA.label.txt',dtype='str')
latent, batch_indices, labels = full.sequential().get_latent()
# n_samples_tsne = 4000
# full.show_t_sne(n_samples=n_samples_tsne, color_by='batches and labels',
#                 save_name=os.path.join(save_path, "scVI_tSNE_batches_labels.png"))

print("Transferring labels from scVI")

true_labels = labels.ravel()
scVI_labels = transfer_nn_labels(latent, labels, batch_indices)

# train scANVI
print("Training scANVI")
# scanvi = SCANVI(gene_dataset.nb_genes, gene_dataset.n_batches, gene_dataset.n_labels, n_latent=10)
scanvi = SCANVI(gene_dataset.nb_genes,
                gene_dataset.n_batches,
                gene_dataset.n_labels,
                n_latent=10,
                reconstruction_loss='nb')
scanvi.load_state_dict(trainer.model.state_dict(), strict=False)
trainer_scanvi = AlternateSemiSupervisedTrainer(scanvi,
                                                gene_dataset,
                                                n_epochs_classifier=5,
                                                lr_classification=5 * 1e-3)
labelled = np.where(gene_dataset.batch_indices == 0)[0]
# np.random.shuffle(labelled)
unlabelled = np.where(gene_dataset.batch_indices == 1)[0]
# np.random.shuffle(unlabelled)
trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(indices=labelled)
trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
    indices=unlabelled)
Пример #3
0
                                                   rep='')
 acc, cell_type = KNNpurity(latent1, latent2, latent,
                            batch_indices.ravel(), labels, keys)
 f.write('vae' + '\t' + rmCellTypes +
         ("\t%.4f" * 8 + "\t%s" * 8 + "\n") %
         tuple(list(acc) + list(cell_type)))
 be, cell_type2 = BEbyType(keys, latent, labels, batch_indices,
                           celltype1)
 g.write('vae' + '\t' + rmCellTypes +
         ("\t%.4f" * 8 + "\t%s" * 8 + "\n") %
         tuple(be + list(cell_type2)))
 plotUMAP(latent, plotname, 'vae', gene_dataset.cell_types,
          rmCellTypes, gene_dataset.batch_indices.ravel())
 scanvi = SCANVI(gene_dataset.nb_genes,
                 2, (gene_dataset.n_labels + 1),
                 n_hidden=128,
                 n_latent=10,
                 n_layers=2,
                 dispersion='gene')
 scanvi.load_state_dict(trainer.model.state_dict(), strict=False)
 trainer_scanvi = AlternateSemiSupervisedTrainer(
     scanvi,
     gene_dataset,
     n_epochs_classifier=10,
     lr_classification=5 * 1e-3)
 trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
     indices=gene_dataset.batch_indices.ravel() == 0)
 trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
     indices=gene_dataset.batch_indices.ravel() == 1)
 trainer_scanvi.train(n_epochs=10)
 scanvi_full = trainer_scanvi.create_posterior(
     trainer_scanvi.model,
Пример #4
0
        data_loader = infer_vae.data_loaders['sequential']
        latent, batch_indices, labels = get_latent(vae, data_loader)
        keys = gene_dataset.cell_types
        batch_indices = np.concatenate(batch_indices)
        keys = gene_dataset.cell_types
    elif model_type == 'svaec':
        gene_dataset.subsample_genes(1000)

        n_epochs_vae = 100
        n_epochs_scanvi = 50
        vae = VAE(gene_dataset.nb_genes, gene_dataset.n_batches, gene_dataset.n_labels, n_latent=10, n_layers=2)
        trainer = UnsupervisedTrainer(vae, gene_dataset, train_size=1.0)
        trainer.train(n_epochs=n_epochs_vae)

        for i in [0, 1]:
            scanvi = SCANVI(gene_dataset.nb_genes, gene_dataset.n_batches, gene_dataset.n_labels, n_layers=2)
            scanvi.load_state_dict(vae.state_dict(), strict=False)
            trainer_scanvi = SemiSupervisedTrainer(scanvi, gene_dataset, classification_ratio=1,
                                                   n_epochs_classifier=1, lr_classification=5 * 1e-3)

            trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(indices=(gene_dataset.batch_indices == i))
            trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
                indices=(gene_dataset.batch_indices == 1 - i)
            )

            trainer_scanvi.model.eval()
            print('NN: ', trainer_scanvi.nn_latentspace())
            trainer_scanvi.unlabelled_set.to_monitor = ['accuracy']
            trainer_scanvi.labelled_set.to_monitor = ['accuracy']
            trainer_scanvi.full_dataset.to_monitor = ['entropy_batch_mixing']
            trainer_scanvi.train(n_epochs=n_epochs_scanvi)
trainer.model.load_state_dict(
    torch.load(save_path + 'DE.vae.%i.mis%.2f.pkl' % (rep, misprop)))
trainer.model.eval()

full = trainer.create_posterior(trainer.model,
                                gene_dataset,
                                indices=np.arange(len(gene_dataset)))
latent, batch_indices, _ = full.sequential().get_latent()

print("Transferring labels from scVI")
scVI_labels = transfer_nn_labels(latent, mislabels, batch_indices)

# train scANVI
print("Training scANVI")
scanvi = SCANVI(gene_dataset.nb_genes,
                gene_dataset.n_batches,
                gene_dataset.n_labels,
                n_latent=10)
scanvi.load_state_dict(trainer.model.state_dict(), strict=False)
trainer_scanvi = SemiSupervisedTrainer(scanvi,
                                       gene_dataset,
                                       classification_ratio=50,
                                       n_epochs_classifier=1,
                                       lr_classification=5 * 1e-3)
# trainer_scanvi = AlternateSemiSupervisedTrainer(scanvi, gene_dataset,
#                                                 n_epochs_classifier=5, lr_classification=5 * 1e-3, kl=1)
labelled = np.where(gene_dataset.batch_indices == 0)[0]
np.random.shuffle(labelled)
unlabelled = np.where(gene_dataset.batch_indices == 1)[0]
np.random.shuffle(unlabelled)
trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(indices=labelled)
trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
def CompareModels(gene_dataset, dataset1, dataset2, plotname, models):
    KNeighbors = np.concatenate(
        [np.arange(10, 100, 10),
         np.arange(100, 500, 50)])
    K_int = np.concatenate([np.repeat(10, 10), np.repeat(50, 7)])
    f = open('../' + plotname + '/' + models + '.res.txt', "w+")
    f.write("model_type " + \
            "knn_asw knn_nmi knn_ari knn_uca knn_wuca " + \
            "p_knn_asw p_knn_nmi p_knn_ari p_knn_uca p_knn_wuca " + \
            "p1_knn_asw p1_knn_nmi p1_knn_ari p1_knn_uca p1_knn_wuca " + \
            "p2_knn_asw p2_knn_nmi p2_knn_ari p2_knn_uca p2_knn_wuca " + \
            "kmeans_asw kmeans_nmi kmeans_ari kmeans_uca kmeans_wuca " + \
            "p_kmeans_asw p_kmeans_nmi p_kmeans_ari p_kmeans_uca p_kmeans_wuca " + \
            "p1_kmeans_asw p1_kmeans_nmi p1_kmeans_ari p1_kmeans_uca p1_kmeans_wuca " + \
            "p2_kmeans_asw p2_kmeans_nmi p2_kmeans_ari p2_kmeans_uca p2_kmeans_wuca " + \
            " ".join(['res_jaccard' + x for x in
                      np.concatenate([np.repeat(10, 10), np.repeat(50, 7)]).astype('str')]) + " " + \
            'jaccard_score likelihood BE classifier_acc\n'
            )
    g = open('../' + plotname + '/' + models + '.percluster.res.txt', "w+")
    g.write("model_type\tannotation\t" + "\t".join(gene_dataset.cell_types) +
            "\n")

    scanvi = SCANVI(gene_dataset.nb_genes, gene_dataset.n_batches,
                    gene_dataset.n_labels)
    trainer_scanvi = SemiSupervisedTrainer(scanvi,
                                           gene_dataset,
                                           classification_ratio=1,
                                           n_epochs_classifier=1,
                                           lr_classification=5 * 1e-3)
    labelled_idx = trainer_scanvi.labelled_set.indices
    unlabelled_idx = trainer_scanvi.unlabelled_set.indices

    if models == 'others':
        latent1 = np.genfromtxt('../harmonization/Seurat_data/' + plotname +
                                '.1.CCA.txt')
        latent2 = np.genfromtxt('../harmonization/Seurat_data/' + plotname +
                                '.2.CCA.txt')
        for model_type in [
                'scmap', 'readSeurat', 'coral', 'Combat', 'MNN', 'PCA'
        ]:
            print(model_type)
            if (model_type == 'scmap') or (model_type == 'coral'):
                latent, batch_indices, labels, keys, stats = run_model(
                    model_type,
                    gene_dataset,
                    dataset1,
                    dataset2,
                    filename=plotname)
                pred1 = latent
                pred2 = stats
                res1 = scmap_eval(pred1, labels[batch_indices == 1], labels)
                res2 = scmap_eval(pred2, labels[batch_indices == 0], labels)
                g.write("%s\t" % (model_type) + "p1\t" +
                        ("%.4f\t" * len(gene_dataset.cell_types) %
                         tuple(res1['clusteracc']) + "\n"))
                g.write("%s\t" % (model_type) + "p2\t" +
                        ("%.4f\t" * len(gene_dataset.cell_types) %
                         tuple(res2['clusteracc']) + "\n"))
                res = [-1] * 10 + \
                      [-1] + [res1[x] for x in ['nmi', 'ari', 'ca', 'weighted ca']] + \
                      [-1] + [res2[x] for x in ['nmi', 'ari', 'ca', 'weighted ca']] + \
                      [-1] * 41

                f.write(model_type + (" %.4f" * 61 + "\n") % tuple(res))
            else:
                if model_type == 'readSeurat':
                    dataset1, dataset2, gene_dataset = SubsetGenes(
                        dataset1, dataset2, gene_dataset, plotname)

                latent, batch_indices, labels, keys, stats = run_model(
                    model_type,
                    gene_dataset,
                    dataset1,
                    dataset2,
                    filename=plotname)

                res_jaccard = [
                    KNNJaccardIndex(latent1, latent2, latent, batch_indices,
                                    k)[0] for k in KNeighbors
                ]
                res_jaccard_score = np.sum(res_jaccard * K_int)
                res_knn, res_knn_partial, res_kmeans, res_kmeans_partial = \
                    eval_latent(batch_indices, labels, latent, keys,
                                labelled_idx, unlabelled_idx,
                                plotname=plotname + '.' + model_type, plotting=False, partial_only=False)

                _, res_knn_partial1, _, res_kmeans_partial1 = \
                    eval_latent(batch_indices, labels, latent, keys,
                                batch_indices == 0, batch_indices == 1,
                                plotname=plotname + '.' + model_type, plotting=False)

                _, res_knn_partial2, _, res_kmeans_partial2 = \
                    eval_latent(batch_indices, labels, latent, keys,
                                batch_indices == 1, batch_indices == 0,
                                plotname=plotname + '.' + model_type, plotting=False)

                sample = select_indices_evenly(
                    np.min(np.unique(batch_indices, return_counts=True)[1]),
                    batch_indices)
                batch_entropy = entropy_batch_mixing(latent[sample, :],
                                                     batch_indices[sample])

                res = [res_knn[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \
                      [res_knn_partial[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \
                      [res_knn_partial1[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \
                      [res_knn_partial2[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \
                      [res_kmeans[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \
                      [res_kmeans_partial[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \
                      [res_kmeans_partial1[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \
                      [res_kmeans_partial2[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \
                      res_jaccard + \
                      [res_jaccard_score, -1, batch_entropy, -1]

                f.write(model_type + (" %.4f" * 61 + "\n") % tuple(res))
                g.write("%s\t" % (model_type) + 'all\t' +
                        ("%.4f\t" * len(gene_dataset.cell_types) %
                         tuple(res_knn['clusteracc']) + "\n"))
                g.write("%s\t" % (model_type) + 'p\t' +
                        ("%.4f\t" * len(gene_dataset.cell_types) %
                         tuple(res_knn_partial['clusteracc']) + "\n"))
                g.write("%s\t" % (model_type) + 'p1\t' +
                        ("%.4f\t" * len(gene_dataset.cell_types) %
                         tuple(res_knn_partial1['clusteracc']) + "\n"))
                g.write("%s\t" % (model_type) + 'p2\t' +
                        ("%.4f\t" * len(gene_dataset.cell_types) %
                         tuple(res_knn_partial2['clusteracc']) + "\n"))

    elif (models == 'scvi') or (models == 'scvi_nb'):
        dataset1, dataset2, gene_dataset = SubsetGenes(dataset1, dataset2,
                                                       gene_dataset, plotname)
        if models == 'scvi_nb':
            latent1, _, _, _, _ = run_model('vae_nb',
                                            dataset1,
                                            0,
                                            0,
                                            filename=plotname,
                                            rep='vae1_nb')
            latent2, _, _, _, _ = run_model('vae_nb',
                                            dataset2,
                                            0,
                                            0,
                                            filename=plotname,
                                            rep='vae2_nb')
        else:
            latent1, _, _, _, _ = run_model('vae',
                                            dataset1,
                                            0,
                                            0,
                                            filename=plotname,
                                            rep='vae1')
            latent2, _, _, _, _ = run_model('vae',
                                            dataset2,
                                            0,
                                            0,
                                            filename=plotname,
                                            rep='vae2')

        for model_type in [
                'vae', 'scanvi1', 'scanvi2', 'vae_nb', 'scanvi1_nb',
                'scanvi2_nb'
        ]:
            print(model_type)
            latent, batch_indices, labels, keys, stats = run_model(
                model_type,
                gene_dataset,
                dataset1,
                dataset2,
                filename=plotname,
                rep='0')

            res_jaccard = [
                KNNJaccardIndex(latent1, latent2, latent, batch_indices, k)[0]
                for k in KNeighbors
            ]
            res_jaccard_score = np.sum(res_jaccard * K_int)
            res_knn, res_knn_partial, res_kmeans, res_kmeans_partial = \
                eval_latent(batch_indices=batch_indices, labels=labels, latent=latent, keys=keys,
                            labelled_idx=labelled_idx, unlabelled_idx=unlabelled_idx,
                            plotname=plotname + '.' + model_type, plotting=False, partial_only=False)

            _, res_knn_partial1, _, res_kmeans_partial1 = \
                eval_latent(batch_indices=batch_indices, labels=labels, latent=latent, keys=keys,
                            labelled_idx=(batch_indices == 0), unlabelled_idx=(batch_indices == 1),
                            plotname=plotname + '.' + model_type, plotting=False)

            _, res_knn_partial2, _, res_kmeans_partial2 = \
                eval_latent(batch_indices=batch_indices, labels=labels, latent=latent, keys=keys,
                            labelled_idx=(batch_indices == 1), unlabelled_idx=(batch_indices == 0),
                            plotname=plotname + '.' + model_type, plotting=False)

            res = [res_knn[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \
                  [res_knn_partial[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \
                  [res_knn_partial1[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \
                  [res_knn_partial2[x] for x in ['asw', 'nmi', 'ari', 'ca', 'weighted ca']] + \
                  [res_kmeans[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \
                  [res_kmeans_partial[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \
                  [res_kmeans_partial1[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \
                  [res_kmeans_partial2[x] for x in ['asw', 'nmi', 'ari', 'uca', 'weighted uca']] + \
                  res_jaccard + \
                  [res_jaccard_score, stats[0], stats[1], stats[2]]

            f.write(model_type + (" %.4f" * 61 + "\n") % tuple(res))
            g.write("%s\t" % (model_type) + 'all\t' +
                    ("%.4f\t" * len(gene_dataset.cell_types) %
                     tuple(res_knn['clusteracc']) + "\n"))
            g.write("%s\t" % (model_type) + 'p\t' +
                    ("%.4f\t" * len(gene_dataset.cell_types) %
                     tuple(res_knn_partial['clusteracc']) + "\n"))
            g.write("%s\t" % (model_type) + 'p1\t' +
                    ("%.4f\t" * len(gene_dataset.cell_types) %
                     tuple(res_knn_partial1['clusteracc']) + "\n"))
            g.write("%s\t" % (model_type) + 'p2\t' +
                    ("%.4f\t" * len(gene_dataset.cell_types) %
                     tuple(res_knn_partial2['clusteracc']) + "\n"))
            # for i in [1, 2, 3]:
            #     latent, batch_indices, labels, keys, stats = run_model(model_type, gene_dataset, dataset1, dataset2,
            #                                                            filename=plotname, rep=str(i))
            #     res_jaccard, res_jaccard_score = KNNJaccardIndex(latent1, latent2, latent, batch_indices)
            #
            #     res_knn, res_knn_partial, res_kmeans, res_kmeans_partial = \
            #         eval_latent(batch_indices=batch_indices, labels=labels, latent=latent, keys=keys,
            #                     labelled_idx=labelled_idx, unlabelled_idx=unlabelled_idx,
            #                     plotname=plotname + '.' + model_type, plotting=False,partial_only=False)
            #
            #     _, res_knn_partial1, _, res_kmeans_partial1 = \
            #         eval_latent(batch_indices=batch_indices, labels=labels, latent=latent, keys=keys,
            #                     labelled_idx=(batch_indices == 0), unlabelled_idx=(batch_indices == 1),
            #                     plotname=plotname + '.' + model_type, plotting=False)
            #
            #     _, res_knn_partial2, _, res_kmeans_partial2 = \
            #         eval_latent(batch_indices=batch_indices, labels=labels, latent=latent, keys=keys,
            #                     labelled_idx=(batch_indices == 1), unlabelled_idx=(batch_indices == 0),
            #                     plotname=plotname + '.' + model_type, plotting=False)
            #
            #     res = [res_knn[x] for x in res_knn] + \
            #           [res_knn_partial[x] for x in res_knn_partial] + \
            #           [res_knn_partial1[x] for x in res_knn_partial1] + \
            #           [res_knn_partial2[x] for x in res_knn_partial2] + \
            #           [res_kmeans[x] for x in res_kmeans] + \
            #           [res_kmeans_partial[x] for x in res_kmeans_partial] + \
            #           [res_kmeans_partial1[x] for x in res_kmeans_partial1] + \
            #           [res_kmeans_partial2[x] for x in res_kmeans_partial2] + \
            #           res_jaccard + \
            #           [res_jaccard_score,stats[0], stats[1], stats[2]]
            #     f.write(model_type + (" %.4f" * 61 + "\n") % tuple(res))

    elif models == 'writedata':
        _, _, _, _, _ = run_model('writedata',
                                  gene_dataset,
                                  dataset1,
                                  dataset2,
                                  filename=plotname)
    f.close()
    g.close()
def trainSCANVI(gene_dataset,
                model_type,
                filename,
                rep,
                nlayers=2,
                reconstruction_loss: str = "zinb"):
    vae_posterior = trainVAE(gene_dataset,
                             filename,
                             rep,
                             reconstruction_loss=reconstruction_loss)
    filename = '../' + filename + '/' + model_type + '.' + reconstruction_loss + '.rep' + str(
        rep) + '.pkl'
    scanvi = SCANVI(gene_dataset.nb_genes,
                    gene_dataset.n_batches,
                    gene_dataset.n_labels,
                    n_layers=nlayers,
                    reconstruction_loss=reconstruction_loss)
    scanvi.load_state_dict(vae_posterior.model.state_dict(), strict=False)
    if model_type == 'scanvi1':
        trainer_scanvi = AlternateSemiSupervisedTrainer(
            scanvi,
            gene_dataset,
            classification_ratio=0,
            n_epochs_classifier=100,
            lr_classification=5 * 1e-3)
        labelled = np.where(gene_dataset.batch_indices.ravel() == 0)[0]
        labelled = np.random.choice(labelled, len(labelled), replace=False)
        trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
            indices=labelled)
        trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
            indices=(gene_dataset.batch_indices.ravel() == 1))
    elif model_type == 'scanvi2':
        trainer_scanvi = AlternateSemiSupervisedTrainer(
            scanvi,
            gene_dataset,
            classification_ratio=0,
            n_epochs_classifier=100,
            lr_classification=5 * 1e-3)
        labelled = np.where(gene_dataset.batch_indices.ravel() == 1)[0]
        labelled = np.random.choice(labelled, len(labelled), replace=False)
        trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
            indices=labelled)
        trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
            indices=(gene_dataset.batch_indices.ravel() == 0))
    elif model_type == 'scanvi0':
        trainer_scanvi = SemiSupervisedTrainer(scanvi,
                                               gene_dataset,
                                               classification_ratio=0,
                                               n_epochs_classifier=100,
                                               lr_classification=5 * 1e-3)
        trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
            indices=(gene_dataset.batch_indices.ravel() < 0))
        trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
            indices=(gene_dataset.batch_indices.ravel() >= 0))
    else:
        trainer_scanvi = SemiSupervisedTrainer(scanvi,
                                               gene_dataset,
                                               classification_ratio=10,
                                               n_epochs_classifier=100,
                                               lr_classification=5 * 1e-3)

    if os.path.isfile(filename):
        trainer_scanvi.model.load_state_dict(torch.load(filename))
        trainer_scanvi.model.eval()
    else:
        trainer_scanvi.train(n_epochs=5)
        torch.save(trainer_scanvi.model.state_dict(), filename)
    return trainer_scanvi