Esempio n. 1
0
def test_synthetic_1():
    synthetic_dataset = SyntheticDataset()
    synthetic_dataset.cell_types = np.array(['A', 'B', 'C'])
    svaec = SCANVI(synthetic_dataset.nb_genes, synthetic_dataset.n_batches,
                   synthetic_dataset.n_labels)
    trainer_synthetic_svaec = JointSemiSupervisedTrainer(svaec,
                                                         synthetic_dataset,
                                                         use_cuda=use_cuda)
    trainer_synthetic_svaec.train(n_epochs=1)
    trainer_synthetic_svaec.labelled_set.entropy_batch_mixing()
    trainer_synthetic_svaec.full_dataset.knn_purity(verbose=True)
    trainer_synthetic_svaec.labelled_set.show_t_sne(n_samples=5)
    trainer_synthetic_svaec.unlabelled_set.show_t_sne(n_samples=5,
                                                      color_by='labels')
    trainer_synthetic_svaec.labelled_set.show_t_sne(
        n_samples=5, color_by='batches and labels')
    trainer_synthetic_svaec.labelled_set.clustering_scores()
    trainer_synthetic_svaec.labelled_set.clustering_scores(
        prediction_algorithm='gmm')
    trainer_synthetic_svaec.unlabelled_set.unsupervised_classification_accuracy(
    )
    trainer_synthetic_svaec.unlabelled_set.differential_expression_score(
        'B', 'C', genes=['2', '4'], M_sampling=2, M_permutation=10)
    trainer_synthetic_svaec.unlabelled_set.differential_expression_table(
        M_sampling=2, M_permutation=10)
Esempio n. 2
0
def test_synthetic_1():
    synthetic_dataset = SyntheticDataset()
    synthetic_dataset.cell_types = np.array(["A", "B", "C"])
    svaec = SCANVI(
        synthetic_dataset.nb_genes,
        synthetic_dataset.n_batches,
        synthetic_dataset.n_labels,
    )
    trainer_synthetic_svaec = JointSemiSupervisedTrainer(
        svaec, synthetic_dataset, use_cuda=use_cuda
    )
    trainer_synthetic_svaec.train(n_epochs=1)
    trainer_synthetic_svaec.labelled_set.entropy_batch_mixing()
    trainer_synthetic_svaec.full_dataset.knn_purity()
    trainer_synthetic_svaec.labelled_set.show_t_sne(n_samples=5)
    trainer_synthetic_svaec.unlabelled_set.show_t_sne(n_samples=5, color_by="labels")
    trainer_synthetic_svaec.labelled_set.show_t_sne(
        n_samples=5, color_by="batches and labels"
    )
    trainer_synthetic_svaec.labelled_set.clustering_scores()
    trainer_synthetic_svaec.labelled_set.clustering_scores(prediction_algorithm="gmm")
    trainer_synthetic_svaec.unlabelled_set.unsupervised_classification_accuracy()
    trainer_synthetic_svaec.unlabelled_set.differential_expression_score(
        synthetic_dataset.labels.ravel() == 1,
        synthetic_dataset.labels.ravel() == 2,
        n_samples=2,
        M_permutation=10,
    )
    trainer_synthetic_svaec.unlabelled_set.one_vs_all_degenes(
        n_samples=2, M_permutation=10
    )
Esempio n. 3
0
def test_totalvi(save_path):
    synthetic_dataset_one_batch = SyntheticDataset(n_batches=1)
    totalvi_benchmark(synthetic_dataset_one_batch,
                      n_epochs=1,
                      use_cuda=use_cuda)
    synthetic_dataset_two_batches = SyntheticDataset(n_batches=2)
    totalvi_benchmark(synthetic_dataset_two_batches,
                      n_epochs=1,
                      use_cuda=use_cuda)
Esempio n. 4
0
def test_synthetic_1():
    synthetic_dataset = SyntheticDataset()
    synthetic_dataset.cell_types = np.array(["A", "B", "C"])
    svaec = SCANVI(
        synthetic_dataset.nb_genes,
        synthetic_dataset.n_batches,
        synthetic_dataset.n_labels,
    )
    trainer_synthetic_svaec = JointSemiSupervisedTrainer(svaec,
                                                         synthetic_dataset,
                                                         use_cuda=use_cuda)
    trainer_synthetic_svaec.train(n_epochs=1)
    trainer_synthetic_svaec.labelled_set.entropy_batch_mixing()

    with tempfile.TemporaryDirectory() as temp_dir:
        posterior_save_path = os.path.join(temp_dir, "posterior_data")
        original_post = trainer_synthetic_svaec.labelled_set.sequential()
        original_post.save_posterior(posterior_save_path)
        new_svaec = SCANVI(
            synthetic_dataset.nb_genes,
            synthetic_dataset.n_batches,
            synthetic_dataset.n_labels,
        )
        new_post = load_posterior(posterior_save_path,
                                  model=new_svaec,
                                  use_cuda=False)
    assert np.array_equal(new_post.indices, original_post.indices)
    assert np.array_equal(new_post.gene_dataset.X,
                          original_post.gene_dataset.X)
    assert np.array_equal(new_post.gene_dataset.labels,
                          original_post.gene_dataset.labels)

    trainer_synthetic_svaec.full_dataset.knn_purity()
    trainer_synthetic_svaec.labelled_set.show_t_sne(n_samples=5)
    trainer_synthetic_svaec.unlabelled_set.show_t_sne(n_samples=5,
                                                      color_by="labels")
    trainer_synthetic_svaec.labelled_set.show_t_sne(
        n_samples=5, color_by="batches and labels")
    trainer_synthetic_svaec.labelled_set.clustering_scores()
    trainer_synthetic_svaec.labelled_set.clustering_scores(
        prediction_algorithm="gmm")
    trainer_synthetic_svaec.unlabelled_set.unsupervised_classification_accuracy(
    )
    trainer_synthetic_svaec.unlabelled_set.differential_expression_score(
        synthetic_dataset.labels.ravel() == 1,
        synthetic_dataset.labels.ravel() == 2,
        n_samples=2,
        M_permutation=10,
    )
    trainer_synthetic_svaec.unlabelled_set.one_vs_all_degenes(n_samples=2,
                                                              M_permutation=10)
Esempio n. 5
0
def load_datasets(dataset_name, save_path='data/', url=None):
    if dataset_name == 'synthetic':
        gene_dataset = SyntheticDataset()
    elif dataset_name == 'cortex':
        gene_dataset = CortexDataset()
    elif dataset_name == 'brain_large':
        gene_dataset = BrainLargeDataset(save_path=save_path)
    elif dataset_name == 'retina':
        gene_dataset = RetinaDataset(save_path=save_path)
    elif dataset_name == 'cbmc':
        gene_dataset = CbmcDataset(save_path=save_path)
    elif dataset_name == 'brain_small':
        gene_dataset = BrainSmallDataset(save_path=save_path)
    elif dataset_name == 'hemato':
        gene_dataset = HematoDataset(save_path='data/HEMATO/')
    elif dataset_name == 'pbmc':
        gene_dataset = PbmcDataset(save_path=save_path)
    elif dataset_name[-5:] == ".loom":
        gene_dataset = LoomDataset(filename=dataset_name,
                                   save_path=save_path,
                                   url=url)
    elif dataset_name[-5:] == ".h5ad":
        gene_dataset = AnnDataset(dataset_name, save_path=save_path, url=url)
    elif ".csv" in dataset_name:
        gene_dataset = CsvDataset(dataset_name, save_path=save_path)
    else:
        raise "No such dataset available"
    return gene_dataset
Esempio n. 6
0
def test_synthetic_3():
    gene_dataset = SyntheticDataset()
    trainer = base_benchmark(gene_dataset)
    adapter_trainer = AdapterTrainer(
        trainer.model, gene_dataset, trainer.train_set, frequency=1
    )
    adapter_trainer.train(n_path=1, n_epochs=1)
Esempio n. 7
0
def load_datasets(dataset_name, save_path="data/", url=None):
    if dataset_name == "synthetic":
        gene_dataset = SyntheticDataset()
    elif dataset_name == "cortex":
        gene_dataset = CortexDataset()
    elif dataset_name == "brain_large":
        gene_dataset = BrainLargeDataset(save_path=save_path)
    elif dataset_name == "retina":
        gene_dataset = RetinaDataset(save_path=save_path)
    elif dataset_name == "cbmc":
        gene_dataset = CbmcDataset(save_path=save_path)
    elif dataset_name == "brain_small":
        gene_dataset = BrainSmallDataset(save_path=save_path)
    elif dataset_name == "hemato":
        gene_dataset = HematoDataset(save_path="data/HEMATO/")
    elif dataset_name == "pbmc":
        gene_dataset = PbmcDataset(save_path=save_path)
    elif dataset_name[-5:] == ".loom":
        gene_dataset = LoomDataset(filename=dataset_name, save_path=save_path, url=url)
    elif dataset_name[-5:] == ".h5ad":
        gene_dataset = AnnDataset(dataset_name, save_path=save_path, url=url)
    elif ".csv" in dataset_name:
        gene_dataset = CsvDataset(dataset_name, save_path=save_path)
    else:
        raise Exception("No such dataset available")
    return gene_dataset
Esempio n. 8
0
def test_synthetic_2():
    synthetic_dataset = SyntheticDataset()
    vaec = VAEC(synthetic_dataset.nb_genes, synthetic_dataset.n_batches, synthetic_dataset.n_labels)
    trainer_synthetic_vaec = JointSemiSupervisedTrainer(vaec, synthetic_dataset, use_cuda=use_cuda, frequency=1,
                                                        early_stopping_kwargs={'early_stopping_metric': 'll',
                                                                               'on': 'labelled_set',
                                                                               'save_best_state_metric': 'll'})
    trainer_synthetic_vaec.train(n_epochs=2)
Esempio n. 9
0
def test_filter_and_concat_datasets():
    cortex_dataset_1 = CortexDataset()
    cortex_dataset_1.subsample_genes(subset_genes=np.arange(0, 300))
    cortex_dataset_1.filter_cell_types(["microglia", "oligodendrocytes"])
    cortex_dataset_2 = CortexDataset()
    cortex_dataset_2.subsample_genes(subset_genes=np.arange(100, 400))
    cortex_dataset_2.filter_cell_types(
        ["endothelial-mural", "interneurons", "microglia", "oligodendrocytes"])
    cortex_dataset_2.filter_cell_types([2, 0])
    cortex_dataset_merged = GeneExpressionDataset.concat_datasets(
        cortex_dataset_1, cortex_dataset_2)
    assert cortex_dataset_merged.nb_genes == 200

    synthetic_dataset_1 = SyntheticDataset(n_batches=2, n_labels=5)
    synthetic_dataset_2 = SyntheticDataset(n_batches=3, n_labels=3)
    synthetic_merged_1 = GeneExpressionDataset.concat_datasets(
        synthetic_dataset_1, synthetic_dataset_2)
    assert synthetic_merged_1.n_batches == 5
    assert synthetic_merged_1.n_labels == 5

    synthetic_merged_2 = GeneExpressionDataset.concat_datasets(
        synthetic_dataset_1, synthetic_dataset_2, shared_labels=False)
    assert synthetic_merged_2.n_batches == 5
    assert synthetic_merged_2.n_labels == 8

    synthetic_dataset_1.filter_cell_types([0, 1, 2, 3])
    assert synthetic_dataset_1.n_labels == 4

    synthetic_dataset_1.subsample_cells(50)
    assert len(synthetic_dataset_1) == 50
Esempio n. 10
0
def test_nb_not_zinb():
    synthetic_dataset = SyntheticDataset()
    svaec = SVAEC(synthetic_dataset.nb_genes,
                  synthetic_dataset.n_batches,
                  synthetic_dataset.n_labels,
                  reconstruction_loss="nb")
    infer_synthetic_svaec = JointSemiSupervisedVariationalInference(
        svaec, synthetic_dataset, use_cuda=use_cuda)
    infer_synthetic_svaec.train(n_epochs=1)
Esempio n. 11
0
def test_nb_not_zinb():
    synthetic_dataset = SyntheticDataset()
    svaec = SCANVI(synthetic_dataset.nb_genes,
                   synthetic_dataset.n_batches,
                   synthetic_dataset.n_labels,
                   labels_groups=[0, 0, 1],
                   reconstruction_loss="nb")
    trainer_synthetic_svaec = JointSemiSupervisedTrainer(svaec, synthetic_dataset, use_cuda=use_cuda)
    trainer_synthetic_svaec.train(n_epochs=1)
Esempio n. 12
0
    def test_batch_correction(self):
        dataset = SyntheticDataset(batch_size=100, nb_genes=100, n_batches=3)

        n_genes = dataset.nb_genes
        n_top = n_genes // 2
        dataset._highly_variable_genes(n_bins=3, flavor="seurat_v2")
        df = dataset._highly_variable_genes(n_bins=3,
                                            n_top_genes=n_top,
                                            flavor="seurat_v2")
        assert df["highly_variable"].sum() >= n_top

        dataset.filter_genes_by_count(2, per_batch=True)
        dataset.subsample_genes(new_n_genes=n_top)
        new_genes = dataset.nb_genes
        assert n_genes > new_genes, "subsample_genes did not filter out genes"
Esempio n. 13
0
def test_totalvi(save_path):
    synthetic_dataset_one_batch = SyntheticDataset(n_batches=1)
    totalvi_benchmark(synthetic_dataset_one_batch,
                      n_epochs=1,
                      use_cuda=use_cuda)
    synthetic_dataset_two_batches = SyntheticDataset(n_batches=2)
    totalvi_benchmark(synthetic_dataset_two_batches,
                      n_epochs=1,
                      use_cuda=use_cuda)

    # adversarial testing
    dataset = synthetic_dataset_two_batches
    totalvae = TOTALVI(dataset.nb_genes,
                       len(dataset.protein_names),
                       n_batch=dataset.n_batches)
    trainer = TotalTrainer(
        totalvae,
        dataset,
        train_size=0.5,
        use_cuda=use_cuda,
        early_stopping_kwargs=None,
        use_adversarial_loss=True,
    )
    trainer.train(n_epochs=1)

    with tempfile.TemporaryDirectory() as temp_dir:
        posterior_save_path = os.path.join(temp_dir, "posterior_data")
        original_post = trainer.create_posterior(
            totalvae,
            dataset,
            indices=np.arange(len(dataset)),
            type_class=TotalPosterior,
        )
        original_post.save_posterior(posterior_save_path)
        new_totalvae = TOTALVI(dataset.nb_genes,
                               len(dataset.protein_names),
                               n_batch=dataset.n_batches)
        new_post = load_posterior(posterior_save_path,
                                  model=new_totalvae,
                                  use_cuda=False)
        assert new_post.posterior_type == "TotalPosterior"
        assert np.array_equal(new_post.gene_dataset.protein_expression,
                              dataset.protein_expression)
Esempio n. 14
0
def test_totalvi(save_path):
    synthetic_dataset_one_batch = SyntheticDataset(n_batches=1)
    totalvi_benchmark(synthetic_dataset_one_batch, n_epochs=1, use_cuda=use_cuda)
    synthetic_dataset_two_batches = SyntheticDataset(n_batches=2)
    totalvi_benchmark(synthetic_dataset_two_batches, n_epochs=1, use_cuda=use_cuda)

    # adversarial testing
    dataset = synthetic_dataset_two_batches
    totalvae = TOTALVI(
        dataset.nb_genes, len(dataset.protein_names), n_batch=dataset.n_batches
    )
    trainer = TotalTrainer(
        totalvae,
        dataset,
        train_size=0.5,
        use_cuda=use_cuda,
        early_stopping_kwargs=None,
        use_adversarial_loss=True,
    )
    trainer.train(n_epochs=1)
Esempio n. 15
0
def test_synthetic_2():
    synthetic_dataset = SyntheticDataset()
    vaec = VAEC(synthetic_dataset.nb_genes, synthetic_dataset.n_batches,
                synthetic_dataset.n_labels)
    infer_synthetic_vaec = JointSemiSupervisedVariationalInference(
        vaec,
        synthetic_dataset,
        use_cuda=use_cuda,
        early_stopping_metric='ll',
        frequency=1,
        save_best_state_metric='accuracy',
        on='labelled')
    infer_synthetic_vaec.train(n_epochs=20)
    infer_synthetic_vaec.svc_rf(unit_test=True)
Esempio n. 16
0
def test_synthetic_1():
    synthetic_dataset = SyntheticDataset()
    svaec = SVAEC(synthetic_dataset.nb_genes, synthetic_dataset.n_batches,
                  synthetic_dataset.n_labels)
    infer_synthetic_svaec = JointSemiSupervisedVariationalInference(
        svaec, synthetic_dataset, use_cuda=use_cuda)
    infer_synthetic_svaec.train(n_epochs=1)
    infer_synthetic_svaec.entropy_batch_mixing('labelled')
    infer_synthetic_svaec.show_t_sne('labelled', n_samples=50)
    infer_synthetic_svaec.show_t_sne('unlabelled',
                                     n_samples=50,
                                     color_by='labels')
    infer_synthetic_svaec.show_t_sne('labelled',
                                     n_samples=50,
                                     color_by='batches and labels')
    infer_synthetic_svaec.clustering_scores('labelled')
Esempio n. 17
0
def test_autozi(save_path):
    data = SyntheticDataset(n_batches=1)

    for disp_zi in ["gene", "gene-label"]:
        autozivae = AutoZIVAE(
            n_input=data.nb_genes,
            dispersion=disp_zi,
            zero_inflation=disp_zi,
            n_labels=data.n_labels,
        )
        trainer_autozivae = UnsupervisedTrainer(
            model=autozivae, gene_dataset=data, train_size=0.5
        )
        trainer_autozivae.train(n_epochs=2, lr=1e-2)
        trainer_autozivae.test_set.elbo()
        trainer_autozivae.test_set.reconstruction_error()
        trainer_autozivae.test_set.marginal_ll()
Esempio n. 18
0
def test_hierarchy():
    synthetic_dataset = SyntheticDataset()
    svaec = SCANVI(
        synthetic_dataset.nb_genes,
        synthetic_dataset.n_batches,
        synthetic_dataset.n_labels,
        ontology=[
            np.array([[1, 1, 0], [0, 0, 1]]),
            np.array([[1, 0, 1, 0], [0, 0, 1, 0], [0, 0, 1, 1]])
        ],
        use_ontology=True,
        reconstruction_loss="zinb",
        n_layers=3,
    )
    trainer_synthetic_svaec = JointSemiSupervisedTrainer(svaec,
                                                         synthetic_dataset,
                                                         use_cuda=use_cuda)
    trainer_synthetic_svaec.train(n_epochs=1)
Esempio n. 19
0
def test_synthetic_2():
    synthetic_dataset = SyntheticDataset()
    vaec = VAEC(
        synthetic_dataset.nb_genes,
        synthetic_dataset.n_batches,
        synthetic_dataset.n_labels,
    )
    trainer_synthetic_vaec = JointSemiSupervisedTrainer(
        vaec,
        synthetic_dataset,
        use_cuda=use_cuda,
        frequency=1,
        early_stopping_kwargs={
            "early_stopping_metric": "reconstruction_error",
            "on": "labelled_set",
            "save_best_state_metric": "reconstruction_error",
        },
    )
    trainer_synthetic_vaec.train(n_epochs=2)
Esempio n. 20
0
def test_synthetic_1():
    synthetic_dataset = SyntheticDataset()
    svaec = SVAEC(synthetic_dataset.nb_genes, synthetic_dataset.n_batches,
                  synthetic_dataset.n_labels)
    infer_synthetic_svaec = JointSemiSupervisedVariationalInference(
        svaec, synthetic_dataset, use_cuda=use_cuda)
    infer_synthetic_svaec.fit(n_epochs=1)
    infer_synthetic_svaec.entropy_batch_mixing('labelled')

    vaec = VAEC(synthetic_dataset.nb_genes, synthetic_dataset.n_batches,
                synthetic_dataset.n_labels)
    infer_synthetic_vaec = JointSemiSupervisedVariationalInference(
        vaec,
        synthetic_dataset,
        use_cuda=use_cuda,
        early_stopping_metric='ll',
        frequency=1,
        save_best_state_metric='accuracy',
        on='labelled')
    infer_synthetic_vaec.fit(n_epochs=20)
    infer_synthetic_vaec.svc_rf(unit_test=True)
    infer_synthetic_vaec.show_t_sne('labelled', n_samples=50)
Esempio n. 21
0
    def test_dense_subsample_genes(self):
        dataset = SyntheticDataset(batch_size=100, nb_genes=100, n_batches=3)

        n_genes = dataset.nb_genes
        n_top = n_genes // 2
        dataset.subsample_genes(new_n_genes=n_top, mode="cell_ranger")
        assert dataset.nb_genes == n_top

        # With Seurat v2
        dataset = SyntheticDataset(batch_size=100, nb_genes=100, n_batches=3)
        dataset.subsample_genes(new_n_genes=n_top, mode="seurat_v2")
        assert dataset.nb_genes == n_top

        # With Seurat v3
        dataset = SyntheticDataset(batch_size=100, nb_genes=100, n_batches=3)
        dataset.subsample_genes(new_n_genes=n_top, mode="seurat_v3")
        assert dataset.nb_genes == n_top

        # make sure constant genes have low scores
        dataset = SyntheticDataset(batch_size=100, nb_genes=100, n_batches=3)
        dataset.X[:, -1] = np.zeros_like(dataset.X[:, -1])
        df = dataset._highly_variable_genes(n_top_genes=n_top,
                                            flavor="seurat_v3")

        assert df.loc[str(dataset.nb_genes -
                          1)]["highly_variable_median_variance"] == 0
Esempio n. 22
0
def test_LDVAE(save_path):
    synthetic_datset_one_batch = SyntheticDataset(n_batches=1)
    ldvae_benchmark(synthetic_datset_one_batch, n_epochs=1, use_cuda=False)
    synthetic_datset_two_batches = SyntheticDataset(n_batches=2)
    ldvae_benchmark(synthetic_datset_two_batches, n_epochs=1, use_cuda=False)
Esempio n. 23
0
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""Test improved calculation of log_zinb_positive and log_nb_positive for better performance on brain_large."""

from scvi.dataset import SyntheticDataset
import torch
import torch.nn.functional as F

batch_size = 128
nb_genes = 720
n_batches = 1


synthetic_dataset = SyntheticDataset(batch_size=batch_size, nb_genes=nb_genes,
                                     n_batches=n_batches)
x = torch.from_numpy(synthetic_dataset.X)
mu = torch.rand(batch_size, nb_genes)
theta = torch.rand(batch_size, nb_genes)
pi = torch.rand(batch_size, nb_genes)
eps = 1e-2


def test_log_zinb_positive():
    """Test that the new method to compute log_zinb_positive is the same as the existing."""

    def existing_method(x, mu, theta, pi, eps=1e-8):
        case_zero = (F.softplus((- pi + theta * torch.log(theta + eps) - theta * torch.log(theta + mu + eps))) -
                     F.softplus(-pi))

        case_non_zero = - pi - F.softplus(-pi) + \
Esempio n. 24
0
 def test_train_one(self):
     dataset = SyntheticDataset(batch_size=10, nb_genes=10)
     unsupervised_training_one_epoch(dataset)
    def test_dense_subsample_genes(self):
        dataset = SyntheticDataset(batch_size=100, nb_genes=100, n_batches=3)

        n_genes = dataset.nb_genes
        n_top = n_genes // 2
        dataset.subsample_genes(new_n_genes=n_top, mode="cell_ranger")
        assert dataset.nb_genes == n_top

        # With Seurat v2
        dataset = SyntheticDataset(batch_size=100, nb_genes=100, n_batches=3)
        dataset.subsample_genes(new_n_genes=n_top, mode="seurat_v2")
        assert dataset.nb_genes == n_top

        # With Seurat v3
        dataset = SyntheticDataset(batch_size=100, nb_genes=100, n_batches=3)
        dataset.subsample_genes(new_n_genes=n_top, mode="seurat_v3")
        assert dataset.nb_genes == n_top
Esempio n. 26
0
def test_filter_and_concat_datasets():
    cortex_dataset_1 = CortexDataset(save_path='tests/data/')
    cortex_dataset_1.subsample_genes(subset_genes=np.arange(0, 3))
    cortex_dataset_1.filter_cell_types(["microglia", "oligodendrocytes"])
    cortex_dataset_2 = CortexDataset(save_path='tests/data/')
    cortex_dataset_2.subsample_genes(subset_genes=np.arange(1, 4))
    cortex_dataset_2.filter_cell_types(["endothelial-mural", "interneurons", "microglia", "oligodendrocytes"])
    cortex_dataset_2.filter_cell_types([2, 0])
    cortex_dataset_merged = GeneExpressionDataset.concat_datasets(cortex_dataset_1, cortex_dataset_2)
    assert cortex_dataset_merged.nb_genes == 2

    synthetic_dataset_1 = SyntheticDataset(n_batches=2, n_labels=5)
    synthetic_dataset_2 = SyntheticDataset(n_batches=3, n_labels=3)
    synthetic_merged_1 = GeneExpressionDataset.concat_datasets(synthetic_dataset_1, synthetic_dataset_2)
    assert synthetic_merged_1.n_batches == 5
    assert synthetic_merged_1.n_labels == 5

    synthetic_merged_2 = GeneExpressionDataset.concat_datasets(synthetic_dataset_1, synthetic_dataset_2,
                                                               shared_labels=False)
    assert synthetic_merged_2.n_batches == 5
    assert synthetic_merged_2.n_labels == 8

    synthetic_dataset_1.filter_cell_types([0, 1, 2, 3])
    assert synthetic_dataset_1.n_labels == 4

    synthetic_dataset_1.subsample_cells(50)
    assert len(synthetic_dataset_1) == 50

    synthetic_dataset_3 = SyntheticDataset(n_labels=6)
    synthetic_dataset_3.cell_types = np.arange(6).astype(np.str)
    synthetic_dataset_3.map_cell_types({"2": "9", ("4", "3"): "8"})
Esempio n. 27
0
def test_particular_benchmark():
    synthetic_dataset = SyntheticDataset()
    benchmark(synthetic_dataset, n_epochs=1, use_cuda=False)
Esempio n. 28
0
    def test_concatenate_from_scvi_to_loom(self):
        try:
            random_seed = 0
            dset1_args = {
                "batch_size": 10,
                "nb_genes": 4,
                "n_proteins": 4,
                "n_batches": 4,
                "n_labels": 3,
                "seed": random_seed
            }
            dset2_args = {
                "batch_size": 30,
                "nb_genes": 2,
                "n_proteins": 6,
                "n_batches": 2,
                "n_labels": 4,
                "seed": random_seed
            }
            dset1, dset2 = (SyntheticDataset(**dset1_args),
                            SyntheticDataset(**dset2_args))

            # Concatenate the datasets in memory first as reference
            union_from_mem_to_mem = UnionDataset(save_path=save_path,
                                                 low_memory=True,
                                                 ignore_batch_annotation=False)
            union_from_mem_to_mem.build_genemap(data_source="memory",
                                                gene_datasets=[dset1, dset2])
            union_from_mem_to_mem.join_datasets(data_source='memory',
                                                data_target='memory',
                                                gene_datasets=[dset1, dset2])

            union_from_mem_to_mem_perturb = UnionDataset(
                save_path=save_path,
                low_memory=True,
                ignore_batch_annotation=False)
            union_from_mem_to_mem_perturb.build_genemap(
                data_source="memory", gene_datasets=[dset1, dset2])
            union_from_mem_to_mem_perturb.join_datasets(
                data_source='memory',
                data_target='memory',
                gene_datasets=[dset2, dset1])

            # Load datasets from scvi and concatenate them in memory
            union_from_scvi_to_loom = UnionDataset(
                save_path=save_path,
                low_memory=True,
                ignore_batch_annotation=False)
            union_from_scvi_to_loom.build_genemap(data_source="memory",
                                                  gene_datasets=[dset1, dset2])
            union_from_scvi_to_loom.join_datasets(
                data_source='scvi',
                data_target='loom',
                dataset_classes=[SyntheticDataset, SyntheticDataset],
                dataset_args=[dset1_args, dset2_args],
                out_filename="test_concat.loom")

            self.assertTrue(
                len(union_from_scvi_to_loom) == (len(dset1) + len(dset2)))

            random_indices = np.sort(
                np.random.choice(np.arange(len(union_from_scvi_to_loom)),
                                 size=int(len(union_from_scvi_to_loom) / 5),
                                 replace=False))

            self.assertTrue(
                (union_from_scvi_to_loom.X[random_indices]
                 == union_from_mem_to_mem.X[random_indices].toarray()).all() or
                (union_from_scvi_to_loom.X[random_indices]
                 == union_from_mem_to_mem_perturb.X[random_indices].toarray()
                 ).all())

            self.assertTrue((union_from_scvi_to_loom.gene_names ==
                             union_from_mem_to_mem.gene_names).all())
            self.assertTrue((union_from_scvi_to_loom.cell_types ==
                             union_from_mem_to_mem.cell_types).all())
            self.assertTrue(
                (union_from_scvi_to_loom.batch_indices
                 == union_from_mem_to_mem.batch_indices).all()
                or (union_from_scvi_to_loom.batch_indices
                    == union_from_mem_to_mem_perturb.batch_indices).all())
            self.assertTrue((union_from_scvi_to_loom.labels
                             == union_from_mem_to_mem.labels).all()
                            or (union_from_scvi_to_loom.labels
                                == union_from_mem_to_mem_perturb.labels).all())

            unsupervised_training_one_epoch(union_from_scvi_to_loom)

        except Exception as e:
            if os.path.exists(os.path.join(save_path, "test_concat.loom")):
                os.remove(os.path.join(save_path, "test_concat.loom"))
            raise e
Esempio n. 29
0
    def test_batch_correction(self):
        dataset = SyntheticDataset(batch_size=100, nb_genes=100, n_batches=3)

        n_genes = dataset.nb_genes
        n_top = n_genes // 2
        dataset._highly_variable_genes(n_bins=3, flavor="seurat_v2")
        df = dataset._highly_variable_genes(n_bins=3,
                                            n_top_genes=n_top,
                                            flavor="seurat_v2")
        assert df["highly_variable"].sum() >= n_top

        dataset.filter_genes_by_count(2, per_batch=True)
        dataset.subsample_genes(new_n_genes=n_top)
        new_genes = dataset.nb_genes
        assert n_genes > new_genes, "subsample_genes did not filter out genes"

        dataset = SyntheticDataset(batch_size=100, nb_genes=100, n_batches=3)
        n_genes = dataset.nb_genes
        n_top = n_genes // 2
        df = dataset._highly_variable_genes(n_bins=3,
                                            flavor="seurat_v2",
                                            batch_correction=False,
                                            n_top_genes=n_top)
        assert ("highly_variable_nbatches" not in df.columns
                ), "HVG dataframe should not contain batch information"
        df = dataset._highly_variable_genes(n_bins=3,
                                            flavor="seurat_v2",
                                            batch_correction=True)
        assert "highly_variable_nbatches" in df.columns
        assert "highly_variable_intersection" in df.columns
        df = dataset._highly_variable_genes(n_bins=3,
                                            flavor="seurat_v3",
                                            batch_correction=False,
                                            n_top_genes=n_top)
        assert ("highly_variable_nbatches" not in df.columns
                ), "HVG dataframe should not contain batch information"
        df = dataset._highly_variable_genes(n_bins=3,
                                            flavor="seurat_v3",
                                            batch_correction=True,
                                            n_top_genes=n_top)
        assert "highly_variable_nbatches" in df.columns
        assert "highly_variable_intersection" in df.columns
Esempio n. 30
0
def test_synthetic_3():
    infer = base_benchmark(SyntheticDataset())
    adapt_encoder(infer, n_path=1, n_epochs=1, frequency=1)