Beispiel #1
0
def _predict_sample(
        region_ds_path,
        region_dim,
        modalities,
        fillna_by_zero,
        sample,
        output_path,
        mask_cutoff=0.3,
        chunk_size=100000,
):
    # set dask scheduler to allow multiprocessing
    with dask.config.set(scheduler="sync"):
        if region_dim == "query-region":
            model = joblib.load(f"{region_ds_path}/model/train-region_model.lib")
        elif region_dim == "query-dmr":
            model = joblib.load(f"{region_ds_path}/model/train-dmr_model.lib")
        else:
            raise ValueError(
                f'Only accept ["query-region", "query-dmr"], got {region_dim}'
            )

        # for query, we don't filter nan, but drop the nan value in final data table
        region_ds = RegionDS.open(region_ds_path, region_dim=region_dim)
        region_ids = region_ds.get_index(region_dim)

        total_proba = []
        for chunk_start in range(0, region_ids.size, chunk_size):
            use_regions = region_ids[chunk_start: chunk_start + chunk_size]
            _region_ds = region_ds.sel({region_dim: use_regions})
            data, _ = _get_data_and_label(
                region_ds=_region_ds,
                modalities=modalities,
                sample=sample,
                fillna_by_zero_list=fillna_by_zero,
            )
            # before dropna, save the index
            total_index = data.index.copy()

            # sample specific NaN drop
            data.dropna(inplace=True)

            # predict
            proba = model.predict_proba(data.astype(np.float64))
            enhancer_proba = pd.Series(proba[:, 1], index=data.index).reindex(
                total_index
            )
            # NA value has 0 proba
            enhancer_proba.fillna(0, inplace=True)
            total_proba.append(enhancer_proba)

        total_proba = pd.DataFrame({sample: pd.concat(total_proba).astype(np.float16)})
        # mask small values
        total_proba[total_proba < mask_cutoff] = 0
        total_proba.index.name = region_ds.region_dim
        total_proba.columns.name = "sample"

        total_proba = xr.Dataset({f"{region_dim}_prediction": total_proba})
        RegionDS(total_proba).to_zarr(output_path, mode="w")
        return output_path
Beispiel #2
0
def _create_query_dmr_ds(reptile, dmr_regions_bed_df):
    query_dmr_ds = RegionDS.from_bed(
        dmr_regions_bed_df,
        chrom_size_path=reptile.chrom_size_path,
        location=reptile.output_path,
        region_dim="query-dmr",
    )
    query_dmr_ds.save()
    return
Beispiel #3
0
def _create_query_region_ds(reptile):
    pybedtools.BedTool().makewindows(
        g=reptile.chrom_size_path, s=reptile.step_size, w=reptile.window_size
    ).saveas(f"{reptile.output_path}/query_region.bed")

    query_region_ds = RegionDS.from_bed(
        f"{reptile.output_path}/query_region.bed",
        chrom_size_path=reptile.chrom_size_path,
        location=reptile.output_path,
        region_dim="query-region",
    )

    subprocess.run(f"rm -f {reptile.output_path}/query_region.bed", shell=True)
    query_region_ds.save()
    return
Beispiel #4
0
def _create_train_dmr_ds(reptile, train_regions_bed, train_label):
    # total DMRs
    dmr_regions_bed = pybedtools.BedTool(reptile.dmr_regions).sort(
        g=reptile.chrom_size_path
    )
    dmr_regions_bed_df = dmr_regions_bed.to_dataframe()

    # train DMRs and train DMR labels
    train_dmr = train_regions_bed.map(dmr_regions_bed, c=4, o="collapse").to_dataframe()

    dmr_label = defaultdict(list)
    for _, row in train_dmr.iterrows():
        *_, train_region, dmrs = row
        if dmrs == ".":
            continue
        dmrs = dmrs.split(",")
        for dmr in dmrs:
            dmr_label[dmr].append(train_label[train_region])

    # some DMR might have multiple labels
    consistent_dmr_label = {}
    for dmr, dmr_labels in dmr_label.items():
        if (len(dmr_labels) == 1) or (len(set(dmr_labels)) == 1):
            consistent_dmr_label[dmr] = dmr_labels[0]
        else:
            # dmr has in consistent label
            continue
    dmr_label = pd.Series(consistent_dmr_label)
    dmr_label.index.name = "train-dmr"

    train_dmr_regions_bed_df = (
        dmr_regions_bed_df.set_index("name")
            .loc[dmr_label.index]
            .reset_index()
            .iloc[:, [1, 2, 3, 0]]
    )

    # train DMR RegionDS
    train_dmr_ds = RegionDS.from_bed(
        train_dmr_regions_bed_df,
        chrom_size_path=reptile.chrom_size_path,
        location=reptile.output_path,
        region_dim="train-dmr",
    )

    train_dmr_ds.coords["train-dmr_label"] = dmr_label
    train_dmr_ds.save()
    return dmr_regions_bed_df
Beispiel #5
0
def _create_train_region_ds(reptile):
    # train regions
    train_regions_bed = pybedtools.BedTool(reptile.train_regions).sort(
        g=reptile.chrom_size_path
    )

    # train region labels
    train_label = pd.read_csv(
        reptile.train_region_labels, sep="\t", index_col=0, squeeze=True
    )
    train_label.index.name = "train-region"
    train_regions_bed_df = train_regions_bed.to_dataframe()

    # train RegionDS
    train_region_ds = RegionDS.from_bed(
        train_regions_bed_df,
        chrom_size_path=reptile.chrom_size_path,
        location=reptile.output_path,
        region_dim="train-region",
    )

    train_region_ds.coords["train-region_label"] = train_label
    train_region_ds.save()
    return train_regions_bed, train_label
Beispiel #6
0
    def _dump_sample(self, sample, mask_cutoff, bw_bin_size):
        # set dask scheduler to allow multiprocessing
        with dask.config.set(scheduler="sync"):
            dmr_pred = RegionDS.open(self.output_path, region_dim="query-dmr")
            region_pred = RegionDS.open(self.output_path, region_dim="query-region")

            # save DMR prediction proba
            dmr_bed_df = dmr_pred.get_bed(with_id=False)
            dmr_value = dmr_pred.get_feature(
                sample, dim="sample", da_name="query-dmr_prediction"
            )
            dmr_bed_df["score"] = dmr_value
            dmr_bed_df = dmr_bed_df[dmr_bed_df["score"] > mask_cutoff].copy()
            dmr_bed_df.sort_values(["chrom", "start"], inplace=True)
            dmr_bed_df.to_csv(
                f"{self.bigwig_dir}/{sample}_dmr_pred.bg",
                sep="\t",
                index=None,
                header=None,
            )

            # save region prediction proba
            region_bed_df = region_pred.get_bed(with_id=False)
            region_value = region_pred.get_feature(
                sample, dim="sample", da_name="query-region_prediction"
            )
            region_bed_df["score"] = region_value
            region_bed_df = region_bed_df[region_bed_df["score"] > mask_cutoff].copy()
            region_bed_df.sort_values(["chrom", "start"], inplace=True)
            region_bed_df.to_csv(
                f"{self.bigwig_dir}/{sample}_region_pred.bg",
                sep="\t",
                index=None,
                header=None,
            )

            bw_path = f"{self.bigwig_dir}/{sample}_reptile_score.bw"
            with pyBigWig.open(bw_path, "w") as bw:
                chrom_sizes = pd.read_csv(
                    self.chrom_size_path,
                    sep="\t",
                    index_col=0,
                    header=None,
                    squeeze=True,
                ).to_dict()
                bw.addHeader(
                    [(k, v) for k, v in pd.Series(chrom_sizes).sort_index().items()]
                )

                p = subprocess.run(
                    f"bedtools unionbedg -i "
                    f"{self.bigwig_dir}/{sample}_dmr_pred.bg "
                    f"{self.bigwig_dir}/{sample}_region_pred.bg",
                    shell=True,
                    check=True,
                    stdout=subprocess.PIPE,
                    encoding="utf8",
                )

                cur_bin = 0
                cur_scores = [0]
                for line in p.stdout.split("\n"):
                    if line == "":
                        continue
                    if line[-1] == ".":
                        # no score
                        continue

                    chrom, start, end, *scores = line.split("\t")
                    score = max(map(float, scores))
                    start_bin = int(start) // bw_bin_size
                    end_bin = int(end) // bw_bin_size + 1

                    for bin_id in range(start_bin, end_bin):
                        if bin_id > cur_bin:
                            # save previous bin
                            cur_pos = cur_bin * bw_bin_size
                            mean_score = sum(cur_scores) / len(cur_scores)
                            try:
                                bw.addEntries(
                                    chrom,
                                    [cur_pos],
                                    values=[mean_score],
                                    span=bw_bin_size,
                                )
                            except RuntimeError as e:
                                print(chrom, cur_pos, mean_score, bw_bin_size)
                                raise e

                            # init new bin
                            cur_bin = bin_id
                            cur_scores = [score]
                        elif bin_id == cur_bin:
                            # the same bin, take average
                            cur_scores.append(score)
                        else:
                            # no score, initial state
                            pass

                # final
                cur_pos = cur_bin * bw_bin_size
                mean_score = sum(cur_scores) / len(cur_scores)
                try:
                    bw.addEntries(
                        chrom, [cur_pos], values=[mean_score], span=bw_bin_size
                    )
                except RuntimeError as e:
                    print(chrom, cur_pos, mean_score, bw_bin_size)
                    raise e

            subprocess.run(
                f"rm -f {self.bigwig_dir}/{sample}_dmr_pred.bg "
                f"{self.bigwig_dir}/{sample}_region_pred.bg",
                shell=True,
            )
        return bw_path
Beispiel #7
0
 def query_dmr_ds(self):
     if self._query_dmr_ds is None:
         self._query_dmr_ds = RegionDS.open(self.output_path, region_dim="query-dmr")
     return self._query_dmr_ds
Beispiel #8
0
 def query_region_ds(self):
     if self._query_region_ds is None:
         self._query_region_ds = RegionDS.open(
             self.output_path, region_dim="query-region", engine="zarr"
         )
     return self._query_region_ds
Beispiel #9
0
 def train_dmr_ds(self):
     if self._train_dmr_ds is None:
         self._train_dmr_ds = RegionDS.open(
             self.output_path, region_dim="train-dmr", engine="zarr"
         )
     return self._train_dmr_ds