def test_trim_dense(self):
        gene_data = TestDataSingleCellLike.gene_metadata
        gene_data.index = gene_data.iloc[:, 0]

        adata = InferelatorData(self.expr, gene_data=gene_data)
        adata.trim_genes(remove_constant_genes=False)

        pdt.assert_frame_equal(
            self.expr.reindex(CORRECT_GENES_INTERSECT,
                              axis=1).astype(np.int32), adata._adata.to_df())

        adata.trim_genes(remove_constant_genes=True)
        pdt.assert_frame_equal(
            self.expr.reindex(CORRECT_GENES_NZ_VAR, axis=1).astype(np.int32),
            adata._adata.to_df())
    def test_trim_sparse(self):
        gene_data = TestDataSingleCellLike.gene_metadata
        gene_data.index = gene_data.iloc[:, 0]

        adata_sparse = InferelatorData(
            sparse.csr_matrix(
                TestDataSingleCellLike.expression_matrix.values.T),
            gene_names=TestDataSingleCellLike.expression_matrix.index,
            meta_data=TestDataSingleCellLike.meta_data.copy(),
            gene_data=gene_data)

        adata_sparse.trim_genes(remove_constant_genes=False)
        pdt.assert_frame_equal(
            self.expr.reindex(CORRECT_GENES_INTERSECT, axis=1),
            adata_sparse._adata.to_df())

        adata_sparse.trim_genes(remove_constant_genes=True)
        pdt.assert_frame_equal(self.expr.reindex(CORRECT_GENES_NZ_VAR, axis=1),
                               adata_sparse._adata.to_df())