Exemplo n.º 1
0
def count_reads(regions_list, params):
    """ Count reads from bam within regions (counts position of cutsite to prevent double-counting) """

    bam_f = params.bam
    read_shift = params.read_shift
    bam_obj = pysam.AlignmentFile(bam_f, "rb")

    log_q = params.log_q
    logger = TobiasLogger("", params.verbosity,
                          log_q)  #sending all logger calls to log_q

    #Count per region
    read_count = 0
    logger.spam("Started counting region_chunk ({0} -> {1})".format(
        "_".join([str(element) for element in regions_list[0]]),
        "_".join([str(element) for element in regions_list[-1]])))
    for region in regions_list:
        read_lst = ReadList().from_bam(bam_obj, region)

        for read in read_lst:
            read.get_cutsite(read_shift)
            if read.cutsite > region.start and read.cutsite < region.end:  #only reads within borders
                read_count += 1

    logger.spam("Finished counting region_chunk ({0} -> {1})".format(
        "_".join([str(element) for element in regions_list[0]]),
        "_".join([str(element) for element in regions_list[-1]])))
    bam_obj.close()

    return (read_count)
Exemplo n.º 2
0
def run_atacorrect(args):
    """
	Function for bias correction of input .bam files
	Calls functions in ATACorrect_functions and several internal classes
	"""

    #Test if required arguments were given:
    if args.bam == None:
        sys.exit("Error: No .bam-file given")
    if args.genome == None:
        sys.exit("Error: No .fasta-file given")
    if args.peaks == None:
        sys.exit("Error: No .peaks-file given")

    #Adjust some parameters depending on input
    args.prefix = os.path.splitext(os.path.basename(
        args.bam))[0] if args.prefix == None else args.prefix
    args.outdir = os.path.abspath(
        args.outdir) if args.outdir != None else os.path.abspath(os.getcwd())

    #Set output bigwigs based on input
    tracks = ["uncorrected", "bias", "expected", "corrected"]
    tracks = [track for track in tracks
              if track not in args.track_off]  # switch off printing

    if args.split_strands == True:
        strands = ["forward", "reverse"]
    else:
        strands = ["both"]

    output_bws = {}
    for track in tracks:
        output_bws[track] = {}
        for strand in strands:
            elements = [args.prefix, track] if strand == "both" else [
                args.prefix, track, strand
            ]
            output_bws[track][strand] = {
                "fn":
                os.path.join(args.outdir, "{0}.bw".format("_".join(elements)))
            }

    #Set all output files
    bam_out = os.path.join(args.outdir, args.prefix + "_atacorrect.bam")
    bigwigs = [
        output_bws[track][strand]["fn"]
        for (track, strand) in itertools.product(tracks, strands)
    ]
    figures_f = os.path.join(args.outdir,
                             "{0}_atacorrect.pdf".format(args.prefix))

    output_files = bigwigs + [figures_f]
    output_files = list(OrderedDict.fromkeys(
        output_files))  #remove duplicates due to "both" option

    strands = ["forward", "reverse"]

    #----------------------------------------------------------------------------------------------------#
    # Print info on run
    #----------------------------------------------------------------------------------------------------#

    logger = TobiasLogger("ATACorrect", args.verbosity)
    logger.begin()

    parser = add_atacorrect_arguments(argparse.ArgumentParser())
    logger.arguments_overview(parser, args)
    logger.output_files(output_files)

    args.cores = check_cores(args.cores, logger)

    #----------------------------------------------------------------------------------------------------#
    # Test input file availability for reading
    #----------------------------------------------------------------------------------------------------#

    logger.info("----- Processing input data -----")

    logger.debug("Testing input file availability")
    check_files([args.bam, args.genome, args.peaks], "r")

    logger.debug("Testing output directory/file writeability")
    make_directory(args.outdir)
    check_files(output_files, "w")

    #Open pdf for figures
    figure_pdf = PdfPages(figures_f, keep_empty=False)

    #----------------------------------------------------------------------------------------------------#
    # Read information in bam/fasta
    #----------------------------------------------------------------------------------------------------#

    logger.info("Reading info from .bam file")
    bamfile = pysam.AlignmentFile(args.bam, "rb")
    if bamfile.has_index() == False:
        logger.warning("No index found for bamfile - creating one via pysam.")
        pysam.index(args.bam)

    bam_references = bamfile.references  #chromosomes in correct order
    bam_chrom_info = dict(zip(bamfile.references, bamfile.lengths))
    logger.debug("bam_chrom_info: {0}".format(bam_chrom_info))
    bamfile.close()

    logger.info("Reading info from .fasta file")
    fastafile = pysam.FastaFile(args.genome)
    fasta_chrom_info = dict(zip(fastafile.references, fastafile.lengths))
    logger.debug("fasta_chrom_info: {0}".format(fasta_chrom_info))
    fastafile.close()

    #Compare chrom lengths
    chrom_in_common = set(bam_chrom_info.keys()).intersection(
        fasta_chrom_info.keys())
    for chrom in chrom_in_common:
        bamlen = bam_chrom_info[chrom]
        fastalen = fasta_chrom_info[chrom]
        if bamlen != fastalen:
            logger.warning("(Fastafile)\t{0} has length {1}".format(
                chrom, fasta_chrom_info[chrom]))
            logger.warning("(Bamfile)\t{0} has length {1}".format(
                chrom, bam_chrom_info[chrom]))
            sys.exit(
                "Error: .bam and .fasta have different chromosome lengths. Please make sure the genome file is similar to the one used in mapping."
            )

    #Subset bam_references to those for which there are sequences in fasta
    chrom_not_in_fasta = set(bam_references) - set(fasta_chrom_info.keys())
    if len(chrom_not_in_fasta) > 1:
        logger.warning(
            "The following contigs in --bam did not have sequences in --fasta: {0}. NOTE: These contigs will be skipped in calculation and output."
            .format(chrom_not_in_fasta))

    bam_references = [ref for ref in bam_references if ref in fasta_chrom_info]
    chrom_in_common = [ref for ref in chrom_in_common if ref in bam_references]

    #----------------------------------------------------------------------------------------------------#
    # Read regions from bedfiles
    #----------------------------------------------------------------------------------------------------#

    logger.info("Processing input/output regions")

    #Chromosomes included in analysis
    genome_regions = RegionList().from_list([
        OneRegion([chrom, 0, bam_chrom_info[chrom]])
        for chrom in bam_references if not "M" in chrom
    ])  #full genome length
    chrom_in_common = [chrom for chrom in chrom_in_common if "M" not in chrom]
    logger.debug("CHROMS\t{0}".format("; ".join(
        ["{0} ({1})".format(reg.chrom, reg.end) for reg in genome_regions])))
    genome_bp = sum([region.get_length() for region in genome_regions])

    # Process peaks
    peak_regions = RegionList().from_bed(args.peaks)
    peak_regions.merge()
    for i in range(len(peak_regions) - 1, -1, -1):
        region = peak_regions[i]

        peak_regions[i] = region.check_boundary(
            bam_chrom_info, "cut")  #regions are cut/removed from list
        if peak_regions[i] is None:
            logger.warning(
                "Peak region {0} was removed at it is either out of bounds or not in the chromosomes given in genome/bam."
                .format(region.tup(), i + 1))
            del peak_regions[i]

    nonpeak_regions = deepcopy(genome_regions).subtract(peak_regions)

    # Process specific input regions if given
    if args.regions_in != None:
        input_regions = RegionList().from_bed(args.regions_in)
        input_regions.merge()
        input_regions.apply_method(OneRegion.check_boundary, bam_chrom_info,
                                   "cut")
    else:
        input_regions = nonpeak_regions

    # Process specific output regions
    if args.regions_out != None:
        output_regions = RegionList().from_bed(args.regions_out)
    else:
        output_regions = deepcopy(peak_regions)

    #Extend regions to make sure extend + flanking for window/flank are within boundaries
    flank_extend = args.k_flank + int(args.window / 2.0)
    output_regions.apply_method(OneRegion.extend_reg,
                                args.extend + flank_extend)
    output_regions.merge()
    output_regions.apply_method(OneRegion.check_boundary, bam_chrom_info,
                                "cut")
    output_regions.apply_method(
        OneRegion.extend_reg, -flank_extend
    )  #Cut to needed size knowing that the region will be extended in function

    #Remove blacklisted regions and chromosomes not in common
    blacklist_regions = RegionList().from_bed(
        args.blacklist) if args.blacklist != None else RegionList(
            [])  #fill in with regions from args.blacklist
    regions_dict = {
        "genome": genome_regions,
        "input_regions": input_regions,
        "output_regions": output_regions,
        "peak_regions": peak_regions,
        "nonpeak_regions": nonpeak_regions,
        "blacklist_regions": blacklist_regions
    }
    for sub in [
            "input_regions", "output_regions", "peak_regions",
            "nonpeak_regions"
    ]:
        regions_sub = regions_dict[sub]
        regions_sub.subtract(blacklist_regions)
        regions_sub = regions_sub.apply_method(OneRegion.split_region, 50000)

        regions_sub.keep_chroms(chrom_in_common)
        regions_dict[sub] = regions_sub

    #write beds to look at in igv
    #input_regions.write_bed(os.path.join(args.outdir, "input_regions.bed"))
    #output_regions.write_bed(os.path.join(args.outdir, "output_regions.bed"))
    #peak_regions.write_bed(os.path.join(args.outdir, "peak_regions.bed"))
    #nonpeak_regions.write_bed(os.path.join(args.outdir, "nonpeak_regions.bed"))

    #Sort according to order in bam_references:
    output_regions.loc_sort(bam_references)
    chrom_order = {bam_references[i]: i
                   for i in range(len(bam_references))
                   }  #for use later when sorting output

    #### Statistics about regions ####
    genome_bp = sum([region.get_length() for region in regions_dict["genome"]])
    for key in regions_dict:
        total_bp = sum([region.get_length() for region in regions_dict[key]])
        logger.stats("{0}: {1} regions | {2} bp | {3:.2f}% coverage".format(
            key, len(regions_dict[key]), total_bp, total_bp / genome_bp * 100))

    #Estallish variables for regions to be used
    input_regions = regions_dict["input_regions"]
    output_regions = regions_dict["output_regions"]
    peak_regions = regions_dict["peak_regions"]
    nonpeak_regions = regions_dict["nonpeak_regions"]

    #Exit if no input/output regions were found
    if len(input_regions) == 0 or len(output_regions) == 0 or len(
            peak_regions) == 0 or len(nonpeak_regions) == 0:
        logger.error("No regions found - exiting!")
        sys.exit()

    #----------------------------------------------------------------------------------------------------#
    # Estimate normalization factors
    #----------------------------------------------------------------------------------------------------#

    #Setup logger queue
    logger.debug("Setting up listener for log")
    logger.start_logger_queue()
    args.log_q = logger.queue

    #----------------------------------------------------------------------------------------------------#

    logger.comment("")
    logger.info("----- Estimating normalization factors -----")

    #If normalization is to be calculated
    if not args.norm_off:

        #Reads in peaks/nonpeaks
        logger.info("Counting reads in peak regions")
        peak_region_chunks = peak_regions.chunks(args.split)
        reads_peaks = sum(
            run_parallel(count_reads, peak_region_chunks, [args], args.cores,
                         logger))
        logger.comment("")

        logger.info("Counting reads in nonpeak regions")
        nonpeak_region_chunks = nonpeak_regions.chunks(args.split)
        reads_nonpeaks = sum(
            run_parallel(count_reads, nonpeak_region_chunks, [args],
                         args.cores, logger))

        reads_total = reads_peaks + reads_nonpeaks

        logger.stats("TOTAL_READS\t{0}".format(reads_total))
        logger.stats("PEAK_READS\t{0}".format(reads_peaks))
        logger.stats("NONPEAK_READS\t{0}".format(reads_nonpeaks))

        lib_norm = 10000000 / reads_total
        frip = reads_peaks / reads_total
        correct_factor = lib_norm * (1 / frip)

        logger.stats("LIB_NORM\t{0:.5f}".format(lib_norm))
        logger.stats("FRiP\t{0:.5f}".format(frip))
    else:
        logger.info("Normalization was switched off")
        correct_factor = 1.0

    logger.stats("CORRECTION_FACTOR:\t{0:.5f}".format(correct_factor))

    #----------------------------------------------------------------------------------------------------#
    # Estimate sequence bias
    #----------------------------------------------------------------------------------------------------#

    logger.comment("")
    logger.info("Started estimation of sequence bias...")

    input_region_chunks = input_regions.chunks(
        args.split)  #split to 100 chunks (also decides the step of output)
    out_lst = run_parallel(bias_estimation, input_region_chunks, [args],
                           args.cores,
                           logger)  #Output is list of AtacBias objects

    #Join objects
    estimated_bias = out_lst[0]  #initialize object with first output
    for output in out_lst[1:]:
        estimated_bias.join(
            output
        )  #bias object contains bias/background SequenceMatrix objects

    logger.debug("Bias estimated\tno_reads: {0}".format(
        estimated_bias.no_reads))

    #----------------------------------------------------------------------------------------------------#
    # Join estimations from all chunks of regions
    #----------------------------------------------------------------------------------------------------#

    bias_obj = estimated_bias
    bias_obj.correction_factor = correct_factor

    ### Bias motif ###
    logger.info("Finalizing bias motif for scoring")
    for strand in strands:
        bias_obj.bias[strand].prepare_mat()

        logger.debug("Saving pssm to figure pdf")
        fig = plot_pssm(bias_obj.bias[strand].pssm,
                        "Tn5 insertion bias of reads ({0})".format(strand))
        figure_pdf.savefig(fig)

    #Write bias motif to pickle
    out_f = os.path.join(args.outdir, args.prefix + "_AtacBias.pickle")
    logger.debug("Saving bias object to pickle ({0})".format(out_f))
    bias_obj.to_pickle(out_f)

    #----------------------------------------------------------------------------------------------------#
    # Correct read bias and write to bigwig
    #----------------------------------------------------------------------------------------------------#

    logger.comment("")
    logger.info("----- Correcting reads from .bam within output regions -----")

    output_regions.loc_sort(bam_references)  #sort in order of references
    output_regions_chunks = output_regions.chunks(args.split)
    no_tasks = float(len(output_regions_chunks))
    chunk_sizes = [len(chunk) for chunk in output_regions_chunks]
    logger.debug("All regions chunked: {0} ({1})".format(
        len(output_regions), chunk_sizes))

    ### Create key-file linking for bigwigs
    key2file = {}
    for track in output_bws:
        for strand in output_bws[track]:
            filename = output_bws[track][strand]["fn"]
            key = "{}:{}".format(track, strand)
            key2file[key] = filename

    #Start correction/write cores
    n_bigwig = len(key2file.values())
    writer_cores = min(n_bigwig, max(
        1, int(args.cores *
               0.1)))  #at most one core per bigwig or 10% of cores (or 1)
    worker_cores = max(1, args.cores - writer_cores)
    logger.debug("Worker cores: {0}".format(worker_cores))
    logger.debug("Writer cores: {0}".format(writer_cores))

    worker_pool = mp.Pool(processes=worker_cores)
    writer_pool = mp.Pool(processes=writer_cores)
    manager = mp.Manager()

    #Start bigwig file writers
    writer_tasks = []
    header = [(chrom, bam_chrom_info[chrom]) for chrom in bam_references]
    key_chunks = [
        list(key2file.keys())[i::writer_cores] for i in range(writer_cores)
    ]
    qs_list = []
    qs = {}
    for chunk in key_chunks:
        logger.debug("Creating writer queue for {0}".format(chunk))

        q = manager.Queue()
        qs_list.append(q)

        files = [key2file[key] for key in chunk]
        writer_tasks.append(
            writer_pool.apply_async(bigwig_writer,
                                    args=(q, dict(zip(chunk, files)), header,
                                          output_regions, args))
        )  #, callback = lambda x: finished.append(x) print("Writing time: {0}".format(x)))
        for key in chunk:
            qs[key] = q

    args.qs = qs
    writer_pool.close()  #no more jobs applied to writer_pool

    #Start correction
    logger.debug("Starting correction")
    task_list = [
        worker_pool.apply_async(bias_correction, args=[chunk, args, bias_obj])
        for chunk in output_regions_chunks
    ]
    worker_pool.close()
    monitor_progress(task_list, logger, "Correction progress:"
                     )  #does not exit until tasks in task_list finished
    results = [task.get() for task in task_list]

    #Get all results
    pre_bias = results[0][0]  #initialize with first result
    post_bias = results[0][1]  #initialize with first result
    for result in results[1:]:
        pre_bias_chunk = result[0]
        post_bias_chunk = result[1]

        for direction in strands:
            pre_bias[direction].add_counts(pre_bias_chunk[direction])
            post_bias[direction].add_counts(post_bias_chunk[direction])

    #Stop all queues for writing
    logger.debug("Stop all queues by inserting None")
    for q in qs_list:
        q.put((None, None, None))

    #Fetch error codes from bigwig writers
    logger.debug("Fetching possible errors from bigwig_writer tasks")
    results = [task.get()
               for task in writer_tasks]  #blocks until writers are finished

    logger.debug("Joining bigwig_writer queues")

    qsum = sum([q.qsize() for q in qs_list])
    while qsum != 0:
        qsum = sum([q.qsize() for q in qs_list])
        logger.spam("- Queue sizes {0}".format([(key, qs[key].qsize())
                                                for key in qs]))
        time.sleep(0.5)

    #Waits until all queues are closed
    writer_pool.join()
    worker_pool.terminate()
    worker_pool.join()

    #Stop multiprocessing logger
    logger.stop_logger_queue()

    #----------------------------------------------------------------------------------------------------#
    # Information and verification of corrected read frequencies
    #----------------------------------------------------------------------------------------------------#

    logger.comment("")
    logger.info("Verifying bias correction")

    #Calculating variance per base
    for strand in strands:

        #Invert negative counts
        abssum = np.abs(np.sum(post_bias[strand].neg_counts, axis=0))
        post_bias[strand].neg_counts = post_bias[strand].neg_counts + abssum

        #Join negative/positive counts
        post_bias[strand].counts += post_bias[strand].neg_counts  #now pos

        pre_bias[strand].prepare_mat()
        post_bias[strand].prepare_mat()

        pre_var = np.mean(np.var(pre_bias[strand].bias_pwm,
                                 axis=1)[:4])  #mean of variance per nucleotide
        post_var = np.mean(np.var(post_bias[strand].bias_pwm, axis=1)[:4])
        logger.stats("BIAS\tpre-bias variance {0}:\t{1:.7f}".format(
            strand, pre_var))
        logger.stats("BIAS\tpost-bias variance {0}:\t{1:.7f}".format(
            strand, post_var))

        #Plot figure
        fig_title = "Nucleotide frequencies in corrected reads\n({0} strand)".format(
            strand)
        figure_pdf.savefig(
            plot_correction(pre_bias[strand].bias_pwm,
                            post_bias[strand].bias_pwm, fig_title))

    #----------------------------------------------------------------------------------------------------#
    # Finish up
    #----------------------------------------------------------------------------------------------------#

    plt.close('all')
    figure_pdf.close()
    logger.end()
Exemplo n.º 3
0
def run_bindetect(args):
    """ Main function to run bindetect algorithm with input files and parameters given in args """

    #Checking input and setting cond_names
    check_required(args, ["signals", "motifs", "genome", "peaks"])
    args.cond_names = [
        os.path.basename(os.path.splitext(bw)[0]) for bw in args.signals
    ] if args.cond_names is None else args.cond_names
    args.outdir = os.path.abspath(args.outdir)

    #Set output files
    states = ["bound", "unbound"]
    outfiles = [
        os.path.abspath(
            os.path.join(args.outdir, "*", "beds",
                         "*_{0}_{1}.bed".format(condition, state)))
        for (condition, state) in itertools.product(args.cond_names, states)
    ]
    outfiles.append(
        os.path.abspath(os.path.join(args.outdir, "*", "beds", "*_all.bed")))
    outfiles.append(
        os.path.abspath(
            os.path.join(args.outdir, "*", "plots", "*_log2fcs.pdf")))
    outfiles.append(
        os.path.abspath(os.path.join(args.outdir, "*", "*_overview.txt")))
    outfiles.append(
        os.path.abspath(os.path.join(args.outdir, "*", "*_overview.xlsx")))

    outfiles.append(
        os.path.abspath(
            os.path.join(args.outdir, args.prefix + "_distances.txt")))
    outfiles.append(
        os.path.abspath(os.path.join(args.outdir,
                                     args.prefix + "_results.txt")))
    outfiles.append(
        os.path.abspath(
            os.path.join(args.outdir, args.prefix + "_results.xlsx")))
    outfiles.append(
        os.path.abspath(os.path.join(args.outdir,
                                     args.prefix + "_figures.pdf")))

    #-------------------------------------------------------------------------------------------------------------#
    #-------------------------------------------- Setup logger and pool ------------------------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    logger = TobiasLogger("BINDetect", args.verbosity)
    logger.begin()

    parser = add_bindetect_arguments(argparse.ArgumentParser())
    logger.arguments_overview(parser, args)
    logger.output_files(outfiles)

    # Setup pool
    args.cores = check_cores(args.cores, logger)
    writer_cores = max(1, int(args.cores * 0.1))
    worker_cores = max(1, args.cores - writer_cores)
    logger.debug("Worker cores: {0}".format(worker_cores))
    logger.debug("Writer cores: {0}".format(writer_cores))

    pool = mp.Pool(processes=worker_cores)
    writer_pool = mp.Pool(processes=writer_cores)

    #-------------------------------------------------------------------------------------------------------------#
    #-------------------------- Pre-processing data: Reading motifs, sequences, peaks ----------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    logger.info("----- Processing input data -----")

    #Check opening/writing of files
    logger.info("Checking reading/writing of files")
    check_files([args.signals, args.motifs, args.genome, args.peaks],
                action="r")
    check_files(outfiles[-3:], action="w")
    make_directory(args.outdir)

    #Comparisons between conditions
    no_conditions = len(args.signals)
    if args.time_series:
        comparisons = list(zip(args.cond_names[:-1], args.cond_names[1:]))
        args.comparisons = comparisons
    else:
        comparisons = list(itertools.combinations(args.cond_names,
                                                  2))  #all-against-all
        args.comparisons = comparisons

    #Open figure pdf and write overview
    fig_out = os.path.abspath(
        os.path.join(args.outdir, args.prefix + "_figures.pdf"))
    figure_pdf = PdfPages(fig_out, keep_empty=True)

    plt.figure()
    plt.axis('off')
    plt.text(0.5,
             0.8,
             "BINDETECT FIGURES",
             ha="center",
             va="center",
             fontsize=20)

    #output and order
    titles = []
    titles.append("Raw score distributions")
    titles.append("Normalized score distributions")
    if args.debug:
        for (cond1, cond2) in comparisons:
            titles.append("Background log2FCs ({0} / {1})".format(
                cond1, cond2))

    for (cond1, cond2) in comparisons:
        titles.append("BINDetect plot ({0} / {1})".format(cond1, cond2))

    plt.text(0.1,
             0.6,
             "\n".join([
                 "Page {0}) {1}".format(i + 2, titles[i])
                 for i in range(len(titles))
             ]) + "\n\n",
             va="top")
    figure_pdf.savefig(bbox_inches='tight')
    plt.close()

    ################# Read peaks ################
    #Read peak and peak_header
    logger.info("Reading peaks")
    peaks = RegionList().from_bed(args.peaks)
    logger.info("- Found {0} regions in input peaks".format(len(peaks)))
    peaks = peaks.merge()  #merge overlapping peaks
    logger.info("- Merged to {0} regions".format(len(peaks)))

    if len(peaks) == 0:
        logger.error("Input --peaks file is empty!")
        sys.exit()

    #Read header and check match with number of peak columns
    peak_columns = len(peaks[0])  #number of columns
    if args.peak_header != None:
        content = open(args.peak_header, "r").read()
        args.peak_header_list = content.split()
        logger.debug("Peak header: {0}".format(args.peak_header_list))

        #Check whether peak header fits with number of peak columns
        if len(args.peak_header_list) != peak_columns:
            logger.error(
                "Length of --peak_header ({0}) does not fit number of columns in --peaks ({1})."
                .format(len(args.peak_header_list), peak_columns))
            sys.exit()
    else:
        args.peak_header_list = ["peak_chr", "peak_start", "peak_end"] + [
            "additional_" + str(num + 1) for num in range(peak_columns - 3)
        ]
    logger.debug("Peak header list: {0}".format(args.peak_header_list))

    ################# Check for match between peaks and fasta/bigwig #################
    logger.info(
        "Checking for match between --peaks and --fasta/--signals boundaries")
    logger.info("- Comparing peaks to {0}".format(args.genome))
    fasta_obj = pysam.FastaFile(args.genome)
    fasta_boundaries = dict(zip(fasta_obj.references, fasta_obj.lengths))
    fasta_obj.close()
    logger.debug("Fasta boundaries: {0}".format(fasta_boundaries))
    peaks = peaks.apply_method(OneRegion.check_boundary, fasta_boundaries,
                               "exit")  #will exit if peaks are outside borders

    #Check boundaries of each bigwig signal individually
    for signal in args.signals:
        logger.info("- Comparing peaks to {0}".format(signal))
        pybw_obj = pybw.open(signal)
        pybw_header = pybw_obj.chroms()
        pybw_obj.close()
        logger.debug("Signal boundaries: {0}".format(pybw_header))
        peaks = peaks.apply_method(OneRegion.check_boundary, pybw_header,
                                   "exit")

    ##### GC content for motif scanning ######
    #Make chunks of regions for multiprocessing
    logger.info("Estimating GC content from peak sequences")
    peak_chunks = peaks.chunks(args.split)
    gc_content_pool = pool.starmap(
        get_gc_content, itertools.product(peak_chunks, [args.genome]))
    gc_content = np.mean(gc_content_pool)  #fraction
    args.gc = gc_content
    bg = np.array([(1 - args.gc) / 2.0, args.gc / 2.0, args.gc / 2.0,
                   (1 - args.gc) / 2.0])
    logger.info("- GC content estimated at {0:.2f}%".format(gc_content * 100))

    ################ Get motifs ################
    logger.info("Reading motifs from file")
    motif_list = MotifList()
    args.motifs = expand_dirs(args.motifs)
    for f in args.motifs:
        motif_list += MotifList().from_file(f)  #List of OneMotif objects
    no_pfms = len(motif_list)
    logger.info("- Read {0} motifs".format(no_pfms))

    logger.debug("Getting motifs ready")
    motif_list.bg = bg

    logger.debug("Getting reverse motifs")
    motif_list.extend([motif.get_reverse() for motif in motif_list])
    logger.spam(motif_list)

    #Set prefixes
    for motif in motif_list:  #now with reverse motifs as well
        motif.set_prefix(args.naming)
        motif.bg = bg

        logger.spam("Getting pssm for motif {0}".format(motif.name))
        motif.get_pssm()

    motif_names = list(set([motif.prefix for motif in motif_list]))

    #Get threshold for motifs
    logger.debug("Getting match threshold per motif")
    outlist = pool.starmap(OneMotif.get_threshold,
                           itertools.product(motif_list, [args.motif_pvalue]))

    motif_list = MotifList(outlist)
    for motif in motif_list:
        logger.debug("Motif {0}: threshold {1}".format(motif.name,
                                                       motif.threshold))

    logger.info("Creating folder structure for each TF")
    for TF in motif_names:
        logger.spam("Creating directories for {0}".format(TF))
        make_directory(os.path.join(args.outdir, TF))
        make_directory(os.path.join(args.outdir, TF, "beds"))
        make_directory(os.path.join(args.outdir, TF, "plots"))

    #-------------------------------------------------------------------------------------------------------------#
    #----------------------------------------- Plot logos for all motifs -----------------------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    plus_motifs = [motif for motif in motif_list if motif.strand == "+"]
    logo_filenames = {
        motif.prefix: os.path.join(args.outdir, motif.prefix,
                                   motif.prefix + ".png")
        for motif in plus_motifs
    }

    logger.info("Plotting sequence logos for each motif")
    task_list = [
        pool.apply_async(OneMotif.logo_to_file, (
            motif,
            logo_filenames[motif.prefix],
        )) for motif in plus_motifs
    ]
    monitor_progress(task_list, logger)
    results = [task.get() for task in task_list]
    logger.comment("")

    logger.debug("Getting base64 strings per motif")
    for motif in motif_list:
        if motif.strand == "+":
            #motif.get_base()
            with open(logo_filenames[motif.prefix], "rb") as png:
                motif.base = base64.b64encode(png.read()).decode("utf-8")

    #-------------------------------------------------------------------------------------------------------------#
    #--------------------- Motif scanning: Find binding sites and match to footprint scores ----------------------#
    #-------------------------------------------------------------------------------------------------------------#

    logger.comment("")
    logger.start_logger_queue(
    )  #start process for listening and handling through the main logger queue
    args.log_q = logger.queue  #queue for multiprocessing logging
    manager = mp.Manager()
    logger.info("Scanning for motifs and matching to signals...")

    #Create writer queues for bed-file output
    logger.debug("Setting up writer queues")
    qs_list = []
    writer_qs = {}

    #writer_queue = create_writer_queue(key2file, writer_cores)
    #writer_queue.stop()	#wait until all are done

    manager = mp.Manager()
    TF_names_chunks = [
        motif_names[i::writer_cores] for i in range(writer_cores)
    ]
    for TF_names_sub in TF_names_chunks:
        logger.debug("Creating writer queue for {0}".format(TF_names_sub))
        files = [
            os.path.join(args.outdir, TF, "beds", TF + ".tmp")
            for TF in TF_names_sub
        ]

        q = manager.Queue()
        qs_list.append(q)

        writer_pool.apply_async(
            file_writer, args=(q, dict(zip(TF_names_sub, files)), args)
        )  #, callback = lambda x: finished.append(x) print("Writing time: {0}".format(x)))
        for TF in TF_names_sub:
            writer_qs[TF] = q
    writer_pool.close()  #no more jobs applied to writer_pool

    #todo: use run_parallel
    #Start working on data
    if worker_cores == 1:
        logger.debug("Running with cores = 1")
        results = []
        for chunk in peak_chunks:
            results.append(
                scan_and_score(chunk, motif_list, args, args.log_q, writer_qs))

    else:
        logger.debug("Sending jobs to worker pool")

        task_list = [
            pool.apply_async(scan_and_score, (
                chunk,
                motif_list,
                args,
                args.log_q,
                writer_qs,
            )) for chunk in peak_chunks
        ]
        monitor_progress(task_list, logger)
        results = [task.get() for task in task_list]

    logger.info("Done scanning for TFBS across regions!")
    #logger.stop_logger_queue()	#stop the listening process (wait until all was written)

    #--------------------------------------#
    logger.info("Waiting for bedfiles to write")

    #Stop all queues for writing
    logger.debug("Stop all queues by inserting None")
    for q in qs_list:
        q.put((None, None))

    logger.debug("Joining bed_writer queues")
    for i, q in enumerate(qs_list):
        logger.debug("- Queue {0} (size {1})".format(i, q.qsize()))

    #Waits until all queues are closed
    writer_pool.join()

    #-------------------------------------------------------------------------------------------------------------#
    #---------------------------- Process information on background scores and overlaps --------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    logger.info("Merging results from subsets")
    background = merge_dicts([result[0] for result in results])
    TF_overlaps = merge_dicts([result[1] for result in results])
    results = None

    #Add missing TF overlaps (if some TFs had zero sites)
    for TF1 in plus_motifs:
        if TF1.prefix not in TF_overlaps:
            TF_overlaps[TF1.prefix] = 0
        for TF2 in plus_motifs:
            tup = (TF1.prefix, TF2.prefix)
            if tup not in TF_overlaps:
                TF_overlaps[tup] = 0

    #Collect sampled background values
    for bigwig in args.cond_names:
        background["signal"][bigwig] = np.array(background["signal"][bigwig])

    #Check how many values were fetched from background
    n_bg_values = len(background["signal"][args.cond_names[0]])
    logger.debug("Collected {0} values from background".format(n_bg_values))
    if n_bg_values < 1000:
        err_str = "Number of background values collected from peaks is low (={0}) ".format(
            n_bg_values)
        err_str += "- this affects estimation of the bound/unbound threshold and the normalization between conditions. "
        err_str += "To improve this estimation, please run BINDetect with --peaks = the full peak set across all conditions."
        logger.warning(err_str)

    logger.comment("")
    logger.info("Estimating score distribution per condition")

    fig = plot_score_distribution(
        [background["signal"][bigwig] for bigwig in args.cond_names],
        labels=args.cond_names,
        title="Raw scores per condition")
    figure_pdf.savefig(fig, bbox_inches='tight')
    plt.close()

    logger.info("Normalizing scores")
    list_of_vals = [background["signal"][bigwig] for bigwig in args.cond_names]
    normed, norm_objects = quantile_normalization(list_of_vals)

    args.norm_objects = dict(zip(args.cond_names, norm_objects))
    for bigwig in args.cond_names:
        background["signal"][bigwig] = args.norm_objects[bigwig].normalize(
            background["signal"][bigwig])

    fig = plot_score_distribution(
        [background["signal"][bigwig] for bigwig in args.cond_names],
        labels=args.cond_names,
        title="Normalized scores per condition")
    figure_pdf.savefig(fig, bbox_inches='tight')
    plt.close()

    ###########################################################
    logger.info("Estimating bound/unbound threshold")

    #Prepare scores (remove 0's etc.)
    bg_values = np.array(normed).flatten()
    bg_values = bg_values[np.logical_not(np.isclose(
        bg_values, 0.0))]  #only non-zero counts
    x_max = np.percentile(bg_values, [99])
    bg_values = bg_values[bg_values < x_max]

    if len(bg_values) == 0:
        logger.error(
            "Error processing bigwig scores from background. It could be that there are no scores in the bigwig (=0) assigned for the peaks. Please check your input files."
        )
        sys.exit()

    #Fit mixture of normals
    lowest_bic = np.inf
    for n_components in [2]:  #2 components
        gmm = sklearn.mixture.GaussianMixture(n_components=n_components,
                                              random_state=1)
        gmm.fit(np.log(bg_values).reshape(-1, 1))

        bic = gmm.bic(np.log(bg_values).reshape(-1, 1))
        logger.debug("n_compontents: {0} | bic: {1}".format(n_components, bic))
        if bic < lowest_bic:
            lowest_bic = bic
            best_gmm = gmm
    gmm = best_gmm

    #Extract most-right gaussian
    means = gmm.means_.flatten()
    sds = np.sqrt(gmm.covariances_).flatten()
    chosen_i = np.argmax(means)  #Mixture with largest mean

    log_params = scipy.stats.lognorm.fit(bg_values[bg_values < x_max],
                                         f0=sds[chosen_i],
                                         fscale=np.exp(means[chosen_i]))
    #all_log_params[bigwig] = log_params

    #Mode of distribution
    mode = scipy.optimize.fmin(
        lambda x: -scipy.stats.lognorm.pdf(x, *log_params), 0, disp=False)[0]
    logger.debug("- Mode estimated at: {0}".format(mode))
    pseudo = mode / 2.0  #pseudo is half the mode
    args.pseudo = pseudo
    logger.debug("Pseudocount estimated at: {0}".format(round(args.pseudo, 5)))

    # Estimate theoretical normal for threshold
    leftside_x = np.linspace(
        scipy.stats.lognorm(*log_params).ppf([0.01]), mode, 100)
    leftside_pdf = scipy.stats.lognorm.pdf(leftside_x, *log_params)

    #Flip over
    mirrored_x = np.concatenate([leftside_x,
                                 np.max(leftside_x) + leftside_x]).flatten()
    mirrored_pdf = np.concatenate([leftside_pdf, leftside_pdf[::-1]]).flatten()
    popt, cov = scipy.optimize.curve_fit(
        lambda x, std, sc: sc * scipy.stats.norm.pdf(x, mode, std), mirrored_x,
        mirrored_pdf)
    norm_params = (mode, popt[0])
    logger.debug("Theoretical normal parameters: {0}".format(norm_params))

    #Set threshold for bound/unbound
    threshold = round(
        scipy.stats.norm.ppf(1 - args.bound_pvalue, *norm_params), 5)

    args.thresholds = {bigwig: threshold for bigwig in args.cond_names}
    logger.stats("- Threshold estimated at: {0}".format(threshold))

    #Only plot if args.debug is True
    if args.debug:

        #Plot fit
        fig, ax = plt.subplots(1, 1)
        ax.hist(bg_values[bg_values < x_max],
                bins='auto',
                density=True,
                label="Observed score distribution")

        xvals = np.linspace(0, x_max, 1000)
        log_probas = scipy.stats.lognorm.pdf(xvals, *log_params)
        ax.plot(xvals, log_probas, label="Log-normal fit", color="orange")

        #Theoretical normal
        norm_probas = scipy.stats.norm.pdf(xvals, *norm_params)
        ax.plot(xvals,
                norm_probas * (np.max(log_probas) / np.max(norm_probas)),
                color="grey",
                linestyle="--",
                label="Theoretical normal")

        ax.axvline(threshold, color="black", label="Bound/unbound threshold")
        ymax = plt.ylim()[1]
        ax.text(threshold, ymax, "\n {0:.3f}".format(threshold), va="top")

        #Decorate plot
        plt.title("Score distribution")
        plt.xlabel("Bigwig score")
        plt.ylabel("Density")
        plt.legend(fontsize=8)
        plt.xlim((0, x_max))

        figure_pdf.savefig(fig)
        plt.close(fig)

    ############ Foldchanges between conditions ################
    logger.comment("")
    log2fc_params = {}
    if len(args.signals) > 1:
        logger.info(
            "Calculating background log2 fold-changes between conditions")

        for (bigwig1, bigwig2) in comparisons:  #cond1, cond2
            logger.info("- {0} / {1}".format(bigwig1, bigwig2))

            #Estimate background log2fc
            scores1 = np.copy(background["signal"][bigwig1])
            scores2 = np.copy(background["signal"][bigwig2])

            included = np.logical_or(scores1 > 0, scores2 > 0)
            scores1 = scores1[included]
            scores2 = scores2[included]

            #Calculate background log2fc normal disitribution
            log2fcs = np.log2(
                np.true_divide(scores1 + args.pseudo, scores2 + args.pseudo))

            lower, upper = np.percentile(log2fcs, [1, 99])
            log2fcs_fit = log2fcs[np.logical_and(log2fcs >= lower,
                                                 log2fcs <= upper)]

            norm_params = scipy.stats.norm.fit(log2fcs_fit)

            logger.debug(
                "({0} / {1}) Background log2fc normal distribution: {2}".
                format(bigwig1, bigwig2, norm_params))
            log2fc_params[(bigwig1, bigwig2)] = norm_params

            #Plot background log2fc to figures
            fig, ax = plt.subplots(1, 1)
            plt.hist(log2fcs,
                     density=True,
                     bins='auto',
                     label="Background log2fc ({0} / {1})".format(
                         bigwig1, bigwig2))

            xvals = np.linspace(plt.xlim()[0], plt.xlim()[1], 100)
            pdf = scipy.stats.norm.pdf(xvals,
                                       *log2fc_params[(bigwig1, bigwig2)])
            plt.plot(xvals, pdf, label="Normal distribution fit")
            plt.title("Background log2FCs ({0} / {1})".format(
                bigwig1, bigwig2))

            plt.xlabel("Log2 fold change")
            plt.ylabel("Density")
            if args.debug:
                figure_pdf.savefig(fig, bbox_inches='tight')

                f = open(
                    os.path.join(
                        args.outdir,
                        "{0}_{1}_log2fcs.txt".format(bigwig1, bigwig2)), "w")
                f.write("\n".join([str(val) for val in log2fcs]))
                f.close()
            plt.close()

    background = None  #free up space

    #-------------------------------------------------------------------------------------------------------------#
    #----------------------------- Read total sites per TF to estimate bound/unbound -----------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    logger.comment("")
    logger.info("Processing scanned TFBS individually")

    #Getting bindetect table ready
    info_columns = ["total_tfbs"]
    info_columns.extend([
        "{0}_{1}".format(cond, metric)
        for (cond, metric
             ) in itertools.product(args.cond_names, ["threshold", "bound"])
    ])
    info_columns.extend([
        "{0}_{1}_{2}".format(comparison[0], comparison[1], metric)
        for (comparison,
             metric) in itertools.product(comparisons, ["change", "pvalue"])
    ])

    cols = len(info_columns)
    rows = len(motif_names)
    info_table = pd.DataFrame(np.zeros((rows, cols)),
                              columns=info_columns,
                              index=motif_names)

    #Starting calculations
    results = []
    if args.cores == 1:
        for name in motif_names:
            logger.info("- {0}".format(name))
            results.append(process_tfbs(name, args, log2fc_params))
    else:
        logger.debug("Sending jobs to worker pool")

        task_list = [
            pool.apply_async(process_tfbs, (
                name,
                args,
                log2fc_params,
            )) for name in motif_names
        ]
        monitor_progress(task_list,
                         logger)  #will not exit before all jobs are done
        results = [task.get() for task in task_list]

    logger.info("Concatenating results from subsets")
    info_table = pd.concat(results)  #pandas tables

    pool.terminate()
    pool.join()

    logger.stop_logger_queue()

    #-------------------------------------------------------------------------------------------------------------#
    #------------------------------------------------ Cluster TFBS -----------------------------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    clustering = RegionCluster(TF_overlaps)
    clustering.cluster()

    #Convert full ids to alt ids
    convert = {motif.prefix: motif.name for motif in motif_list}
    for cluster in clustering.clusters:
        for name in convert:
            clustering.clusters[cluster]["cluster_name"] = clustering.clusters[
                cluster]["cluster_name"].replace(name, convert[name])

    #Write out distance matrix
    matrix_out = os.path.join(args.outdir, args.prefix + "_distances.txt")
    clustering.write_distance_mat(matrix_out)

    #-------------------------------------------------------------------------------------------------------------#
    #----------------------------------------- Write all_bindetect file ------------------------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    logger.comment("")
    logger.info("Writing all_bindetect files")

    #Add columns of name / motif_id / prefix
    names = []
    ids = []
    for prefix in info_table.index:
        motif = [motif for motif in motif_list if motif.prefix == prefix]
        names.append(motif[0].name)
        ids.append(motif[0].id)

    info_table.insert(0, "output_prefix", info_table.index)
    info_table.insert(1, "name", names)
    info_table.insert(2, "motif_id", ids)

    #info_table.insert(3, "motif_logo", [os.path.join("motif_logos", os.path.basename(logo_filenames[prefix])) for prefix in info_table["output_prefix"]])	#add relative path to logo

    #Add cluster to info_table
    cluster_names = []
    for name in info_table.index:
        for cluster in clustering.clusters:
            if name in clustering.clusters[cluster]["member_names"]:
                cluster_names.append(
                    clustering.clusters[cluster]["cluster_name"])

    info_table.insert(3, "cluster", cluster_names)

    #Cluster table on motif clusters
    info_table_clustered = info_table.groupby(
        "cluster").mean()  #mean of each column
    info_table_clustered.reset_index(inplace=True)

    #Map correct type
    info_table["total_tfbs"] = info_table["total_tfbs"].map(int)
    for condition in args.cond_names:
        info_table[condition + "_bound"] = info_table[condition +
                                                      "_bound"].map(int)

    #### Write excel ###
    bindetect_excel = os.path.join(args.outdir, args.prefix + "_results.xlsx")
    writer = pd.ExcelWriter(bindetect_excel, engine='xlsxwriter')

    #Tables
    info_table.to_excel(writer, index=False, sheet_name="Individual motifs")
    info_table_clustered.to_excel(writer,
                                  index=False,
                                  sheet_name="Motif clusters")

    for sheet in writer.sheets:
        worksheet = writer.sheets[sheet]
        n_rows = worksheet.dim_rowmax
        n_cols = worksheet.dim_colmax
        worksheet.autofilter(0, 0, n_rows, n_cols)
    writer.save()

    #Format comparisons
    for (cond1, cond2) in comparisons:
        base = cond1 + "_" + cond2
        info_table[base + "_change"] = info_table[base + "_change"].round(5)
        info_table[base + "_pvalue"] = info_table[base + "_pvalue"].map(
            "{:.5E}".format, na_action="ignore")

    #Write bindetect results tables
    #info_table.insert(0, "TF_name", info_table.index)	 #Set index as first column
    bindetect_out = os.path.join(args.outdir, args.prefix + "_results.txt")
    info_table.to_csv(bindetect_out,
                      sep="\t",
                      index=False,
                      header=True,
                      na_rep="NA")

    #-------------------------------------------------------------------------------------------------------------#
    #------------------------------------------- Make BINDetect plot ---------------------------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    if no_conditions > 1:
        logger.info("Creating BINDetect plot(s)")

        #Fill NAs from info_table to enable plotting of log2fcs (NA -> 0 change)
        change_cols = [col for col in info_table.columns if "_change" in col]
        pvalue_cols = [col for col in info_table.columns if "_pvalue" in col]
        info_table[change_cols] = info_table[change_cols].fillna(0)
        info_table[pvalue_cols] = info_table[pvalue_cols].fillna(1)

        #Plotting bindetect per comparison
        for (cond1, cond2) in comparisons:

            logger.info("- {0} / {1}".format(cond1, cond2))
            base = cond1 + "_" + cond2

            #Define which motifs to show
            xvalues = info_table[base + "_change"].astype(float)
            yvalues = info_table[base + "_pvalue"].astype(float)
            y_min = np.percentile(yvalues[yvalues > 0],
                                  5)  #5% smallest pvalues
            x_min, x_max = np.percentile(
                xvalues, [5, 95])  #5% smallest and largest changes

            #Make copy of motifs and fill in with metadata
            comparison_motifs = [
                motif for motif in motif_list if motif.strand == "+"
            ]  #copy.deepcopy(motif_list) - swig pickle error, just overwrite motif_list
            for motif in comparison_motifs:
                name = motif.prefix
                motif.change = float(info_table.at[name, base + "_change"])
                motif.pvalue = float(info_table.at[name, base + "_pvalue"])
                motif.logpvalue = -np.log10(
                    motif.pvalue) if motif.pvalue > 0 else -np.log10(1e-308)

                #Assign each motif to group
                if motif.change < x_min or motif.change > x_max or motif.pvalue < y_min:
                    if motif.change < 0:
                        motif.group = cond2 + "_up"
                    if motif.change > 0:
                        motif.group = cond1 + "_up"
                else:
                    motif.group = "n.s."

            #Bindetect plot
            fig = plot_bindetect(comparison_motifs, clustering, [cond1, cond2],
                                 args)
            figure_pdf.savefig(fig, bbox_inches='tight')
            plt.close(fig)

            #Interactive BINDetect plot
            html_out = os.path.join(args.outdir, "bindetect_" + base + ".html")
            plot_interactive_bindetect(comparison_motifs, [cond1, cond2],
                                       html_out)

    #-------------------------------------------------------------------------------------------------------------#
    #----------------------------- Make heatmap across conditions (for debugging)---------------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    if args.debug:

        mean_columns = [cond + "_mean_score" for cond in args.cond_names]
        heatmap_table = info_table[mean_columns]
        heatmap_table.index = info_table["output_prefix"]

        #Decide fig size
        rows, cols = heatmap_table.shape
        figsize = (7 + cols, max(10, rows / 8.0))
        cm = sns.clustermap(
            heatmap_table,
            figsize=figsize,
            z_score=0,  #zscore for rows
            col_cluster=False,  #do not cluster condition columns
            yticklabels=True,  #show all row annotations
            xticklabels=True,
            cbar_pos=(0, 0, .4, .005),
            dendrogram_ratio=(0.3, 0.01),
            cbar_kws={
                "orientation": "horizontal",
                'label': 'Row z-score'
            },
            method="single")

        #Adjust width of columns
        #hm = cm.ax_heatmap.get_position()
        #cm.ax_heatmap.set_position([hm.x0, hm.y0, cols * 3 * hm.height / rows, hm.height]) 	#aspect should be equal

        plt.setp(cm.ax_heatmap.get_xticklabels(),
                 fontsize=8,
                 rotation=45,
                 ha="right")
        plt.setp(cm.ax_heatmap.get_yticklabels(), fontsize=5)

        cm.ax_col_dendrogram.set_title('Mean scores across conditions',
                                       fontsize=20)
        cm.ax_heatmap.set_ylabel("Transcription factor motifs",
                                 fontsize=15,
                                 rotation=270)
        #cm.ax_heatmap.set_title('Conditions')
        #cm.fig.suptitle('Mean scores across conditions')
        #cm.cax.set_visible(False)

        #Save to output pdf
        plt.tight_layout()
        figure_pdf.savefig(cm.fig, bbox_inches='tight')
        plt.close(cm.fig)

    #-------------------------------------------------------------------------------------------------------------#
    #-------------------------------------------------- Wrap up---------------------------------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    figure_pdf.close()
    logger.end()
Exemplo n.º 4
0
def bias_correction(regions_list, params, bias_obj):
    """ Corrects bias in cutsites (from bamfile) using estimated bias """

    logger = TobiasLogger("", params.verbosity, params.log_q)

    bam_f = params.bam
    fasta_f = params.genome
    k_flank = params.k_flank
    read_shift = params.read_shift
    L = 2 * k_flank + 1
    w = params.window
    f = int(w / 2.0)
    qs = params.qs

    f_extend = k_flank + f

    strands = ["forward", "reverse"]
    pre_bias = {strand: SequenceMatrix.create(L, "PWM") for strand in strands}
    post_bias = {strand: SequenceMatrix.create(L, "PWM") for strand in strands}

    #Open bamfile and fasta
    bam_obj = pysam.AlignmentFile(bam_f, "rb")
    fasta_obj = pysam.FastaFile(fasta_f)

    out_signals = {}

    #Go through each region
    for region_obj in regions_list:

        region_obj.extend_reg(f_extend)
        reg_len = region_obj.get_length()  #length including flanking
        reg_key = (region_obj.chrom, region_obj.start + f_extend,
                   region_obj.end - f_extend)  #output region
        out_signals[reg_key] = {
            "uncorrected": {},
            "bias": {},
            "expected": {},
            "corrected": {}
        }

        ################################
        ####### Uncorrected reads ######
        ################################

        #Get cutsite positions for each read
        read_lst = ReadList().from_bam(bam_obj, region_obj)
        for read in read_lst:
            read.get_cutsite(read_shift)
        logger.spam("Read {0} reads from region {1}".format(
            len(read_lst), region_obj))

        #Exclude reads with cutsites outside region
        read_lst = ReadList([
            read for read in read_lst if read.cutsite > region_obj.start
            and read.cutsite < region_obj.end
        ])
        for_lst, rev_lst = read_lst.split_strands()
        read_lst_strand = {"forward": for_lst, "reverse": rev_lst}

        for strand in strands:
            out_signals[reg_key]["uncorrected"][strand] = read_lst_strand[
                strand].signal(region_obj)
            out_signals[reg_key]["uncorrected"][strand] = np.round(
                out_signals[reg_key]["uncorrected"][strand], 5)

        ################################
        ###### Estimation of bias ######
        ################################

        #Get sequence in this region
        sequence_obj = GenomicSequence(region_obj).from_fasta(fasta_obj)

        #Score sequence using forward/reverse motifs
        for strand in strands:
            if strand == "forward":
                seq = sequence_obj.sequence
                bias = bias_obj.bias[strand].score_sequence(seq)
            elif strand == "reverse":
                seq = sequence_obj.revcomp
                bias = bias_obj.bias[strand].score_sequence(seq)[::-1]  #3'-5'

            out_signals[reg_key]["bias"][strand] = np.nan_to_num(
                bias)  #convert any nans to 0

        #################################
        ###### Correction of reads ######
        #################################

        reg_end = reg_len - k_flank
        step = 10
        overlaps = int(params.window / step)
        window_starts = list(range(k_flank, reg_end - params.window, step))
        window_ends = list(range(k_flank + params.window, reg_end, step))
        window_ends[-1] = reg_len
        windows = list(zip(window_starts, window_ends))

        for strand in strands:

            ########### Estimate bias threshold ###########
            bias_predictions = np.zeros((overlaps, reg_len))
            row = 0

            for window in windows:

                signal_w = out_signals[reg_key]["uncorrected"][strand][
                    window[0]:window[1]]
                bias_w = out_signals[reg_key]["bias"][strand][
                    window[0]:window[1]]

                signalmax = np.max(signal_w)
                #biasmin = np.min(bias_w)
                #biasmax = np.max(bias_w)

                if signalmax > 0:
                    try:
                        popt, pcov = curve_fit(relu, bias_w, signal_w)
                        bias_predict = relu(bias_w, *popt)

                    except (OptimizeWarning, RuntimeError):
                        cut_positions = np.logical_not(np.isclose(signal_w, 0))
                        bias_min = np.min(bias_w[cut_positions])
                        bias_predict = bias_w - bias_min
                        bias_predict[bias_predict < 0] = 0

                    if np.max(bias_predict) > 0:
                        bias_predict = bias_predict / np.max(bias_predict)
                else:
                    bias_predict = np.zeros(window[1] - window[0])

                bias_predictions[row, window[0]:window[1]] = bias_predict
                row += 1 if row < overlaps - 1 else 0

            bias_prediction = np.mean(bias_predictions, axis=0)
            bias = bias_prediction

            ######## Calculate expected signal ######
            signal_sum = fast_rolling_math(
                out_signals[reg_key]["uncorrected"][strand], w, "sum")
            signal_sum[np.isnan(signal_sum)] = 0  #f-width ends of region

            bias_sum = fast_rolling_math(bias, w, "sum")  #ends of arr are nan
            nulls = np.logical_or(np.isclose(bias_sum, 0), np.isnan(bias_sum))
            bias_sum[nulls] = 1  # N-regions will give stretches of 0-bias
            bias_probas = bias / bias_sum
            bias_probas[nulls] = 0  #nan to 0

            out_signals[reg_key]["expected"][strand] = signal_sum * bias_probas

            ######## Correct signal ########
            out_signals[reg_key]["uncorrected"][
                strand] *= bias_obj.correction_factor
            out_signals[reg_key]["expected"][
                strand] *= bias_obj.correction_factor
            out_signals[reg_key]["corrected"][
                strand] = out_signals[reg_key]["uncorrected"][
                    strand] - out_signals[reg_key]["expected"][strand]

            ######## Rescale signal to fit uncorrected sum ########
            uncorrected_sum = fast_rolling_math(
                out_signals[reg_key]["uncorrected"][strand], w, "sum")
            uncorrected_sum[np.isnan(uncorrected_sum)] = 0
            corrected_sum = fast_rolling_math(
                np.abs(out_signals[reg_key]["corrected"][strand]), w,
                "sum")  #negative values count as positive
            corrected_sum[np.isnan(corrected_sum)] = 0

            #Positive signal left after correction
            corrected_pos = np.copy(out_signals[reg_key]["corrected"][strand])
            corrected_pos[corrected_pos < 0] = 0
            corrected_pos_sum = fast_rolling_math(corrected_pos, w, "sum")
            corrected_pos_sum[np.isnan(corrected_pos_sum)] = 0
            corrected_neg_sum = corrected_sum - corrected_pos_sum

            #The corrected sum is less than the signal sum, so scale up positive cuts
            zero_sum = corrected_pos_sum == 0
            corrected_pos_sum[zero_sum] = np.nan  #allow for zero division
            scale_factor = (uncorrected_sum -
                            corrected_neg_sum) / corrected_pos_sum
            scale_factor[
                zero_sum] = 1  #Scale factor is 1 (which will be multiplied to the 0 values)
            scale_factor[scale_factor < 1] = 1  #Only scale up if needed
            pos_bool = out_signals[reg_key]["corrected"][strand] > 0
            out_signals[reg_key]["corrected"][strand][
                pos_bool] *= scale_factor[pos_bool]

        #######################################
        ########   Verify correction   ########
        #######################################

        #Verify correction across all reads
        for strand in strands:
            for idx in range(k_flank, reg_len - k_flank - 1):
                if idx > k_flank and idx < reg_len - k_flank:

                    orig = out_signals[reg_key]["uncorrected"][strand][idx]
                    correct = out_signals[reg_key]["corrected"][strand][idx]

                    if orig != 0 or correct != 0:  #if both are 0, don't add to pre/post bias
                        if strand == "forward":
                            kmer = sequence_obj.sequence[idx - k_flank:idx +
                                                         k_flank + 1]
                        else:
                            kmer = sequence_obj.revcomp[reg_len - idx -
                                                        k_flank - 1:reg_len -
                                                        idx + k_flank]

                        #Save kmer for bias correction verification
                        pre_bias[strand].add_sequence(kmer, orig)
                        post_bias[strand].add_sequence(kmer, correct)

        #######################################
        ########    Write to queue    #########
        #######################################

        #Set size back to original
        for track in out_signals[reg_key]:
            for strand in out_signals[reg_key][track]:
                out_signals[reg_key][track][strand] = out_signals[reg_key][
                    track][strand][f_extend:-f_extend]

        #Calculate "both" if split_strands == False
        if params.split_strands == False:
            for track in out_signals[reg_key]:
                out_signals[reg_key][track]["both"] = out_signals[reg_key][
                    track]["forward"] + out_signals[reg_key][track]["reverse"]

        #Send to queue
        strands_to_write = ["forward", "reverse"
                            ] if params.split_strands == True else ["both"]
        for track in out_signals[reg_key]:

            #Send to writer per strand
            for strand in strands_to_write:
                key = "{0}:{1}".format(track, strand)
                logger.spam(
                    "Sending {0} signal from region {1} to writer queue".
                    format(key, reg_key))
                qs[key].put(
                    (key, reg_key, out_signals[reg_key][track][strand]))

        #Sent to qs - delete from this process
        out_signals[reg_key] = None

    bam_obj.close()
    fasta_obj.close()

    gc.collect()

    return ([pre_bias, post_bias])
Exemplo n.º 5
0
def bias_estimation(regions_list, params):
    """ Estimates bias of insertions within regions """

    #Info on run
    bam_f = params.bam
    fasta_f = params.genome
    k_flank = params.k_flank
    bg_shift = params.bg_shift
    read_shift = params.read_shift
    L = 2 * k_flank + 1

    logger = TobiasLogger("", params.verbosity,
                          params.log_q)  #sending all logger calls to log_q

    #Open objects for reading
    bam_obj = pysam.AlignmentFile(bam_f, "rb")
    fasta_obj = pysam.FastaFile(fasta_f)
    chrom_lengths = dict(
        zip(bam_obj.references,
            bam_obj.lengths))  #Chromosome boundaries from bam_obj

    bias_obj = AtacBias(L, params.score_mat)

    strands = ["forward", "reverse"]

    #Estimate bias at each region
    for region in regions_list:

        read_lst = ReadList().from_bam(bam_obj, region)
        for read in read_lst:
            read.get_cutsite(read_shift)

        ## Kmer cutting bias ##
        if len(read_lst) > 0:

            #Extract sequence
            extended_region = region.extend_reg(
                k_flank + bg_shift)  #Extend to allow full kmers
            extended_region.check_boundary(chrom_lengths, "cut")
            sequence_obj = GenomicSequence(extended_region).from_fasta(
                fasta_obj)

            #Split reads forward/reverse
            for_lst, rev_lst = read_lst.split_strands()
            read_lst_strand = {"forward": for_lst, "reverse": rev_lst}
            logger.spam(
                "Region: {0}. Forward reads: {1}. Reverse reads: {2}".format(
                    region, len(for_lst), len(rev_lst)))

            for strand in strands:

                #Map reads to positions
                read_per_pos = {}
                for read in read_lst_strand[strand]:
                    if read.cigartuples is not None:
                        first_tuple = read.cigartuples[
                            -1] if read.is_reverse else read.cigartuples[0]
                        if first_tuple[0] == 0 and first_tuple[
                                1] > params.k_flank + max(
                                    np.abs(params.read_shift)
                                ):  #Only include non-clipped reads
                            read_per_pos[read.cutsite] = read_per_pos.get(
                                read.cutsite, []) + [read]

                #Get kmer for each position
                for cutsite in read_per_pos:
                    if cutsite > region.start and cutsite < region.end:  #only reads within borders
                        read = read_per_pos[cutsite][
                            0]  #use first read in list to establish kmer
                        no_cut = min(
                            len(read_per_pos[cutsite]), 10
                        )  #put cap on number of cuts to limit influence of outliers

                        read.get_kmer(sequence_obj, k_flank)

                        bias_obj.bias[strand].add_sequence(read.kmer, no_cut)
                        read.shift_cutsite(
                            -bg_shift
                        )  #upstream of read; ensures that bg is not within fragment
                        read.get_kmer(
                            sequence_obj,
                            k_flank)  #kmer updated to kmer for shifted read
                        bias_obj.bias[strand].add_background(read.kmer, no_cut)

                        bias_obj.no_reads += no_cut
    bam_obj.close()
    fasta_obj.close()

    return (bias_obj)  #object containing information collected on bias