예제 #1
0
def _main(args):
    samp_cov, samp_mod_cov = mh.parse_bed_methyls(
        args.bed_methyl_files, strand_offset=args.strand_offset)
    with open(args.out_csv, 'w') as gt_fp:
        for (chrom, strand), ctg_cov in samp_cov.items():
            for pos, cov in ctg_cov.items():
                if cov < args.coverage_threshold:
                    continue
                pct_mod = 100 * samp_mod_cov[(chrom, strand)][pos] / cov
                if pct_mod <= args.pct_mod_thresholds[0]:
                    gt_fp.write(','.join(
                        map(str, (chrom, mh.int_strand_to_str(strand), pos,
                                  'False'))) + '\n')
                    if args.strand_offset is not None:
                        gt_fp.write(','.join(
                            map(str, (chrom, mh.int_strand_to_str(strand),
                                      pos + args.strand_offset, 'False'))) +
                                    '\n')
                elif pct_mod >= args.pct_mod_thresholds[1]:
                    gt_fp.write(','.join(
                        map(str, (chrom, mh.int_strand_to_str(strand), pos,
                                  'True'))) + '\n')
                    if args.strand_offset is not None:
                        gt_fp.write(','.join(
                            map(str, (chrom, mh.int_strand_to_str(strand),
                                      pos + args.strand_offset, 'True'))) +
                                    '\n')
예제 #2
0
def write_unsorted_merge(in_fns, out_fp, bar):
    cov, mod_cov = mh.parse_bed_methyls(in_fns)
    for chrm in sorted(
            mh.RefName(chrm) for chrm in set(chrm for chrm, _ in cov)):
        # convert back to string after sorting
        chrm = str(chrm)
        s_poss = []
        if (chrm, 1) in cov:
            s_poss.extend([(pos, 1) for pos in cov[(chrm, 1)]])
        if (chrm, -1) in cov:
            s_poss.extend([(pos, -1) for pos in cov[(chrm, -1)]])
        for pos, strand in sorted(s_poss):
            pcov = cov[(chrm, strand)][pos]
            out_fp.write(
                mods.BEDMETHYL_TMPLT.format(
                    chrom=chrm,
                    pos=pos,
                    end=pos + 1,
                    strand=mh.int_strand_to_str(strand),
                    cov=pcov,
                    score=min(int(pcov), 1000),
                    perc=np.around(mod_cov[(chrm, strand)][pos] / pcov *
                                   100, 1),
                ) + "\n")
            bar.update()
예제 #3
0
    def write_alignment(map_res):
        # convert tuple back to namedtuple
        map_res = MAP_RES(*map_res)
        nalign, nmatch, ndel, nins = [
            0,
        ] * 4
        for op_len, op in map_res.cigar:
            if op not in (4, 5):
                nalign += op_len
            if op in (0, 7):
                nmatch += op_len
            elif op in (2, 3):
                ndel += op_len
            elif op == 1:
                nins += op_len
        bc_len = len(map_res.q_seq)
        q_seq = map_res.q_seq[map_res.q_st:map_res.q_en]

        a = prepare_mapping(
            map_res.read_id,
            q_seq if map_res.strand == 1 else mh.revcomp(q_seq),
            flag=get_map_flag(map_res.strand, map_res.map_num),
            ref_id=map_fp.get_tid(map_res.ctg),
            ref_st=map_res.r_st,
            map_qual=map_res.mapq,
            cigartuples=[(op, op_l) for op_l, op in map_res.cigar],
            tags=[('NM', nalign - nmatch)])
        map_fp.write(a)

        # compute alignment stats
        r_map_summ = MAP_SUMM(
            read_id=map_res.read_id,
            pct_identity=100 * nmatch / float(nalign),
            num_align=nalign,
            num_match=nmatch,
            num_del=ndel,
            num_ins=nins,
            read_pct_coverage=((map_res.q_en - map_res.q_st) * 100 /
                               float(bc_len)),
            chrom=map_res.ctg,
            strand=mh.int_strand_to_str(map_res.strand),
            start=map_res.r_st,
            end=map_res.r_st + nalign - nins,
            query_start=map_res.q_st,
            query_end=map_res.q_en,
            map_sig_start=map_res.map_sig_start,
            map_sig_end=map_res.map_sig_end,
            sig_len=map_res.sig_len,
            map_num=map_res.map_num)
        summ_fp.write(MAP_SUMM_TMPLT.format(r_map_summ))

        if ref_out_info.do_output.pr_refs and read_passes_filters(
                ref_out_info.filt_params, len(map_res.q_seq), map_res.q_st,
                map_res.q_en, map_res.cigar) and map_res.map_num == 0:
            pr_ref_fp.write('>{}\n{}\n'.format(map_res.read_id,
                                               map_res.ref_seq))
예제 #4
0
def _main(args):
    mods_db = mods.ModsDb(
        mh.get_megalodon_fn(args.megalodon_results_dir, mh.PR_MOD_NAME),
        in_mem_dbid_to_uuid=True,
    )
    mods_txt_fp = open(
        mh.get_megalodon_fn(args.megalodon_results_dir, mh.PR_MOD_TXT_NAME)
        if args.out_filename is None else args.out_filename,
        "w",
    )
    mods_txt_fp.write("\t".join(mods_db.text_field_names) + "\n")
    rec_tmplt = "\t".join("{}" for _ in mods_db.text_field_names) + "\n"
    bar = tqdm(
        desc="Processing Per-read Data",
        unit="per-read calls",
        total=mods_db.get_num_uniq_stats(),
        smoothing=0,
        dynamic_ncols=True,
    )
    for (chrm, strand,
         pos), pos_lps in mods_db.iter_pos_scores(convert_pos=True):
        bar.update(len(pos_lps))
        str_strand = mh.int_strand_to_str(strand)
        mod_out_text = ""
        prev_dbid = None
        mod_bs, r_lps = [], []
        for read_dbid, mod_dbid, lp in sorted(pos_lps):
            if prev_dbid != read_dbid and prev_dbid is not None:
                uuid = mods_db.get_uuid(prev_dbid)
                # compute and store log likelihood ratios
                with np.errstate(divide="ignore"):
                    can_lp = np.log1p(-np.exp(r_lps).sum())
                for mod_b, r_lp in zip(mod_bs, r_lps):
                    mod_out_text += rec_tmplt.format(uuid, chrm, str_strand,
                                                     pos, r_lp, can_lp, mod_b)
                mod_bs, r_lps = [], []
            prev_dbid = read_dbid
            mod_bs.append(mods_db.get_mod_base(mod_dbid))
            r_lps.append(lp)
        uuid = mods_db.get_uuid(prev_dbid)
        # compute and store log likelihood ratios
        with np.errstate(divide="ignore"):
            can_lp = np.log1p(-np.exp(r_lps).sum())
        for mod_b, r_lp in zip(mod_bs, r_lps):
            mod_out_text += rec_tmplt.format(uuid, chrm, str_strand, pos, r_lp,
                                             can_lp, mod_b)
        mods_txt_fp.write(mod_out_text)
    mods_txt_fp.close()
예제 #5
0
    def write_alignment(read_id, q_seq, chrm, strand, r_st, q_st, q_en, cigar):
        nalign, nmatch, ndel, nins = [
            0,
        ] * 4
        for op_len, op in cigar:
            if op not in (4, 5):
                nalign += op_len
            if op in (0, 7):
                nmatch += op_len
            elif op in (2, 3):
                ndel += op_len
            elif op == 1:
                nins += op_len
        bc_len = len(q_seq)
        q_seq = q_seq[q_st:q_en]

        a = pysam.AlignedSegment()
        a.query_name = read_id
        a.query_sequence = q_seq if strand == 1 else mh.revcomp(q_seq)
        a.flag = 0 if strand == 1 else 16
        a.reference_id = map_fp.get_tid(chrm)
        a.reference_start = r_st
        a.cigartuples = [(op, op_l) for op_l, op in cigar]
        a.template_length = q_en - q_st
        # add NM tag containing edit distance to the reference
        a.tags = (('NM', nalign - nmatch), )
        map_fp.write(a)

        # compute alignment stats
        r_map_summ = MAP_SUMM(read_id=read_id,
                              pct_identity=100 * nmatch / float(nalign),
                              num_align=nalign,
                              num_match=nmatch,
                              num_del=ndel,
                              num_ins=nins,
                              read_pct_coverage=(q_en - q_st) * 100 /
                              float(bc_len),
                              chrom=chrm,
                              strand=mh.int_strand_to_str(strand),
                              start=r_st,
                              end=r_st + nalign - nins)
        summ_fp.write(MAP_SUMM_TMPLT.format(r_map_summ))
예제 #6
0
def extract_threshs_worker(site_batches_q, thresh_q, low_cov_q, mod_db_fn,
                           gt_cov_min, np_cov_min, target_mod_bases,
                           strand_offset, valid_sites_used):
    def get_gt_cov(chrm, strand, lookup_pos):
        if strand_offset is not None and strand == -1:
            lookup_pos -= strand_offset
        cov_strand = strand if strand_offset is None else None
        return cov[(lookup_pos, cov_strand)], mod_cov[(lookup_pos, cov_strand)]

    def get_pos_thresh(pos_cov, pos_mod_cov, pos_mod_data):
        # TODO add new method here to determine the most likely threshold
        # taking the combination of the ground truth and calibrated modified
        # base scores (estimated from a sample of this data)
        if pos_mod_cov == 0:
            return '{:.4f}'.format(-MAX_MOD_SCORE)
        elif pos_mod_cov == pos_cov:
            return '{:.4f}'.format(MAX_MOD_SCORE)

        # determine fractional threshold
        # collate stats per-read
        pos_stats = defaultdict(dict)
        for read_dbid, mod_dbid, lp in pos_mod_data:
            pos_stats[read_dbid][mod_dbid] = lp
        pos_llrs = []
        with np.errstate(divide='ignore'):
            for read_pos_lps in pos_stats.values():
                mt_lps = list(read_pos_lps.values())
                can_lp = np.log1p(-np.exp(mt_lps).sum())
                valid_mt_lps = [
                    mod_lp for mod_dbid, mod_lp in read_pos_lps.items()
                    if mod_dbid in target_mod_dbids
                ]
                if len(valid_mt_lps) > 0:
                    # take maximum modified base log probability to match
                    # behavior for markup in mods.annotate_all_mods
                    pos_llrs.append(can_lp - max(valid_mt_lps))
        gt_meth_pct = 100.0 * pos_mod_cov / pos_cov
        return '{:.4f}'.format(np.percentile(pos_llrs, gt_meth_pct))

    mods_db = mods.ModsDb(mod_db_fn)
    target_mod_dbids = set(
        mods_db.get_mod_base_dbid(tmb) for tmb in target_mod_bases)
    while True:
        try:
            batch = site_batches_q.get(block=True, timeout=0.1)
        except queue.Empty:
            continue
        if batch is None:
            break

        # process batch of data
        chrm, pos_range, cov, mod_cov = batch
        batch_low_cov = []
        batch_threshs = []
        # iterate score from database grouped by position
        for (chrm, strand, pos), pos_mod_data in mods_db.iter_pos_scores(
                convert_pos=True, pos_range=(chrm, *pos_range)):
            # check that mod calls are to target modbase
            target_mod_cov = len(
                set(read_id for read_id, mod_dbid, _ in pos_mod_data
                    if mod_dbid in target_mod_dbids))
            if target_mod_cov == 0:
                continue
            # convert strand to string for output
            str_strand = mh.int_strand_to_str(strand)
            # extract ground truth coverage
            try:
                pos_cov, pos_mod_cov = get_gt_cov(chrm, strand, pos)
            except KeyError:
                # if valid sites is provided then pos cov will be returned as 0
                # when at a valid site. All sites from mods DB not found in the
                # ground truth batch dicts are thus invalid
                if valid_sites_used:
                    continue
                else:
                    pos_cov = pos_mod_cov = 0
            # if nanopore coverage is not deep enough write to blacklist
            if pos_cov < gt_cov_min or target_mod_cov < np_cov_min:
                score_txt = 'GT_COV:{}'.format(pos_cov) \
                    if pos_cov < gt_cov_min else \
                    'NP_COV:{}'.format(target_mod_cov)
                batch_low_cov.append(
                    BED_TMPLT.format(chrom=chrm,
                                     pos=pos,
                                     end=pos + 1,
                                     strand=str_strand,
                                     score=score_txt))
                continue

            pos_thresh = get_pos_thresh(pos_cov, pos_mod_cov, pos_mod_data)
            batch_threshs.append(
                BED_TMPLT.format(chrom=chrm,
                                 pos=pos,
                                 end=pos + 1,
                                 strand=str_strand,
                                 score=pos_thresh))

        thresh_q.put(''.join(batch_threshs))
        low_cov_q.put(''.join(batch_low_cov))