Exemple #1
0
def test_cortex(save_path):
    cortex_dataset = CortexDataset(save_path=save_path)
    vae = VAE(cortex_dataset.nb_genes, cortex_dataset.n_batches)
    trainer_cortex_vae = UnsupervisedTrainer(vae, cortex_dataset, train_size=0.5, use_cuda=use_cuda)
    trainer_cortex_vae.train(n_epochs=1)
    trainer_cortex_vae.train_set.ll()
    trainer_cortex_vae.train_set.differential_expression_stats()

    trainer_cortex_vae.corrupt_posteriors(corruption='binomial')
    trainer_cortex_vae.corrupt_posteriors()
    trainer_cortex_vae.train(n_epochs=1)
    trainer_cortex_vae.uncorrupt_posteriors()

    trainer_cortex_vae.train_set.imputation_benchmark(n_samples=1, show_plot=False,
                                                      title_plot='imputation', save_path=save_path)

    svaec = SCANVI(cortex_dataset.nb_genes, cortex_dataset.n_batches, cortex_dataset.n_labels)
    trainer_cortex_svaec = JointSemiSupervisedTrainer(svaec, cortex_dataset,
                                                      n_labelled_samples_per_class=3,
                                                      use_cuda=use_cuda)
    trainer_cortex_svaec.train(n_epochs=1)
    trainer_cortex_svaec.labelled_set.accuracy()
    trainer_cortex_svaec.full_dataset.ll()

    svaec = SCANVI(cortex_dataset.nb_genes, cortex_dataset.n_batches, cortex_dataset.n_labels)
    trainer_cortex_svaec = AlternateSemiSupervisedTrainer(svaec, cortex_dataset,
                                                          n_labelled_samples_per_class=3,
                                                          use_cuda=use_cuda)
    trainer_cortex_svaec.train(n_epochs=1, lr=1e-2)
    trainer_cortex_svaec.unlabelled_set.accuracy()
    data_train, labels_train = trainer_cortex_svaec.labelled_set.raw_data()
    data_test, labels_test = trainer_cortex_svaec.unlabelled_set.raw_data()
    compute_accuracy_svc(data_train, labels_train, data_test, labels_test,
                         param_grid=[{'C': [1], 'kernel': ['linear']}])
    compute_accuracy_rf(data_train, labels_train, data_test, labels_test,
                        param_grid=[{'max_depth': [3], 'n_estimators': [10]}])

    cls = Classifier(cortex_dataset.nb_genes, n_labels=cortex_dataset.n_labels)
    cls_trainer = ClassifierTrainer(cls, cortex_dataset)
    cls_trainer.train(n_epochs=1)
    cls_trainer.train_set.accuracy()
print("Transferring labels from scVI")

true_labels = labels.ravel()
scVI_labels = transfer_nn_labels(latent, labels, batch_indices)

# train scANVI
print("Training scANVI")
# scanvi = SCANVI(gene_dataset.nb_genes, gene_dataset.n_batches, gene_dataset.n_labels, n_latent=10)
scanvi = SCANVI(gene_dataset.nb_genes,
                gene_dataset.n_batches,
                gene_dataset.n_labels,
                n_latent=10,
                reconstruction_loss='nb')
scanvi.load_state_dict(trainer.model.state_dict(), strict=False)
trainer_scanvi = AlternateSemiSupervisedTrainer(scanvi,
                                                gene_dataset,
                                                n_epochs_classifier=5,
                                                lr_classification=5 * 1e-3)
labelled = np.where(gene_dataset.batch_indices == 0)[0]
# np.random.shuffle(labelled)
unlabelled = np.where(gene_dataset.batch_indices == 1)[0]
# np.random.shuffle(unlabelled)
trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(indices=labelled)
trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
    indices=unlabelled)

# file_name = '%s/scanvi.pkl' % save_path
# if os.path.isfile(file_name):
#     print("loaded model from: " + file_name)
#     trainer_scanvi.model.load_state_dict(torch.load(file_name))
#     trainer_scanvi.model.eval()
# else:
Exemple #3
0
def test_cortex(save_path):
    cortex_dataset = CortexDataset(save_path=save_path)
    vae = VAE(cortex_dataset.nb_genes, cortex_dataset.n_batches)
    trainer_cortex_vae = UnsupervisedTrainer(
        vae, cortex_dataset, train_size=0.5, use_cuda=use_cuda
    )
    trainer_cortex_vae.train(n_epochs=1)
    trainer_cortex_vae.train_set.reconstruction_error()
    trainer_cortex_vae.train_set.differential_expression_stats()
    trainer_cortex_vae.train_set.generate_feature_correlation_matrix(
        n_samples=2, correlation_type="pearson"
    )
    trainer_cortex_vae.train_set.generate_feature_correlation_matrix(
        n_samples=2, correlation_type="spearman"
    )
    trainer_cortex_vae.train_set.imputation(n_samples=1)
    trainer_cortex_vae.test_set.imputation(n_samples=5)

    trainer_cortex_vae.corrupt_posteriors(corruption="binomial")
    trainer_cortex_vae.corrupt_posteriors()
    trainer_cortex_vae.train(n_epochs=1)
    trainer_cortex_vae.uncorrupt_posteriors()

    trainer_cortex_vae.train_set.imputation_benchmark(
        n_samples=1, show_plot=False, title_plot="imputation", save_path=save_path
    )
    trainer_cortex_vae.train_set.generate_parameters()

    n_cells, n_genes = (
        len(trainer_cortex_vae.train_set.indices),
        cortex_dataset.nb_genes,
    )
    n_samples = 3
    (dropout, means, dispersions,) = trainer_cortex_vae.train_set.generate_parameters()
    assert dropout.shape == (n_cells, n_genes) and means.shape == (n_cells, n_genes)
    assert dispersions.shape == (n_cells, n_genes)
    (dropout, means, dispersions,) = trainer_cortex_vae.train_set.generate_parameters(
        n_samples=n_samples
    )
    assert dropout.shape == (n_samples, n_cells, n_genes)
    assert means.shape == (n_samples, n_cells, n_genes,)
    (dropout, means, dispersions,) = trainer_cortex_vae.train_set.generate_parameters(
        n_samples=n_samples, give_mean=True
    )
    assert dropout.shape == (n_cells, n_genes) and means.shape == (n_cells, n_genes)

    full = trainer_cortex_vae.create_posterior(
        vae, cortex_dataset, indices=np.arange(len(cortex_dataset))
    )
    x_new, x_old = full.generate(n_samples=10)
    assert x_new.shape == (cortex_dataset.nb_cells, cortex_dataset.nb_genes, 10)
    assert x_old.shape == (cortex_dataset.nb_cells, cortex_dataset.nb_genes)

    trainer_cortex_vae.train_set.imputation_benchmark(
        n_samples=1, show_plot=False, title_plot="imputation", save_path=save_path
    )

    svaec = SCANVI(
        cortex_dataset.nb_genes, cortex_dataset.n_batches, cortex_dataset.n_labels
    )
    trainer_cortex_svaec = JointSemiSupervisedTrainer(
        svaec, cortex_dataset, n_labelled_samples_per_class=3, use_cuda=use_cuda
    )
    trainer_cortex_svaec.train(n_epochs=1)
    trainer_cortex_svaec.labelled_set.accuracy()
    trainer_cortex_svaec.full_dataset.reconstruction_error()

    svaec = SCANVI(
        cortex_dataset.nb_genes, cortex_dataset.n_batches, cortex_dataset.n_labels
    )
    trainer_cortex_svaec = AlternateSemiSupervisedTrainer(
        svaec, cortex_dataset, n_labelled_samples_per_class=3, use_cuda=use_cuda
    )
    trainer_cortex_svaec.train(n_epochs=1, lr=1e-2)
    trainer_cortex_svaec.unlabelled_set.accuracy()
    data_train, labels_train = trainer_cortex_svaec.labelled_set.raw_data()
    data_test, labels_test = trainer_cortex_svaec.unlabelled_set.raw_data()
    compute_accuracy_svc(
        data_train,
        labels_train,
        data_test,
        labels_test,
        param_grid=[{"C": [1], "kernel": ["linear"]}],
    )
    compute_accuracy_rf(
        data_train,
        labels_train,
        data_test,
        labels_test,
        param_grid=[{"max_depth": [3], "n_estimators": [10]}],
    )

    cls = Classifier(cortex_dataset.nb_genes, n_labels=cortex_dataset.n_labels)
    cls_trainer = ClassifierTrainer(cls, cortex_dataset)
    cls_trainer.train(n_epochs=1)
    cls_trainer.train_set.accuracy()
Exemple #4
0
                           celltype1)
 g.write('vae' + '\t' + rmCellTypes +
         ("\t%.4f" * 8 + "\t%s" * 8 + "\n") %
         tuple(be + list(cell_type2)))
 plotUMAP(latent, plotname, 'vae', gene_dataset.cell_types,
          rmCellTypes, gene_dataset.batch_indices.ravel())
 scanvi = SCANVI(gene_dataset.nb_genes,
                 2, (gene_dataset.n_labels + 1),
                 n_hidden=128,
                 n_latent=10,
                 n_layers=2,
                 dispersion='gene')
 scanvi.load_state_dict(trainer.model.state_dict(), strict=False)
 trainer_scanvi = AlternateSemiSupervisedTrainer(
     scanvi,
     gene_dataset,
     n_epochs_classifier=10,
     lr_classification=5 * 1e-3)
 trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
     indices=gene_dataset.batch_indices.ravel() == 0)
 trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
     indices=gene_dataset.batch_indices.ravel() == 1)
 trainer_scanvi.train(n_epochs=10)
 scanvi_full = trainer_scanvi.create_posterior(
     trainer_scanvi.model,
     gene_dataset,
     indices=np.arange(len(gene_dataset)))
 latent, _, _ = scanvi_full.sequential().get_latent()
 acc, cell_type = KNNpurity(latent1, latent2, latent,
                            batch_indices.ravel(), labels, keys)
 f.write('scanvi' + '\t' + rmCellTypes +
def trainSCANVI(gene_dataset,
                model_type,
                filename,
                rep,
                nlayers=2,
                reconstruction_loss: str = "zinb"):
    vae_posterior = trainVAE(gene_dataset,
                             filename,
                             rep,
                             reconstruction_loss=reconstruction_loss)
    filename = '../' + filename + '/' + model_type + '.' + reconstruction_loss + '.rep' + str(
        rep) + '.pkl'
    scanvi = SCANVI(gene_dataset.nb_genes,
                    gene_dataset.n_batches,
                    gene_dataset.n_labels,
                    n_layers=nlayers,
                    reconstruction_loss=reconstruction_loss)
    scanvi.load_state_dict(vae_posterior.model.state_dict(), strict=False)
    if model_type == 'scanvi1':
        trainer_scanvi = AlternateSemiSupervisedTrainer(
            scanvi,
            gene_dataset,
            classification_ratio=0,
            n_epochs_classifier=100,
            lr_classification=5 * 1e-3)
        labelled = np.where(gene_dataset.batch_indices.ravel() == 0)[0]
        labelled = np.random.choice(labelled, len(labelled), replace=False)
        trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
            indices=labelled)
        trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
            indices=(gene_dataset.batch_indices.ravel() == 1))
    elif model_type == 'scanvi2':
        trainer_scanvi = AlternateSemiSupervisedTrainer(
            scanvi,
            gene_dataset,
            classification_ratio=0,
            n_epochs_classifier=100,
            lr_classification=5 * 1e-3)
        labelled = np.where(gene_dataset.batch_indices.ravel() == 1)[0]
        labelled = np.random.choice(labelled, len(labelled), replace=False)
        trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
            indices=labelled)
        trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
            indices=(gene_dataset.batch_indices.ravel() == 0))
    elif model_type == 'scanvi0':
        trainer_scanvi = SemiSupervisedTrainer(scanvi,
                                               gene_dataset,
                                               classification_ratio=0,
                                               n_epochs_classifier=100,
                                               lr_classification=5 * 1e-3)
        trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
            indices=(gene_dataset.batch_indices.ravel() < 0))
        trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
            indices=(gene_dataset.batch_indices.ravel() >= 0))
    else:
        trainer_scanvi = SemiSupervisedTrainer(scanvi,
                                               gene_dataset,
                                               classification_ratio=10,
                                               n_epochs_classifier=100,
                                               lr_classification=5 * 1e-3)

    if os.path.isfile(filename):
        trainer_scanvi.model.load_state_dict(torch.load(filename))
        trainer_scanvi.model.eval()
    else:
        trainer_scanvi.train(n_epochs=5)
        torch.save(trainer_scanvi.model.state_dict(), filename)
    return trainer_scanvi