Example #1
0
def mod_align(indirs,
              ref,
              outdir,
              threads,
              recursive=False,
              bamdir="minimap2"):
    """Align *.fq.gz files from input dirs and store sorted .bam in outdir/minimap2"""
    logger("Aligning FastQ files from %s directories...\n" % len(indirs))
    # prepare output directory
    if not os.path.isdir(os.path.join(outdir, bamdir)):
        os.makedirs(os.path.join(outdir, bamdir))
    # assert same setting has been used between different runs
    compare_info(indirs)
    # process indirs
    for indir in indirs:
        # run alignment if BAM doesn't exist already
        outfn = os.path.join(outdir, bamdir, indir.split("/")[-2] + ".bam")
        if os.path.isfile(outfn):
            logger(" %s already present" % outfn)
            continue
        logger("  > %s\n" % outfn)
        if recursive:
            fq = sorted(map(str, Path(indir).rglob('*.fq.gz')))
        else:
            fq = sorted(map(str, Path(indir).glob('*.fq.gz')))
        run_minimap2(ref, fq, outfn, threads, spliced=is_rna(indir))
        # update dump info
        dump_updated_info(indir, outdir, outfn)
Example #2
0
def plot_venn(outfn, beds, names=[], title=""):
    """Plot venn diagram"""
    import venn
    # select plotting function
    if len(beds) == 2: func = venn.venn2
    elif len(beds) == 3: func = venn.venn3
    elif len(beds) == 4: func = venn.venn4
    elif len(beds) == 5: func = venn.venn5
    elif len(beds) == 6: func = venn.venn6
    else:
        logger("[ERROR] Please provide between 2 and 6 BED files\n")
        sys.exit(1)
    # use fnames as names if names not given
    if len(names) != len(beds): names = beds
    # load positions & plot
    labels = venn.get_labels([get_positions_from_bed(bed)
                              for bed in beds])  #, fill=['number', 'logic']
    fig, ax = func(labels, names=names)
    # add title
    if title: plt.title(title)
    # and save or visualise plot
    if outfn: plt.savefig(outfn)
    else: plt.show()
Example #3
0
def main():
    import argparse
    usage = "%(prog)s -v"  #usage=usage,
    parser  = argparse.ArgumentParser(description=desc, epilog=epilog, \
                                      formatter_class=argparse.RawTextHelpFormatter)

    parser.add_argument('--version', action='version', version=VERSION)
    parser.add_argument("-v", "--verbose", action="store_true", help="verbose")
    parser.add_argument("-i",
                        "--input",
                        default="modPhred/mod.gz",
                        help="input file [%(default)s]")
    parser.add_argument("-m",
                        "--mapq",
                        default=15,
                        type=int,
                        help="min mapping quality [%(default)s]")
    parser.add_argument("-d",
                        "--minDepth",
                        default=25,
                        type=int,
                        help="min depth of coverage [%(default)s]")
    parser.add_argument(
        "--minModFreq",
        default=0.20,
        type=float,
        help="min modification frequency per position [%(default)s]")
    parser.add_argument(
        "--minModProb",
        default=0.50,
        type=float,
        help="min modification probability per base [%(default)s]")
    #parser.add_argument("-s", "--strand", default=None, choices=["+", "-"], help="select strand [include both]")
    parser.add_argument("--mod",
                        default="",
                        help="filter only 1 modification [analyse all]")
    parser.add_argument("-r",
                        "--regions",
                        nargs="+",
                        default=[],
                        help="regions to process [all chromosomes]")
    parser.add_argument("-w",
                        "--overwrite",
                        action="store_true",
                        help="overwrite existing output")
    parser.add_argument("-e",
                        "--ext",
                        default="svg",
                        help="figure format/extension [%(default)s]")

    o = parser.parse_args()
    if o.verbose:
        sys.stderr.write("Options: %s\n" % str(o))

    if not o.regions:
        logger(
            "Processing entire chromosomes - consider narrowing to certain regions!"
        )

    mod_correlation(o.input,
                    ext=o.ext,
                    mapq=o.mapq,
                    overwrite=o.overwrite,
                    regions=o.regions,
                    mod=o.mod,
                    minfreq=o.minModFreq,
                    mindepth=o.minDepth,
                    minModProb=o.minModProb)
    logger("Finished\n")
Example #4
0
def plot_heatmap(corrs,
                 chrdata,
                 ref,
                 outfn,
                 figsize=(12, 10),
                 dim=100000,
                 ext="svg",
                 simpleY=True):
    """Plot heatmaps"""
    # narrow by strands
    xlab = chrdata.pos[:dim].to_numpy()
    ylab = [
        "%s %s" % (m, s)
        for m, s in zip(chrdata["mod"][:dim], chrdata.strand[:dim])
    ]
    mod2count = Counter(chrdata["mod"])
    xlab, ylab, pos = collapse_axes(xlab, ylab)
    # use unique names on Y
    if simpleY:
        _ylab = [
            ylab[i] if ylab[i] != ylab[i - 1] else ""
            for i in range(1, len(ylab))
        ]
        _ylab.insert(0, ylab[0])
        ylab = _ylab
    # switch axes labels
    #xlab, ylab = ylab, xlab
    logger(" Plotting %s modified positions in %s:%s-%s" %
           (len(chrdata), ref, xlab[0], xlab[-1]))
    #mask = np.zeros_like(corrs)
    #mask[np.triu_indices_from(mask)] = True
    #f, ax = plt.subplots(figsize=figsize)
    fig = plt.figure(figsize=figsize)
    # add title
    mcounts = "; ".join("%s: %s" % (m, c) for m, c in mod2count.items())
    fig.suptitle("\n%s\n%s modifications: %s" % (ref, len(chrdata), mcounts))
    fig.subplots_adjust(top=0.75)
    with sns.axes_style("white"):
        ax = sns.heatmap(
            corrs,
            vmin=-1,
            vmax=1,
            center=0,
            cmap="RdBu_r",  #mask=mask, #fmt="d"
            xticklabels=xlab,
            yticklabels=ylab)
        #ax.xaxis.tick_top()
    ax.set_xlabel("Modified positions")
    ax.set_ylabel("Modifications at those positions [with +/- strand]")
    # calculate global frequency
    depthcols = list(filter(lambda c: c.endswith("depth"), chrdata.columns))
    mfreqcols = list(
        filter(lambda x: x.endswith('mod_frequency'), chrdata.columns))
    mcountcols = ["%s modcount" % c.split()[0] for c in depthcols]
    for cc, fc, dc in zip(mcountcols, depthcols, mfreqcols):
        chrdata[cc] = chrdata[fc] * chrdata[dc]
    chrdata["avgfreq"] = chrdata[mcountcols].sum(
        axis=1) / chrdata[depthcols].sum(axis=1)
    ax2 = fig.add_axes([.125, 0.77, .62, .10], anchor="N", sharex=ax)
    ax2.bar(np.arange(len(corrs)) + 0.5, chrdata["avgfreq"].to_numpy()[pos])
    ax2.set_ylim((0, 1))
    ax2.set_ylabel("mod\nfreq")
    ax2.xaxis.tick_top()  #ax2.set_xticklabels([]) #
    ax2.set_xticklabels(xlab, rotation=90)
    fig.savefig(outfn + ".%s" % ext)  #show()
Example #5
0
def mod_correlation(infn,
                    ext="png",
                    logger=logger,
                    data=False,
                    overwrite=False,
                    regions=[],
                    samples=[],
                    minfreq=0.20,
                    mindepth=10,
                    minModProb=0.5,
                    mapq=15,
                    strand=None,
                    mod=None):
    outdir = os.path.join(os.path.dirname(infn), "correlations")
    if not os.path.isdir(outdir):
        os.makedirs(outdir)
    # load info
    moddata = load_info(os.path.dirname(infn))
    MaxPhredProb = moddata["MaxPhredProb"]
    bamfiles = moddata["bam"]
    bamfiles.sort()
    # BAM > modifications
    # get can2mods ie {'A': ['6mA'], 'C': ['5mC'], 'G': [], 'T': []}
    can2mods = {
        b: [moddata["symbol2modbase"][m] for m in mods]
        for b, mods in moddata["canonical2mods"].items()
    }  #; print(can2mods)
    #print(MaxPhredProb, can2mods, bamfiles)

    # parse data
    if isinstance(data, bool):
        logger("Loading %s ...\n" % infn)
        data = pd.read_csv(infn,
                           sep="\t",
                           header=len(HEADER.split('\n')) - 2,
                           index_col=False,
                           dtype={
                               "chr": object,
                               "pos": int
                           })  # ADD TO ALL
    # filter by min freq and depth
    mfreqcols = list(
        filter(lambda x: x.endswith('mod_frequency'), data.columns))
    mfreqcols
    depthcols = list(filter(lambda x: x.endswith('depth'), data.columns))
    depthcols
    filters = [
        data.loc[:, mfreqcols].max(axis=1) > minfreq,
        data.loc[:, depthcols].max(axis=1) > mindepth
    ]
    # add filters for strand and modification
    if mod:
        filters.append(data["mod"] == mod)
    data = data[np.all(filters, axis=0)]
    #print(data.shape, data.head())
    # limit by region AND CONSIDER LIMITING COV TO 2-3x median?
    if regions:
        # get regions logger(" limiting to %s regions: %s\n"%(len(regions), ",".join(regions)))
        regionsData = get_data_for_regions(data, regions)
        logger("Processing %s region(s): %s ...\n" %
               (len(regions), ",".join(regions)[:3]))
    else:
        # get chromosomes
        regions = data.chr.unique()
        #if strand: filters.append(data.strand==strand)
        regionsData = (data[data.chr == ref] for ref in regions)
        logger("Processing %s chromosome(s): %s ...\n" %
               (len(regions), ",".join(regions)[:3]))
    if data.shape[0] < 1:
        logger("[mod_plot][ERROR]  %s row(s) found in %s\n" %
               (data.shape[0], infn))
        return
    # process regions/chromosomes
    for ref, chrdata in zip(regions, regionsData):
        # define output
        fn = "%s.csv.gz" % ref
        if mod: fn = "%s.%s.csv.gz" % (ref, mod)
        outfn = os.path.join(outdir, fn)
        if overwrite or not os.path.isfile(outfn):
            # generate data
            corrs = chr2modcorr(outfn, bamfiles, ref, chrdata, mapq, mindepth,
                                minModProb, MaxPhredProb)
        else:
            # load data
            corrs = np.loadtxt(outfn, delimiter=",")
        # plot
        plot_heatmap(corrs, chrdata, ref, outfn, ext=ext)
Example #6
0
def chr2modcorr(outfn,
                bams,
                region,
                chrdata,
                mapq,
                mindepth,
                minModProb,
                MaxPhredProb,
                minmodreads=10):
    """Calculate correlation between modifications"""
    cols = ["chr", "pos", "mod", "strand"]
    # get chr and positions
    ref, positions = chrdata.chr.unique()[0], np.unique(chrdata.pos.to_numpy())
    corrs = np.zeros((len(positions), len(positions)), dtype="float32")
    corrs[:] = np.nan
    logger(" %s with %s modified positions > %s" %
           (region, len(positions), outfn))
    parsers = [
        bam2calls(bam, ref, positions, mapq, minModProb, MaxPhredProb)
        for bam in bams
    ]
    for i, calls in enumerate(zip(*parsers)):
        # stack all reads - those are already prefiltered for only those modified for given position
        calls = np.hstack(calls)
        sys.stderr.write(" %s  \r" % i)
        # get modified positions
        #mod = calls!=255
        # get positions with mindepth
        enoughdepth = np.where(np.sum(calls > 0, axis=1) >= mindepth)[0]
        print(Counter(calls[i]), enoughdepth.sum())
        # store correlations between this positions
        for j in filter(lambda x: x >= i, enoughdepth):
            # mod in i and j and take balanced number of modified reads for each position
            modi = np.argwhere(
                np.all((calls[i] > 0, calls[i] < 255, calls[j] > 0),
                       axis=0))  #calls[i]>0, calls[i]<255 #calls[i]==1
            modj = np.argwhere(
                np.all((calls[j] > 0, calls[j] < 255, calls[i] > 0),
                       axis=0))  #calls[j]>0, calls[j]<255 #calls[j]==1
            lessmod = min(len(modi), len(modj))
            sel = np.unique(list(modi[:lessmod]) + list(modj[:lessmod]))
            if len(sel) < minmodreads: continue
            '''# take only those reads that are modified in any of the samples
            sel = np.all((np.any(mod[[i, j]], axis=0), calls[j]>0), axis=0)
            # skip if less than 10 reads with modification at least in 1 position
            if sel.sum()<minmodreads:
                continue'''
            # get those that are modified in both - unmod are 255
            same = calls[i, sel] == calls[j, sel]
            corr = 2 * (np.mean(same) - 0.5) if np.any(
                same
            ) else -1  #; print(chrdata[chrdata.pos==positions[j]][cols].to_numpy(), i, j, corr, same.sum(), sel.sum())
            corrs[i, j] = corrs[j, i] = corr
        '''# get positions modified for all positions  
        modreads = np.all((calls>0, calls<255), axis=0)
        modreadsum = modreads.sum(axis=1) 
        # get positions with mindepth
        enoughdepth = np.where(np.sum(calls>0, axis=1)>=mindepth)[0]
        # store correlations between this positions
        for j in filter(lambda x: x>=i, enoughdepth):
            # choose _i that it always has more modified positions than _j
            if np.argmax(modreadsum[[i, j]]):
                _j, _i = i, j
            else:
                _i, _j = i, j
            # select only reads modified in the position with more modifications
            # and with bases called in both
            sel = np.all((modreads[_i], calls[_j]>0), axis=0) #; print(i, j, modreadsum[[i, j]])
            # skip if less than 10 reads with modification at least in 1 position
            if sel.sum()<minmodreads:
                continue
            # get those that are modified in both - unmod are 255
            same = calls[_j, sel]!=255 # calls[i, sel] == calls[j, sel]
            corr = 2*(np.mean(same)-0.5) if np.any(same) else -1 #; print(chrdata[chrdata.pos==positions[j]][cols].to_numpy(), i, j, corr, same.sum(), sel.sum())
            corrs[i, j] = corrs[j, i] = corr'''
        #return
    #print(corrs[:10,:10]); return
    # store array and
    np.savetxt(outfn,
               corrs,
               fmt='%.3f',
               header=",".join(map(str, positions)),
               delimiter=',',
               footer=",".join(
                   "%s%s" % (m, s)
                   for m, s in zip(chrdata["mod"], chrdata.strand)))
    return corrs
Example #7
0
def bam2calls(bams,
              chrdata,
              mapq,
              minModProb,
              MaxPhredProb,
              can2mods,
              maxDepth=100,
              minStored=0,
              minAlgFrac=0.3,
              strand=None):
    """Generator of basecalls and mod qualities from BAM file encoded as floats for +/- strand"""
    ref = chrdata.chr.iloc[0]
    start, end = chrdata.pos.iloc[0] - 1, chrdata.pos.iloc[-1] + 1
    pos2base = {
        r.pos: r.ref_base
        for idx, r in chrdata.iterrows() if can2mods[r.ref_base]
    }
    # prepare data frame with reads in rows x modified positions in columns
    columns = pd.MultiIndex.from_frame(chrdata[['pos', 'mod', "strand"]])
    df = pd.DataFrame(np.zeros((maxDepth * len(bams), chrdata.shape[0])),
                      columns=columns)  #; print(df.head())
    # add sample as last column
    df["sample"] = df["read_strand"] = ""
    logger(" %s:%s-%s %s with %s modified positions" %
           (ref, start, end, strand, len(pos2base)))
    # process bam files
    reads = []
    for bi, bam in enumerate(bams, 1):
        # get sample name modPhred/curlcakes/minimap2/RNA010220191_m5C.bam -> RNA010220191_m5C
        sample = os.path.basename(bam)[:-4]  #; print(sample)
        sys.stderr.write("  %s / %s %s > %s              \r" %
                         (bi, len(bams), bam, sample))
        readi = 0
        sam = pysam.AlignmentFile(bam)
        for a in sam.fetch(ref, start, end):
            # skip low quality alignments, not primary, QC fails, duplicates or supplementary algs
            # or algs shorter than 30% of the region
            s, e = max(start, a.pos), min(end, a.aend)
            if is_qcfail(a, mapq) or (e - s) < minAlgFrac * (end - start):
                continue
            # skip reads from wrong strand
            if strand == "+" and a.is_reverse or strand == "-" and not a.is_reverse:
                continue
            # store modifications from alignment blocks
            df, stored = store_blocks(a, start, end, df, reads, minModProb,
                                      MaxPhredProb, can2mods, pos2base)
            # keep read if enough mods
            if stored >= minStored:
                df.loc[len(reads),
                       "read_strand"] = "-" if a.is_reverse else "+"
                reads.append(a.qname)  #"%s: %s"%(sample, a.qname))
                readi += 1
            else:
                df.loc[len(reads)] = 0
            # process up to maxDepth of reads per BAM file
            if readi == maxDepth:
                break
        # store samples
        df.loc[len(reads) - readi:len(reads),
               "sample"] = sample  #.split("_")[-1]
    # strip rows that have no reads & rename rows
    df = df.iloc[:len(reads)]
    df.index = reads
    return df
Example #8
0
def main():
    import argparse
    usage = "%(prog)s -v"  #usage=usage,
    parser  = argparse.ArgumentParser(description=desc, epilog=epilog, \
                                      formatter_class=argparse.RawTextHelpFormatter)
    # https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.linkage.html
    methods = [
        "single", "ward", "complete", "average", "weighted", "centroid",
        "median"
    ]
    # https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.pdist.html
    metrics = [
        "braycurtis", "canberra", "chebyshev", "cityblock", "correlation",
        "cosine", "dice", "euclidean", "hamming", "jaccard", "jensenshannon",
        "kulsinski", "mahalanobis", "matching", "minkowski", "rogerstanimoto",
        "russellrao", "seuclidean", "sokalmichener", "sokalsneath",
        "sqeuclidean", "yule"
    ]
    parser.add_argument('--version', action='version', version=VERSION)
    parser.add_argument("-v", "--verbose", action="store_true", help="verbose")
    parser.add_argument("-i",
                        "--input",
                        default="modPhred/mod.gz",
                        help="input file [%(default)s]")
    parser.add_argument("-m",
                        "--mapq",
                        default=15,
                        type=int,
                        help="min mapping quality [%(default)s]")
    parser.add_argument("-d",
                        "--mindepth",
                        default=25,
                        type=int,
                        help="min depth of coverage [%(default)s]")
    parser.add_argument("--maxDepth",
                        default=100,
                        type=int,
                        help="max depth of coverage [%(default)s]")
    parser.add_argument(
        "--minfreq",
        "--minModFreq",
        default=0.20,
        type=float,
        help="min modification frequency per position [%(default)s]")
    parser.add_argument(
        "--minModProb",
        default=0.50,
        type=float,
        help="min modification probability per base [%(default)s]")
    parser.add_argument(
        "--minAlgFrac",
        default=0.80,
        type=float,
        help="min fraction of read aligned to the region [%(default)s]")
    parser.add_argument("--col_cluster",
                        action="store_true",
                        help="cluster also columns")
    parser.add_argument("--colors",
                        default="cmybrgwk",
                        help="colors for plotting [%(default)s]")
    #parser.add_argument("-s", "--strand", default=None, choices=["+", "-"], help="select strand [include both]")
    parser.add_argument("--mod",
                        default="",
                        help="filter only 1 modification [analyse all]")
    parser.add_argument("--method",
                        default="ward",
                        choices=methods,
                        help="method used for clustering [%(default)s]")
    parser.add_argument("--metric",
                        default="euclidean",
                        choices=metrics,
                        help="metric used for clustering [%(default)s]")
    parser.add_argument("-r",
                        "--regions",
                        nargs="+",
                        default=[],
                        help="regions to process [all chromosomes]")
    parser.add_argument("-w",
                        "--overwrite",
                        action="store_true",
                        help="overwrite existing output")
    parser.add_argument("-e",
                        "--ext",
                        default="svg",
                        help="figure format/extension [%(default)s]")

    o = parser.parse_args()
    if o.verbose:
        sys.stderr.write("Options: %s\n" % str(o))

    if not o.regions:
        logger(
            "Processing entire chromosomes - consider narrowing to certain regions!"
        )

    mod_cluster(o.input, args=o, ext=o.ext)
    logger("Finished\n")
Example #9
0
def mod_cluster(infn,
                args,
                data=False,
                ext="png",
                logger=logger,
                read_strand_lut={
                    "+": "r",
                    "-": "b"
                }):
    """Cluster reads base on their modification profiles"""
    regions = args.regions
    outdir = os.path.join(os.path.dirname(infn), "clusters")
    if not os.path.isdir(outdir):
        os.makedirs(outdir)
    # load info
    moddata = load_info(os.path.dirname(infn))
    MaxPhredProb = moddata["MaxPhredProb"]
    bamfiles = moddata["bam"]
    bamfiles.sort()
    # BAM > modifications
    # get can2mods ie {'A': ['6mA'], 'C': ['5mC'], 'G': [], 'T': []}
    can2mods = {
        b: [moddata["symbol2modbase"][m] for m in mods]
        for b, mods in moddata["canonical2mods"].items()
    }
    if "U" in can2mods:
        can2mods["T"] = can2mods["U"]  #; print(can2mods)
    #print(MaxPhredProb, can2mods, bamfiles)

    # parse data
    if isinstance(data, bool):
        logger("Loading %s ...\n" % infn)
        data = pd.read_csv(infn,
                           sep="\t",
                           header=len(HEADER.split('\n')) - 2,
                           index_col=False,
                           dtype={
                               "chr": object,
                               "pos": int
                           })  # ADD TO ALL
    # filter by min freq and depth
    mfreqcols = list(
        filter(lambda x: x.endswith('mod_frequency'), data.columns))
    mfreqcols
    depthcols = list(filter(lambda x: x.endswith('depth'), data.columns))
    depthcols
    filters = [
        data.loc[:, mfreqcols].max(axis=1) > args.minfreq,
        data.loc[:, depthcols].max(axis=1) > args.mindepth
    ]
    # add filters for strand and modification
    if args.mod:
        filters.append(data["mod"] == args.mod)
    data = data[np.all(filters, axis=0)]
    #print(data.shape, data.head())
    # limit by region AND CONSIDER LIMITING COV TO 2-3x median?
    if regions:
        # get regions logger(" limiting to %s regions: %s\n"%(len(regions), ",".join(regions)))
        regions, regionsData, strands = get_data_for_regions(data, regions)
        logger("Processing %s region(s): %s ...\n" %
               (len(regions), ",".join(regions)[:3]))
    else:
        # get chromosomes
        regions = data.chr.unique()
        strands = [None] * len(regions)
        #if strand: filters.append(data.strand==strand)
        regionsData = (data[data.chr == ref] for ref in regions)
        logger("Processing %s chromosome(s): %s ...\n" %
               (len(regions), ",".join(regions)[:3]))
    if data.shape[0] < 1:
        logger("[mod_plot][ERROR]  %s row(s) found in %s\n" %
               (data.shape[0], infn))
        return

    # process regions/chromosomes
    ## this can be run in parallel easily, but is it worth the effort really?
    for ref, chrdata, strand in zip(regions, regionsData, strands):
        # define output
        fn = "%s" % ref
        if args.mod: fn = "%s.%s" % (ref, args.mod)
        outfn = os.path.join(outdir, fn)
        # get data frame
        df = bam2calls(bamfiles,
                       chrdata,
                       args.mapq,
                       args.minModProb,
                       MaxPhredProb,
                       can2mods,
                       maxDepth=args.maxDepth,
                       minAlgFrac=args.minAlgFrac,
                       strand=strand)
        #print(df.head()); print(df.columns); print(df.index)
        # get colors for rows and columns
        samples = df.pop("sample")  #; print(samples.unique())
        pal = sns.cubehelix_palette(samples.unique().size,
                                    light=.9,
                                    dark=.1,
                                    reverse=True,
                                    start=1,
                                    rot=-2)
        lut = dict(
            zip(sorted(samples.unique(), key=lambda x: x.split("_")[-1]), pal))
        row_colors = [
            samples.map(lut),
        ]
        read_strand = df.pop("read_strand")  #; print(read_strand.unique())
        if len(read_strand.unique()) > 1:
            row_colors.append(read_strand.map(read_strand_lut))
        mods = df.columns.get_level_values("mod")
        mods_lut = dict(zip(sorted(mods.unique()), args.colors))
        col_colors = mods.map(mods_lut)
        # https://seaborn.pydata.org/generated/seaborn.clustermap.html
        g = sns.clustermap(
            df,
            cmap="Blues",
            method=args.method,
            metric=args.metric,
            figsize=(16, 10),
            col_cluster=args.col_cluster,
            col_colors=col_colors,
            row_colors=row_colors,
            xticklabels=False,
            yticklabels=False,
            cbar_kws={
                'label': 'Modification probability',
                'orientation': 'horizontal'
            },
        )
        # add legends
        l1 = g.fig.legend(
            loc='lower left',
            bbox_to_anchor=(0.01, 0.8),
            frameon=True,
            title="Sample",
            handles=[mpatches.Patch(color=c, label=l) for l, c in lut.items()])
        l2 = g.fig.legend(loc='lower right',
                          bbox_to_anchor=(1.0, 0.8),
                          frameon=True,
                          ncol=10,
                          title="Modification",
                          handles=[
                              mpatches.Patch(color=c, label=l)
                              for l, c in mods_lut.items()
                          ])
        if len(read_strand.unique()) > 1:
            l3 = g.fig.legend(loc='upper left',
                              bbox_to_anchor=(0.01, 0.8),
                              frameon=True,
                              ncol=1,
                              title="Read strand",
                              handles=[
                                  mpatches.Patch(color=c, label=l)
                                  for l, c in read_strand_lut.items()
                              ])
        # adjust legend horizontal and top left corner
        g.cax.set_position([.01, .05, .15, .03])  #[.8, .96, .2, .03])
        # set title and labels for axes
        g.fig.suptitle(
            "Clustermap for %s %s %s" %
            (ref, args.method, args.mod))  #g.ax_col_dendrogram.set_title(
        g.ax_heatmap.set_xlabel("Modified positions")
        g.ax_heatmap.set_ylabel("Reads")
        # save
        g.savefig("%s.clustermap.%s.%s" % (outfn, args.method, ext))

        # pca
        pc_op = PCA()
        data_pcs = pc_op.fit_transform(g.data)
        fig, ax = plt.subplots(1, figsize=(6, 5))
        # plot explained variance as a fraction of the total explained variance
        ax.plot(np.arange(1,
                          len(pc_op.explained_variance_) + 1),
                pc_op.explained_variance_ / pc_op.explained_variance_.sum())
        ax.set_xlabel('Component number')
        ax.set_ylabel('Fraction of explained variance')
        ax.set_title('Scree plot for %s' % ref)
        #fig.tight_layout()
        fig.savefig("%s.scree.%s.%s" % (outfn, args.method, ext))  #show()
        ax.set_xlim((0, 10))
        fig.savefig("%s.scree_10.%s.%s" % (outfn, args.method, ext))  #show()
        # clean-up
        g.fig.clear()
        plt.close()
Example #10
0
def plot_scatter(infn,
                 ext="png",
                 logger=logger,
                 data=False,
                 region="",
                 samples=[],
                 features=[
                     "depth", "basecall_accuracy", "mod_frequency",
                     "median_mod_prob"
                 ]):
    """Plot scatter using seaborn"""
    # make sure outfn exists
    if not os.path.isfile(infn):
        logger(
            "[mod_plot][ERROR] File %s does not exists! Have you run mod_report.py?\n"
            % infn)
        sys.exit(1)
    # get outdir
    outdir = os.path.join(os.path.dirname(infn), "plots")
    if not os.path.isdir(outdir):
        os.makedirs(outdir)
    logger("Saving plots for %s to %s ...\n" % (", ".join(features), outdir))

    # parse data
    if isinstance(data, bool):
        logger("Loading %s ...\n" % infn)
        data = pd.read_csv(infn,
                           sep="\t",
                           header=len(HEADER.split('\n')) - 2,
                           index_col=False,
                           dtype={
                               "chr": object,
                               "pos": int
                           })  # ADD TO ALL
    # limit by region AND CONSIDER LIMITING COV TO 2-3x median?
    if region:
        logger(" limiting to %s ...\n" % region)
        chrom, s, e = region, 0, 0
        if "-" in region:
            chrom, se = region.split(':')
            s, e = map(int, se.split("-"))
        data = data[data.chr == chrom]
        if e:
            data = data[s <= data.pos <= e]
    if data.shape[0] < 1:
        logger("[mod_plot][ERROR]  %s row(s) found in %s\n" %
               (data.shape[0], infn))
        return
    # rename .bam columns to basename and split it at _ with \n
    data.columns = [
        os.path.basename(c).replace("_", "\n") if ".bam" in c else c
        for c in data.columns
    ]  #; data.head()
    # plot features
    for feature in features:
        # match feature by replacing _ with \n
        cols = list(
            filter(lambda c: feature.replace("_", "\n") in c,
                   data.columns))  #; print(cols)
        # limit by sample
        if samples:
            scols = set(c for c in cols for s in samples
                        if s.replace("_", "\n") in c)
            cols = list(sorted(scols))  #; print(cols)
        # plot palette="husl",
        g = sns.pairplot(data,
                         vars=cols,
                         height=4,
                         hue='mod',
                         diag_kind='kde',
                         plot_kws={
                             'alpha': 0.1,
                             's': 3,
                         })  #'edgecolor': 'k'
        outfn = os.path.join(outdir, "%s.%s" % (feature, ext))
        # add figure title
        g.fig.suptitle("%s\n%s" % (feature, infn), size=16)
        g.fig.subplots_adjust(top=.90, right=0.95)
        # make legend in top right corner and increase marker size
        g._legend.set_bbox_to_anchor((0.10, 0.95))
        g._legend.set_title("")  #"mods:"
        for lh in g._legend.legendHandles:
            lh._sizes = [20]  #lh.set_alpha(1)
        # set axes limits 0-1
        if feature != "depth":
            for r in g.axes:
                for ax in r:
                    ax.set_xlim((0, 1))
                    ax.set_ylim((0, 1))
        # save
        g.fig.savefig(outfn)
Example #11
0
def mod_plot_bases(infn, ext="svg", logger=logger, data=False):
    """Generate violin plots

    If data is given, it won't be loaded again.
    """
    # make sure outfn exists
    if not os.path.isfile(infn):
        logger(
            "[mod_plot][ERROR] File %s does not exists! Have you run mod_report.py?\n"
            % infn)
        sys.exit(1)

    # parse data
    if isinstance(data, bool):
        logger("Loading %s ...\n" % infn)
        data = pd.read_csv(infn,
                           sep="\t",
                           header=len(HEADER.split('\n')) - 2,
                           index_col=False)
    if data.shape[0] < 10:
        logger("[mod_plot][ERROR]  %s row(s) found in %s\n" %
               (data.shape[0], infn))
        return
    # plot
    bases = 'ACGT'
    metrics = [
        'depth', 'basecall_accuracy', 'mod_frequency', 'median_mod_prob'
    ]
    sample_names = [
        get_sample_name(n) for n in data.columns if n.endswith(metrics[0])
    ]
    fig, axes = plt.subplots(nrows=len(metrics),
                             ncols=len(bases),
                             sharex="col",
                             sharey="row",
                             figsize=(2 + 1.5 * len(bases) * len(sample_names),
                                      5 * len(metrics)))  #6, 20
    fig.suptitle(infn, fontsize=12)
    nans = [float('nan'), float('nan')]
    # get max median depth
    maxYdepth = 0
    for bi, b in enumerate(bases):
        # get mask for only median_mod_prob
        cols = list(filter(lambda x: x.endswith(metrics[-1]), data.columns))
        _data = data[data.ref_base == b].loc[:, cols].to_numpy()
        # mask nan before plotting https://stackoverflow.com/a/44306965/632242
        mask = ~np.isnan(_data)
        for mi, m in enumerate(metrics):
            cols = list(filter(lambda x: x.endswith(m), data.columns))
            ax = axes[mi, bi]
            _data = data[data.ref_base == b].loc[:, cols].to_numpy(
            )  #; print(bi, b, mi, m, _data.shape)
            #if _data.sum():
            a = ax.violinplot([
                d[m] if d[m].any() else nans for d, m in zip(_data.T, mask.T)
            ],
                              points=20,
                              widths=0.7,
                              bw_method=0.5,
                              showmeans=True,
                              showextrema=True,
                              showmedians=True)
            ax.set_xticks(range(1, len(cols) + 1))
            ax.set_xticklabels([" " for x in range(len(cols))])
            if not mi:
                ax.set_title("%s (%s positions)" %
                             (b, data[data.ref_base == b].shape[0]))
                # set depth Y range as 2*median of depth for A
                if 2 * np.nanmedian(_data, axis=0).max() > maxYdepth:
                    maxYdepth = 2 * np.nanmedian(_data, axis=0).max(
                    )  #; print(np.nanmean(_data, axis=0)); print(a['cmedians'])
                    ax.set_ylim(0, maxYdepth)
            else:
                ax.set_ylim(0, 1)
            if not bi:
                ax.set_ylabel(m)
            if mi + 1 == len(metrics):
                ax.set_xticklabels(sample_names)
    ax.set_ylim(0.5, 1)
    #fig.show()
    fig.savefig(infn + ".%s" % ext)
Example #12
0
def mod_plot(infn,
             ext="svg",
             logger=logger,
             data=False,
             colors="brcmyg"):  #png
    """Generate violin plots

    If data is given, it won't be loaded again.
    """
    # make sure outfn exists
    if not os.path.isfile(infn):
        logger(
            "[mod_plot][ERROR] File %s does not exists! Have you run mod_report.py?\n"
            % infn)
        sys.exit(1)

    # parse data
    if isinstance(data, bool):
        logger("Loading %s ...\n" % infn)
        data = pd.read_csv(infn,
                           sep="\t",
                           header=len(HEADER.split('\n')) - 2,
                           index_col=False)
    if data.shape[0] < 10:
        logger("[mod_plot][ERROR]  %s row(s) found in %s\n" %
               (data.shape[0], infn))
        return
    # plot
    bases = data["mod"].unique()  #; print(bases)
    metrics = [
        'depth', 'basecall_accuracy', 'mod_frequency', 'median_mod_prob'
    ]
    metrics_names = [
        "Number of reads", "Agreement with reference",
        "Frequency of modification", "Median modification probability"
    ]
    sample_names = [
        get_sample_name(n) for n in data.columns if n.endswith(metrics[0])
    ]
    fig, axes = plt.subplots(nrows=len(metrics),
                             ncols=len(bases),
                             sharex="col",
                             sharey="row",
                             figsize=(1.5 * len(bases) * len(sample_names),
                                      2 + 3 * len(metrics)))  #6, 20
    fig.suptitle(infn, fontsize=12)
    nans = [float('nan'), float('nan')]
    # get max median depth
    maxYdepth = 0
    for bi, b in enumerate(bases):
        # get mask for only median_mod_prob
        cols = list(filter(lambda x: x.endswith(metrics[-1]), data.columns))
        _data = data[data["mod"] == b].loc[:, cols].to_numpy()
        # mask nan before plotting https://stackoverflow.com/a/44306965/632242
        mask = ~np.isnan(_data)
        for mi, m in enumerate(metrics):
            cols = list(filter(lambda x: x.endswith(m), data.columns))
            ax = axes[mi, bi] if len(bases) > 1 else axes[mi]
            _data = data[data["mod"] == b].loc[:, cols].to_numpy(
            )  #; print(bi, b, mi, m, _data.shape)
            #if _data.sum():
            a = ax.violinplot([
                d[m] if d[m].any() else nans for d, m in zip(_data.T, mask.T)
            ],
                              points=20,
                              widths=0.7,
                              bw_method=0.5,
                              showextrema=True,
                              showmedians=True)  #showmeans=True,
            # color samples differently
            for pci, pc in enumerate(a['bodies']):
                pc.set_facecolor(colors[pci % len(colors)])
            #pc.set_edgecolor('black')
            #pc.set_alpha(1)
            ax.set_xticks(range(1, len(cols) + 1))
            ax.set_xticklabels([" " for x in range(len(cols))])
            if not mi:
                ax.set_title("%s\n%s positions" %
                             (b, data[data["mod"] == b].shape[0]))
                # set depth Y range as 2*median of depth for A
                if 2 * np.nanmedian(_data, axis=0).max() > maxYdepth:
                    maxYdepth = 2 * np.nanmedian(_data, axis=0).max(
                    )  #; print(np.nanmean(_data, axis=0)); print(a['cmedians'])
                    ax.set_ylim(0, maxYdepth)
            elif mi in (1, 3):
                ax.set_ylim(0.5, 1)
                #elif mi==2:  ax.set_ylim(0, 0.5)
            else:
                ax.set_ylim(0, 1)
            if not bi:
                ax.set_ylabel(metrics_names[mi])
            if mi + 1 == len(metrics):
                ax.set_xticklabels(sample_names)
            ax.grid(axis="y", which="both")
    #fig.show()
    fig.savefig(infn + ".%s" % ext)
Example #13
0
def plot_regions(infn,
                 bed,
                 ext="svg",
                 logger=logger,
                 data=False,
                 colors="brcmyg"):
    """Generate frequency plots for given regions

    If data is given, it won't be loaded again.
    """
    # make sure outfn exists
    if not os.path.isfile(infn):
        logger(
            "[mod_plot][ERROR] File %s does not exists! Have you run mod_report.py?\n"
            % infn)
        sys.exit(1)
    # get outdir
    outdir = os.path.join(os.path.dirname(infn), "plots")
    if not os.path.isdir(outdir):
        os.makedirs(outdir)
    # load regions to plot
    regions = load_bed(bed)  #; print(regions)
    logger("Saving plots for %s region(s) to %s ...\n" %
           (len(regions), outdir))

    # parse data
    if isinstance(data, bool):
        logger("Loading %s ...\n" % infn)
        data = pd.read_csv(infn,
                           sep="\t",
                           header=len(HEADER.split('\n')) - 2,
                           index_col=False,
                           dtype={
                               "chr": object,
                               "pos": int
                           })
    if data.shape[0] < 1:
        logger("[mod_plot][ERROR]  %s row(s) found in %s\n" %
               (data.shape[0], infn))
        return
    # get uniq mods
    mods = data["mod"].unique()
    metrics = [
        'depth', 'basecall_accuracy', 'mod_frequency', 'median_mod_prob'
    ]
    sample_names = [
        get_sample_name(n) for n in data.columns if n.endswith(metrics[0])
    ]
    metric2cols = {
        m: [c for c in data.columns if c.endswith(m)]
        for m in [
            'mod_frequency',
        ]
    }  #; print(metric2cols)
    logger(" %s samples and %s modifications: %s\n" %
           (len(sample_names), len(mods), ", ".join(mods)))
    metric = 'mod_frequency'
    # plot regions
    for ref, s, e in regions:
        #df = data[(data.chr==ref)&(data.pos>=s)&(data.pos<=e)]
        df = data[np.all((data.chr == ref, data.pos >= s, data.pos <= e),
                         axis=0)]
        if df.shape[0] < 1:
            logger("[mod_plot][ERROR] No modifications in %s:%s-%s\n" %
                   (ref, s, e))
            continue
        mods = df["mod"].unique()
        logger(" %s:%s-%s with %s modifications: %s\n" %
               (ref, s, e, len(mods), ", ".join(mods)))
        #return
        fig, axes = plt.subplots(
            nrows=len(sample_names),
            ncols=1,
            sharex="all",
            sharey="all",
            figsize=(20, 2 + 1 * len(sample_names)))  #20, 12 for 2 samples
        fig.suptitle("%s:%s-%s" % (ref, s, e), fontsize=12)
        labels = []
        for strand, norm in zip("+-", (1, -1)):
            for ax, col, name in zip(axes, metric2cols[metric], sample_names):
                for color, mod in zip(colors, mods):
                    selection = (df.strand == strand) & (df["mod"] == mod)
                    ax.bar(df[selection].pos,
                           norm * df[selection][col],
                           color=color,
                           label=mod)
                    if strand:
                        ax.set_title(name.replace('\n', '_'))  #col)
                        ax.set_ylabel("%s\n[on +/- strand]" % metric)

        # set limits
        ax.set_xlim(s, e + 1)
        ax.set_ylim(-1, 1)
        ax.set_xlabel("%s position" % ref)
        #https://stackoverflow.com/a/13589144/632242
        handles, labels = plt.gca().get_legend_handles_labels()
        by_label = OrderedDict(zip(labels, handles))
        #plt.legend(by_label.values(), by_label.keys())
        fig.legend(handles=by_label.values(), labels=by_label.keys())
        #fig.tight_layout()
        #plt.show()
        outfn = os.path.join(outdir, "%s:%s-%s.%s" % (ref, s, e, ext))
        fig.savefig(outfn)