Esempio n. 1
0
    def test_populate_from_datasets_with_measurments(self):
        data = np.random.randint(1, 5, size=(5, 10))
        gene_names = np.array(["gene_%d" % i for i in range(10)])

        paired1 = np.ones((5, 5)) * np.arange(0, 5)
        pair_names1 = ["gabou", "achille", "pedro", "oclivio", "gayoso"]
        y1 = CellMeasurement(name="dev",
                             data=paired1,
                             columns_attr_name="dev_names",
                             columns=pair_names1)
        paired2 = np.ones((5, 4)) * np.arange(0, 4)
        pair_names2 = ["gabou", "oclivio", "achille", "pedro"]
        y2 = CellMeasurement(name="dev",
                             data=paired2,
                             columns_attr_name="dev_names",
                             columns=pair_names2)

        dataset1 = GeneExpressionDataset()
        dataset2 = GeneExpressionDataset()

        dataset1.populate_from_data(data, Ys=[y1], gene_names=gene_names)
        dataset2.populate_from_data(data, Ys=[y2], gene_names=gene_names)

        dataset = GeneExpressionDataset()
        dataset.populate_from_datasets([dataset1, dataset2])

        self.assertTrue(hasattr(dataset, "dev"))
        self.assertTrue(hasattr(dataset, "dev_names"))

        self.assertListEqual(dataset.dev_names.tolist(),
                             ["achille", "gabou", "oclivio", "pedro"])
        self.assertListEqual(dataset.dev[0].tolist(), [1, 0, 3, 2])
        self.assertListEqual(dataset.dev[5].tolist(), [2, 0, 1, 3])
Esempio n. 2
0
    def test_populate_from_datasets_with_measurments(self):
        data = np.random.randint(1, 5, size=(5, 10))
        gene_names = np.array(["gene_%d" % i for i in range(10)])

        paired1 = np.ones((5, 5)) * np.arange(0, 5)
        pair_names1 = ["gabou", "achille", "pedro", "oclivio", "gayoso"]
        y1 = CellMeasurement(name="dev",
                             data=paired1,
                             columns_attr_name="dev_names",
                             columns=pair_names1)
        paired2 = np.ones((5, 4)) * np.arange(0, 4)
        pair_names2 = ["gabou", "oclivio", "achille", "pedro"]
        y2 = CellMeasurement(name="dev",
                             data=paired2,
                             columns_attr_name="dev_names",
                             columns=pair_names2)

        dataset1 = GeneExpressionDataset()
        dataset2 = GeneExpressionDataset()

        dataset1.populate_from_data(data, Ys=[y1], gene_names=gene_names)
        dataset2.populate_from_data(data, Ys=[y2], gene_names=gene_names)

        dataset = GeneExpressionDataset()
        dataset.populate_from_datasets(
            [copy.deepcopy(dataset1),
             copy.deepcopy(dataset2)])

        self.assertTrue(hasattr(dataset, "dev"))
        self.assertTrue(hasattr(dataset, "dev_names"))

        self.assertListEqual(dataset.dev_names.tolist(),
                             ["achille", "gabou", "oclivio", "pedro"])
        self.assertListEqual(dataset.dev[0].tolist(), [1, 0, 3, 2])
        self.assertListEqual(dataset.dev[5].tolist(), [2, 0, 1, 3])

        # Take union of dev columns, 0s fill remainder
        dataset = GeneExpressionDataset()
        dataset.populate_from_datasets(
            [copy.deepcopy(dataset1),
             copy.deepcopy(dataset2)],
            cell_measurement_intersection={"dev": False},
        )
        self.assertListEqual(
            dataset.dev_names.tolist(),
            ["achille", "gabou", "gayoso", "oclivio", "pedro"],
        )
        mask = dataset.get_batch_mask_cell_measurement("dev")
        self.assertEqual(mask[1][2].astype(int), 0)
Esempio n. 3
0
    def test_collate_add(self):
        data = np.ones((25, 2)) * np.arange(0, 25).reshape((-1, 1))
        batch_indices = np.arange(0, 25).reshape((-1, 1))
        x_coords = np.arange(0, 25).reshape((-1, 1))
        proteins = (np.ones((25, 3)) + np.arange(0, 25).reshape(
            (-1, 1)) + np.arange(0, 3))
        proteins_name = ["A", "B", "C"]
        dataset = GeneExpressionDataset()
        dataset.populate_from_data(
            data,
            batch_indices=batch_indices,
            cell_attributes_dict={"x_coords": x_coords},
            Ys=[
                CellMeasurement(
                    name="proteins",
                    data=proteins,
                    columns_attr_name="protein_names",
                    columns=proteins_name,
                )
            ],
        )

        collate_fn = dataset.collate_fn_builder(add_attributes_and_types={
            "x_coords": np.float32,
            "proteins": np.float32
        })
        x, mean, var, batch, labels, x_coords_tensor, proteins_tensor = collate_fn(
            [1, 2])
        self.assertListEqual(x_coords_tensor.tolist(), [[1.0], [2.0]])
        self.assertListEqual(proteins_tensor.tolist(),
                             [[2.0, 3.0, 4.0], [3.0, 4.0, 5.0]])
Esempio n. 4
0
    def __init__(
        self,
        ad: anndata.AnnData,
        batch_label: str = "batch_indices",
        ctype_label: str = "cell_types",
        class_label: str = "labels",
        use_raw: bool = False,
        cell_measurements_col_mappings: Optional[Dict[str, str]] = None,
    ):
        super().__init__()
        (
            X,
            batch_indices,
            labels,
            gene_names,
            cell_types,
            obs,
            obsm,
            var,
            _,
            uns,
        ) = extract_data_from_anndata(
            ad,
            batch_label=batch_label,
            ctype_label=ctype_label,
            class_label=class_label,
            use_raw=use_raw,
        )

        # Dataset API takes a dict as input
        obs = obs.to_dict(orient="list")
        var = var.to_dict(orient="list")

        # add external cell measurements
        Ys = []
        if cell_measurements_col_mappings is not None:
            for name, attr_name in cell_measurements_col_mappings.items():
                columns = uns[name]
                measurement = CellMeasurement(
                    name=name,
                    data=obsm[name],
                    columns_attr_name=attr_name,
                    columns=columns,
                )
                Ys.append(measurement)

        self.populate_from_data(
            X=X,
            Ys=Ys,
            labels=labels,
            batch_indices=batch_indices,
            gene_names=gene_names,
            cell_types=cell_types,
            cell_attributes_dict=obs,
            gene_attributes_dict=var,
        )
        self.filter_cells_by_count()
Esempio n. 5
0
    def populate(self):
        ad = anndata.read_h5ad(os.path.join(
            self.save_path, self.filenames[0]))  # obs = cells, var = genes

        # extract GeneExpressionDataset relevant attributes
        # and provide access to annotations from the underlying AnnData object.
        (
            X,
            batch_indices,
            labels,
            gene_names,
            cell_types,
            obs,
            obsm,
            var,
            _,
            uns,
        ) = extract_data_from_anndata(
            ad,
            batch_label=self.batch_label,
            ctype_label=self.ctype_label,
            class_label=self.class_label,
            use_raw=self.use_raw,
        )
        # Dataset API takes a dict as input
        obs = obs.to_dict(orient="list")
        var = var.to_dict(orient="list")

        # add external cell measurements
        Ys = []
        if self.cell_measurements_col_mappings_temp is not None:
            for name, attr_name in self.cell_measurements_col_mappings_temp.items(
            ):
                columns = uns[attr_name]
                measurement = CellMeasurement(
                    name=name,
                    data=obsm[name],
                    columns_attr_name=attr_name,
                    columns=columns,
                )
                Ys.append(measurement)

        self.populate_from_data(
            X=X,
            Ys=Ys,
            labels=labels,
            batch_indices=batch_indices,
            gene_names=gene_names,
            cell_types=cell_types,
            cell_attributes_dict=obs,
            gene_attributes_dict=var,
        )
        self.filter_cells_by_count()

        del self.cell_measurements_col_mappings_temp
Esempio n. 6
0
    def __init__(
        self,
        batch_size: int = 200,
        nb_genes: int = 100,
        n_proteins: int = 100,
        n_batches: int = 2,
        n_labels: int = 3,
    ):
        super().__init__()
        # Generating samples according to a ZINB process
        data = np.random.negative_binomial(5,
                                           0.3,
                                           size=(n_batches, batch_size,
                                                 nb_genes))
        mask = np.random.binomial(n=1,
                                  p=0.7,
                                  size=(n_batches, batch_size, nb_genes))
        data = data * mask  # We put the batch index first
        labels = np.random.randint(0,
                                   n_labels,
                                   size=(n_batches, batch_size, 1))
        cell_types = ["undefined_%d" % i for i in range(n_labels)]

        self.populate_from_per_batch_list(
            data,
            labels_per_batch=labels,
            gene_names=np.arange(nb_genes).astype(np.str),
            cell_types=cell_types,
        )
        # clear potentially unused cell_types
        self.remap_categorical_attributes()

        # Protein measurements
        p_data = np.random.negative_binomial(5,
                                             0.3,
                                             size=(self.X.shape[0],
                                                   n_proteins))
        protein_data = CellMeasurement(
            name="protein_expression",
            data=p_data,
            columns_attr_name="protein_names",
            columns=np.arange(p_data.shape[1]),
        )
        self.initialize_cell_measurement(protein_data)
Esempio n. 7
0
    def test_populate_from_data_with_measurements(self):
        data = np.ones((25, 10)) * 100
        paired = np.ones((25, 4)) * np.arange(0, 4)
        pair_names = ["gabou", "achille", "pedro", "oclivio"]
        y = CellMeasurement(name="dev",
                            data=paired,
                            columns_attr_name="dev_names",
                            columns=pair_names)
        dataset = GeneExpressionDataset()

        dataset.populate_from_data(data, Ys=[y])

        self.assertEqual(dataset.nb_genes, 10)
        self.assertEqual(dataset.nb_cells, 25)

        self.assertTrue(hasattr(dataset, "dev"))
        self.assertTrue(hasattr(dataset, "dev_names"))

        self.assertListEqual(dataset.dev_names.tolist(), pair_names)
        self.assertListEqual(dataset.dev[0].tolist(), [0, 1, 2, 3])
Esempio n. 8
0
    def populate(self):
        logger.info("Preprocessing dataset")

        was_extracted = False
        if len(self.filenames) > 0:
            file_path = os.path.join(self.save_path, self.filenames[0])
            if not os.path.exists(file_path[:-7]):  # nothing extracted yet
                if tarfile.is_tarfile(file_path):
                    logger.info("Extracting tar file")
                    tar = tarfile.open(file_path, "r:gz")
                    tar.extractall(path=self.save_path)
                    was_extracted = True
                    tar.close()

        # get exact path of the extract, for robustness to changes is the 10X storage logic
        path_to_data, suffix = self.find_path_to_data()

        # get filenames, according to 10X storage logic
        measurements_filename = "genes.tsv" if suffix == "" else "features.tsv.gz"
        barcode_filename = "barcodes.tsv" + suffix

        matrix_filename = "matrix.mtx" + suffix
        expression_data = sp_io.mmread(os.path.join(path_to_data, matrix_filename)).T
        if self.dense:
            expression_data = expression_data.A
        else:
            expression_data = csr_matrix(expression_data)

        # group measurements by type (e.g gene, protein)
        # in case there are multiple measurements, e.g protein
        # they are indicated in the third column
        gene_expression_data = expression_data
        measurements_info = pd.read_csv(
            os.path.join(path_to_data, measurements_filename), sep="\t", header=None
        )
        Ys = None
        if measurements_info.shape[1] < 3:
            gene_names = measurements_info[self.measurement_names_column].astype(np.str)
        else:
            gene_names = None
            for measurement_type in np.unique(measurements_info[2]):
                # .values required to work with sparse matrices
                measurement_mask = (measurements_info[2] == measurement_type).values
                measurement_data = expression_data[:, measurement_mask]
                measurement_names = measurements_info[self.measurement_names_column][
                    measurement_mask
                ].astype(np.str)
                if measurement_type == "Gene Expression":
                    gene_expression_data = measurement_data
                    gene_names = measurement_names
                else:
                    Ys = [] if Ys is None else Ys
                    if measurement_type == "Antibody Capture":
                        measurement_type = "protein_expression"
                        columns_attr_name = "protein_names"
                        # protein counts do not have many zeros so always make dense
                        if self.dense is not True:
                            measurement_data = measurement_data.A
                    else:
                        measurement_type = measurement_type.lower().replace(" ", "_")
                        columns_attr_name = measurement_type + "_names"
                    measurement = CellMeasurement(
                        name=measurement_type,
                        data=measurement_data,
                        columns_attr_name=columns_attr_name,
                        columns=measurement_names,
                    )
                    Ys.append(measurement)
            if gene_names is None:
                raise ValueError(
                    "When loading measurements, no 'Gene Expression' category was found."
                )

        batch_indices, cell_attributes_dict = None, None
        if os.path.exists(os.path.join(path_to_data, barcode_filename)):
            barcodes = pd.read_csv(
                os.path.join(path_to_data, barcode_filename), sep="\t", header=None
            )
            cell_attributes_dict = {
                "barcodes": np.squeeze(np.asarray(barcodes, dtype=str))
            }
            # As of 07/01, 10X barcodes have format "%s-%d" where the digit is a batch index starting at 1
            batch_indices = np.asarray(
                [barcode.split("-")[-1] for barcode in cell_attributes_dict["barcodes"]]
            )
            batch_indices = batch_indices.astype(np.int64) - 1

        logger.info("Finished preprocessing dataset")

        self.populate_from_data(
            X=gene_expression_data,
            batch_indices=batch_indices,
            gene_names=gene_names,
            cell_attributes_dict=cell_attributes_dict,
            Ys=Ys,
        )
        self.filter_cells_by_count()

        # cleanup if required
        if was_extracted and self.remove_extracted_data:
            logger.info("Removing extracted data at {}".format(file_path[:-7]))
            shutil.rmtree(file_path[:-7])
Esempio n. 9
0
    def populate(self):
        logger.info("Preprocessing dataset")

        was_extracted = False
        if len(self.save_path_list) > 0:
            for file_path in self.save_path_list:
                if tarfile.is_tarfile(file_path):
                    logger.info("Extracting tar file")
                    tar = tarfile.open(file_path, "r:gz")
                    tar.extractall(path=self.save_path)
                    was_extracted = True
                    tar.close()

        data_dict = {}
        for file_path in self.save_path_list:
            if file_path.split("_")[-1].split(".")[-1] == "mtx" and self.dense:
                data_dict[available_suffix[file_path.split("_")
                                           [-1]]] = sp_io.mmread(file_path).T
            elif not self.dense and file_path.split("_")[-1].split(
                    ".")[-1] == "mtx":
                data_dict[available_suffix[file_path.split(
                    "_")[-1]]] = csr_matrix(sp_io.mmread(file_path).T)
            else:
                if len(self.save_path_list) == 2:
                    data_dict[available_suffix[file_path.split(
                        "_")[-1]]] = pd.read_csv(file_path,
                                                 sep="\t",
                                                 header=0,
                                                 index_col=0)
                else:
                    data_dict[available_suffix[file_path.split(
                        "_")[-1]]] = pd.read_csv(file_path,
                                                 sep="\t",
                                                 header=None)

        if len(self.save_path_list) == 2:
            temp = data_dict["gene_expression"]
            data_dict["gene_barcodes"] = pd.DataFrame(temp.columns.values)
            data_dict["gene_names"] = pd.DataFrame(temp._stat_axis.values)
            data_dict["gene_expression"] = np.array(temp).T

            temp = data_dict["atac_expression"]
            data_dict["atac_barcodes"] = pd.DataFrame(temp.columns.values)
            data_dict["atac_names"] = pd.DataFrame(temp._stat_axis.values)
            data_dict["atac_expression"] = np.array(temp).T

        #gene_barcode_index = np.array(data_dict["gene_barcodes"]).argsort()
        #gene_barcode_index = data_dict["gene_barcodes"].sort_values(by = ["0"],axis = 0).index.tolist()
        gene_barcode_index = data_dict["gene_barcodes"].values.tolist()
        gene_barcode_index = sorted(range(len(gene_barcode_index)),
                                    key=lambda k: gene_barcode_index[k])
        #gene_barcode_index = data_dict["gene_barcodes"].values.tolist().index(gene_barcode_index)
        temp = data_dict["gene_barcodes"]
        data_dict["gene_barcodes"] = temp.loc[gene_barcode_index, :]
        temp = data_dict["gene_expression"]
        if issparse(temp):
            data_dict["gene_expression"] = temp[gene_barcode_index, :].A
        else:
            data_dict["gene_expression"] = temp[gene_barcode_index, :]

        atac_barcode_index = data_dict["atac_barcodes"].values.tolist()
        atac_barcode_index = sorted(range(len(atac_barcode_index)),
                                    key=lambda k: atac_barcode_index[k])
        temp = data_dict["atac_barcodes"]
        data_dict["atac_barcodes"] = temp.loc[atac_barcode_index, :]
        temp = data_dict["atac_expression"]
        data_dict["atac_expression"] = temp[atac_barcode_index, :]

        #if issparse(temp):
        #    data_dict["atac_expression"] = temp[atac_barcode_index, :].A
        #else:
        #    data_dict["atac_expression"] = temp[atac_barcode_index, :]
        # filter the atac data
        temp = data_dict["atac_expression"]
        # for binary distribution
        if self.is_binary:
            temp_index = temp > 1
            temp[temp_index] = 1
        # end binary
        high_count_atacs = ((temp > 0).sum(axis=0).ravel() >= 0.001 * temp.shape[0])\
                            & ((temp > 0).sum(axis=0).ravel() <= 0.1 * temp.shape[0])
        #high_count_atacs = ((temp > 0).sum(axis=0).ravel() >= 0.19 * temp.shape[0])

        if issparse(temp):
            high_count_atacs_index = np.where(high_count_atacs)
            temp = temp[:, high_count_atacs_index[1]]
            data_dict["atac_expression"] = temp.A
            data_dict["atac_names"] = data_dict["atac_names"].loc[
                high_count_atacs_index[1], :]
        else:
            temp = temp[:, high_count_atacs]
            data_dict["atac_expression"] = temp
            data_dict["atac_names"] = data_dict["atac_names"].loc[
                high_count_atacs, :]
            print(len(temp[temp > 1]))
            print(len(temp[temp < 0]))
        #data_dict["atac_expression"] = temp
        #data_dict["atac_names"] = data_dict["atac_names"].loc[high_count_atacs,:]
        '''
        # ATAC-seq as the key
        Ys = []
        measurement = CellMeasurement(
            name="atac_expression",
            data=data_dict["atac_expression"],
            columns_attr_name="atac_names",
            columns=data_dict["atac_names"].astype(np.str),
        )
        Ys.append(measurement)

        cell_attributes_dict = {
            "barcodes": np.squeeze(np.asarray(data_dict["atac_barcodes"], dtype=str))
        }

        logger.info("Finished preprocessing dataset")

        self.populate_from_data(
            X=data_dict["atac_expression"],
            batch_indices=None,
            gene_names=data_dict["atac_names"].astype(np.str),
            cell_attributes_dict=cell_attributes_dict,
            Ys=Ys,
        )
        self.filter_cells_by_count(datatype=self.datatype)
        '''
        # RNA-seq as the key
        Ys = []
        measurement = CellMeasurement(
            name="atac_expression",
            data=data_dict["atac_expression"],
            columns_attr_name="atac_names",
            columns=data_dict["atac_names"].astype(np.str),
        )
        Ys.append(measurement)

        cell_attributes_dict = {
            "barcodes":
            np.squeeze(np.asarray(data_dict["gene_barcodes"], dtype=str))
        }

        logger.info("Finished preprocessing dataset")

        self.populate_from_data(
            X=data_dict["gene_expression"],
            batch_indices=None,
            gene_names=data_dict["gene_names"].astype(np.str),
            cell_attributes_dict=cell_attributes_dict,
            Ys=Ys,
        )
        self.filter_cells_by_count(datatype=self.datatype)
Esempio n. 10
0
    def populate(self):
        logger.info("Preprocessing data")
        self.expression = pd.read_csv(
            os.path.join(self.save_path, self.filenames.rna),
            index_col=0,
            compression="gzip",
        ).T

        # process protein measurements
        adt = pd.read_csv(
            os.path.join(self.save_path, self.filenames.adt),
            index_col=0,
            compression="gzip",
        )
        protein_names = np.asarray(adt.index).astype(np.str)
        protein_measurement = CellMeasurement(
            name="protein_expression",
            data=adt.T.values,
            columns_attr_name="protein_names",
            columns=protein_names,
        )
        adt_centered = pd.read_csv(
            os.path.join(self.save_path, self.filenames.adt_centered),
            index_col=0,
            compression="gzip",
        )
        if not np.array_equal(
                np.asarray(adt_centered.index).astype(np.str), protein_names):
            raise ValueError(
                "Protein names are not the same for raw and centered counts.")
        protein_measurement_centered = CellMeasurement(
            name="protein_expression_clr",
            data=adt_centered.T.values,
            columns_attr_name="protein_names_clr",
            columns=protein_names,
        )

        # keep only human genes (there are also mouse genes)
        gene_names = np.asarray(self.expression.columns, dtype=str)
        human_filter = np.asarray(
            [name.startswith("HUMAN") for name in gene_names], dtype=np.bool)
        logger.info("Selecting only HUMAN genes ({} / {})".format(
            human_filter.sum(), len(human_filter)))
        X = self.expression.values[:, human_filter]
        gene_names = gene_names[human_filter]
        gene_names = np.asarray(
            [
                name.split("_")[-1] if "_" in name else name
                for name in gene_names
            ],
            dtype=np.str,
        )

        logger.info("Finish preprocessing data")

        self.populate_from_data(
            X=X,
            gene_names=gene_names,
            Ys=[protein_measurement, protein_measurement_centered],
        )

        self.filter_cells_by_count()
Esempio n. 11
0
    def populate(self):
        logger.info("Preprocessing dataset")

        was_extracted = False
        if len(self.save_path_list) > 0:
            for file_path in self.save_path_list:
                if tarfile.is_tarfile(file_path):
                    logger.info("Extracting tar file")
                    tar = tarfile.open(file_path, "r:gz")
                    tar.extractall(path=self.save_path)
                    was_extracted = True
                    tar.close()

        data_dict = {}
        for file_path in self.save_path_list:
            if file_path.split("_")[-1].split(".")[-1] == "mtx" and self.dense:
                data_dict[available_suffix[file_path.split("_")
                                           [-1]]] = sp_io.mmread(file_path).T
            elif not self.dense and file_path.split("_")[-1].split(
                    ".")[-1] == "mtx":
                data_dict[available_suffix[file_path.split(
                    "_")[-1]]] = csr_matrix(sp_io.mmread(file_path).T)
            else:
                if len(self.save_path_list) == 2:
                    data_dict[available_suffix[file_path.split(
                        "_")[-1]]] = pd.read_csv(file_path,
                                                 sep="\t",
                                                 header=0,
                                                 index_col=0)
                else:
                    data_dict[available_suffix[file_path.split("_")
                                               [-1]]] = pd.read_csv(file_path,
                                                                    sep=",",
                                                                    header=0)

        if len(self.save_path_list) == 2:
            temp = data_dict["gene_expression"]
            data_dict["gene_barcodes"] = pd.DataFrame(temp.columns.values)
            data_dict["gene_names"] = pd.DataFrame(temp._stat_axis.values)
            data_dict["gene_expression"] = np.array(temp).T

            temp = data_dict["atac_expression"]
            data_dict["atac_barcodes"] = pd.DataFrame(temp.columns.values)
            data_dict["atac_names"] = pd.DataFrame(temp._stat_axis.values)
            data_dict["atac_expression"] = np.array(temp).T
        else:
            temp = data_dict["gene_barcodes"]
            data_dict["gene_barcodes"] = temp['sample']
            if "group" in temp.columns.values.tolist():
                data_dict["gene_label"] = temp['group']
            elif "replicate" in temp.columns.values.tolist():
                data_dict["gene_label"] = temp['replicate']
            elif "cell_name" in temp.columns.values.tolist():
                data_dict["gene_label"] = temp['cell_name']
            else:
                data_dict["gene_label"] = temp['sample']

            temp = data_dict["atac_barcodes"]
            data_dict["atac_barcodes"] = temp['sample']

            if "group" in temp.columns.values.tolist():
                data_dict["atac_label"] = temp['group']
            elif "replicate" in temp.columns.values.tolist():
                data_dict["atac_label"] = temp['replicate']
            elif "cell_name" in temp.columns.values.tolist():
                data_dict["atac_label"] = temp['cell_name']
            else:
                data_dict["atac_label"] = temp['sample']

            if "source" in temp.columns.values.tolist():
                data_dict["atac_source"] = temp['source']

            temp = data_dict["atac_names"]
            data_dict["atac_names"] = temp['peak']

        #gene_barcode_index = np.array(data_dict["gene_barcodes"]).argsort()
        #gene_barcode_index = data_dict["gene_barcodes"].sort_values(by = ["0"],axis = 0).index.tolist()
        xy, gene_barcode_index, atac_barcode_index = np.intersect1d(
            data_dict["gene_barcodes"].values,
            data_dict["atac_barcodes"].values,
            return_indices=True)
        #gene_barcode_index = data_dict["gene_barcodes"].values.tolist()
        #gene_barcode_index = sorted(range(len(gene_barcode_index)),key = lambda k:gene_barcode_index[k])
        #gene_barcode_index = data_dict["gene_barcodes"].values.tolist().index(gene_barcode_index)
        temp = data_dict["gene_barcodes"]
        data_dict["gene_barcodes"] = temp.loc[gene_barcode_index]
        temp = data_dict["gene_expression"]
        if issparse(temp):
            data_dict["gene_expression"] = temp[gene_barcode_index, :].A
        else:
            data_dict["gene_expression"] = temp[gene_barcode_index, :]
        temp = data_dict["gene_label"]
        data_dict["gene_label"] = temp.loc[gene_barcode_index]

        #atac_barcode_index = data_dict["atac_barcodes"].values.tolist()
        #atac_barcode_index = sorted(range(len(atac_barcode_index)), key=lambda k: atac_barcode_index[k])
        temp = data_dict["atac_barcodes"]
        data_dict["atac_barcodes"] = temp.loc[atac_barcode_index]
        temp = data_dict["atac_expression"]
        data_dict["atac_expression"] = temp[atac_barcode_index, :]
        temp = data_dict["atac_label"]
        data_dict["atac_label"] = temp.loc[atac_barcode_index]
        if "atac_source" in data_dict.keys():
            temp = data_dict["atac_source"]
            data_dict["atac_source"] = temp.loc[atac_barcode_index]

        #if issparse(temp):
        #    data_dict["atac_expression"] = temp[atac_barcode_index, :].A
        #else:
        #    data_dict["atac_expression"] = temp[atac_barcode_index, :]
        # filter the atac data
        temp = data_dict["atac_expression"]
        # for binary distribution
        if self.is_binary:
            temp_index = temp > 1
            temp[temp_index] = 1
        # end binary
        high_count_atacs = ((temp > 0).sum(axis=0).ravel() >= 0.001 * temp.shape[0])\
                            & ((temp > 0).sum(axis=0).ravel() <= 0.1 * temp.shape[0])
        #high_count_atacs = ((temp > 0).sum(axis=0).ravel() >= 0.19 * temp.shape[0])

        if issparse(temp):
            high_count_atacs_index = np.where(high_count_atacs)
            temp = temp[:, high_count_atacs_index[1]]
            data_dict["atac_expression"] = temp.A
            data_dict["atac_names"] = data_dict["atac_names"].loc[
                high_count_atacs_index[1]]
        else:
            temp = temp[:, high_count_atacs]
            data_dict["atac_expression"] = temp
            data_dict["atac_names"] = data_dict["atac_names"].loc[
                high_count_atacs, :]
            print(len(temp[temp > 1]))
            print(len(temp[temp < 0]))
        #data_dict["atac_expression"] = temp
        #data_dict["atac_names"] = data_dict["atac_names"].loc[high_count_atacs,:]
        '''
        # ATAC-seq as the key
        Ys = []
        measurement = CellMeasurement(
            name="atac_expression",
            data=data_dict["atac_expression"],
            columns_attr_name="atac_names",
            columns=data_dict["atac_names"].astype(np.str),
        )
        Ys.append(measurement)

        cell_attributes_dict = {
            "barcodes": np.squeeze(np.asarray(data_dict["atac_barcodes"], dtype=str))
        }

        logger.info("Finished preprocessing dataset")

        self.populate_from_data(
            X=data_dict["atac_expression"],
            batch_indices=None,
            gene_names=data_dict["atac_names"].astype(np.str),
            cell_attributes_dict=cell_attributes_dict,
            Ys=Ys,
        )
        self.filter_cells_by_count(datatype=self.datatype)
        '''
        # RNA-seq as the key
        label = np.zeros(len(data_dict["atac_label"].values))
        if "atac_label" in data_dict.keys():
            temp = data_dict["atac_label"].values.tolist()
            temp1 = dict(zip(temp, range(len(temp))))
            for i, key in zip(range(len(temp1)), temp1.keys()):
                temp1[key] = i
            for i, el in zip(range(len(temp)), temp):
                label[i] = temp1[el]
            if "atac_source" in data_dict.keys():
                temp2 = data_dict["atac_source"].values.tolist()
                temp3 = dict(zip(temp2, range(len(temp2))))
                for i, key in zip(range(len(temp3)), temp3.keys()):
                    temp3[key] = i + len(temp1)
                for i, el in zip(range(len(temp2)), temp2):
                    label[i] *= temp3[el]
                #temp4 = dict(zip(label, range(len(label))))
                #for i, el in zip(range(len(label)), label):
                #label[i] = temp4[el]
        elif "gene_label" in data_dict.keys():
            temp = data_dict["gene_label"].values.tolist()
            temp1 = dict(zip(temp, range(len(temp))))
            for i, key in zip(range(len(temp1)), temp1.keys()):
                temp1[key] = i
            for i, el in zip(range(len(temp)), temp):
                label[i] = temp1[el]
        data_dict["label"] = label

        Ys = []
        measurement = CellMeasurement(
            name="atac_expression",
            data=data_dict["atac_expression"],
            columns_attr_name="atac_names",
            columns=data_dict["atac_names"].astype(np.str),
        )
        Ys.append(measurement)

        cell_attributes_dict = {
            "barcodes":
            np.squeeze(np.asarray(data_dict["gene_barcodes"], dtype=str))
        }

        logger.info("Finished preprocessing dataset")

        self.populate_from_data(X=data_dict["gene_expression"],
                                batch_indices=None,
                                gene_names=data_dict["gene_names"].astype(
                                    np.str),
                                cell_attributes_dict=cell_attributes_dict,
                                Ys=Ys,
                                labels=label,
                                remap_attributes=False)
        self.filter_cells_by_count(datatype=self.datatype)