示例#1
0
文件: test_scvi.py 项目: tkisss/scVI
def test_multibatches_features():
    data = [
        np.random.randint(1, 5, size=(20, 10)),
        np.random.randint(1, 10, size=(20, 10)),
        np.random.randint(1, 10, size=(20, 10)),
        np.random.randint(1, 10, size=(30, 10)),
    ]
    dataset = GeneExpressionDataset()
    dataset.populate_from_per_batch_list(data)
    vae = VAE(dataset.nb_genes, dataset.n_batches)
    trainer = UnsupervisedTrainer(vae, dataset, train_size=0.5, use_cuda=use_cuda)
    trainer.train(n_epochs=2)
    trainer.test_set.imputation(n_samples=2, transform_batch=0)
    trainer.train_set.imputation(n_samples=2, transform_batch=[0, 1, 2])
示例#2
0
 def test_populate_from_per_batch_list(self):
     data = [
         np.random.randint(1, 5, size=(7, 10)),
         np.random.randint(1, 5, size=(5, 10)),
         np.random.randint(1, 5, size=(3, 10)),
     ]
     dataset = GeneExpressionDataset()
     dataset.populate_from_per_batch_list(data)
     self.assertEqual(dataset.nb_cells, 15)
     self.assertEqual(dataset.nb_genes, 10)
     true_batch_indices = np.concatenate([
         np.zeros((7, 1), dtype=int),
         np.ones((5, 1), dtype=int),
         2 * np.ones((3, 1), dtype=int),
     ])
     self.assertListEqual(true_batch_indices.tolist(),
                          dataset.batch_indices.tolist())
示例#3
0
    def test_batch_correction(self):
        data = [
            np.random.randint(1, 5, size=(50, 25)),
            np.random.randint(1, 5, size=(50, 25)),
            np.random.randint(1, 5, size=(50, 25)),
        ]
        dataset = GeneExpressionDataset()
        dataset.populate_from_per_batch_list(data)

        n_genes = dataset.nb_genes
        n_top = n_genes // 2
        dataset.highly_variable_genes(n_bins=3, flavor="seurat")

        dataset.highly_variable_genes(n_bins=3, flavor="seurat")

        df = dataset.highly_variable_genes(n_bins=3,
                                           n_top_genes=n_top,
                                           flavor="seurat")
        assert df["highly_variable"].sum() >= n_top
        pass
示例#4
0
    def test_dense_subsample_genes(self):
        data = [
            np.random.randint(1, 5, size=(50, 26)),
            np.random.randint(1, 5, size=(50, 26)),
            np.random.randint(1, 5, size=(50, 26)),
        ]

        # With default
        dataset = GeneExpressionDataset()
        dataset.populate_from_per_batch_list(data)
        n_genes = dataset.nb_genes
        n_top = n_genes // 2
        dataset.subsample_genes(new_n_genes=n_top)
        assert dataset.nb_genes < n_genes
        # For some reason the new number of genes can be slightly different than n_top

        # With Seurat
        dataset = GeneExpressionDataset()
        dataset.populate_from_per_batch_list(data)
        dataset.subsample_genes(new_n_genes=n_top, mode="seurat")
        assert dataset.nb_genes < n_genes
示例#5
0
    def test_batch_correction(self):
        data = [
            np.random.randint(1, 5, size=(50, 25)),
            np.random.randint(1, 5, size=(50, 25)),
            np.random.randint(1, 5, size=(50, 25)),
        ]
        dataset = GeneExpressionDataset()
        dataset.populate_from_per_batch_list(data)

        n_genes = dataset.nb_genes
        n_top = n_genes // 2
        dataset._highly_variable_genes(n_bins=3, flavor="seurat_v2")
        df = dataset._highly_variable_genes(
            n_bins=3, n_top_genes=n_top, flavor="seurat_v2"
        )
        assert df["highly_variable"].sum() >= n_top

        dataset.subsample_genes(new_n_genes=n_top)
        new_genes = dataset.nb_genes
        assert n_genes > new_genes, "subsample_genes did not filter out genes"
        pass
示例#6
0
    def test_dense_subsample_genes(self):
        data = [
            np.random.randint(1, 5, size=(50, 26)),
            np.random.randint(1, 5, size=(50, 26)),
            np.random.randint(1, 5, size=(50, 26)),
        ]

        # With default
        dataset = GeneExpressionDataset()
        dataset.populate_from_per_batch_list(data)
        n_genes = dataset.nb_genes
        n_top = n_genes // 2
        dataset.subsample_genes(new_n_genes=n_top, mode="cell_ranger")
        assert dataset.nb_genes == n_top

        # With Seurat v2
        dataset = GeneExpressionDataset()
        dataset.populate_from_per_batch_list(data)
        dataset.subsample_genes(new_n_genes=n_top, mode="seurat_v2")
        assert dataset.nb_genes == n_top

        # With Seurat v3
        dataset = GeneExpressionDataset()
        dataset.populate_from_per_batch_list(data)
        dataset.subsample_genes(new_n_genes=n_top, mode="seurat_v3")
        assert dataset.nb_genes == n_top