def run_atacorrect(args): """ Function for bias correction of input .bam files Calls functions in ATACorrect_functions and several internal classes """ #Test if required arguments were given: if args.bam == None: sys.exit("Error: No .bam-file given") if args.genome == None: sys.exit("Error: No .fasta-file given") if args.peaks == None: sys.exit("Error: No .peaks-file given") #Adjust some parameters depending on input args.prefix = os.path.splitext(os.path.basename( args.bam))[0] if args.prefix == None else args.prefix args.outdir = os.path.abspath( args.outdir) if args.outdir != None else os.path.abspath(os.getcwd()) #Set output bigwigs based on input tracks = ["uncorrected", "bias", "expected", "corrected"] tracks = [track for track in tracks if track not in args.track_off] # switch off printing if args.split_strands == True: strands = ["forward", "reverse"] else: strands = ["both"] output_bws = {} for track in tracks: output_bws[track] = {} for strand in strands: elements = [args.prefix, track] if strand == "both" else [ args.prefix, track, strand ] output_bws[track][strand] = { "fn": os.path.join(args.outdir, "{0}.bw".format("_".join(elements))) } #Set all output files bam_out = os.path.join(args.outdir, args.prefix + "_atacorrect.bam") bigwigs = [ output_bws[track][strand]["fn"] for (track, strand) in itertools.product(tracks, strands) ] figures_f = os.path.join(args.outdir, "{0}_atacorrect.pdf".format(args.prefix)) output_files = bigwigs + [figures_f] output_files = list(OrderedDict.fromkeys( output_files)) #remove duplicates due to "both" option strands = ["forward", "reverse"] #----------------------------------------------------------------------------------------------------# # Print info on run #----------------------------------------------------------------------------------------------------# logger = TobiasLogger("ATACorrect", args.verbosity) logger.begin() parser = add_atacorrect_arguments(argparse.ArgumentParser()) logger.arguments_overview(parser, args) logger.output_files(output_files) args.cores = check_cores(args.cores, logger) #----------------------------------------------------------------------------------------------------# # Test input file availability for reading #----------------------------------------------------------------------------------------------------# logger.info("----- Processing input data -----") logger.debug("Testing input file availability") check_files([args.bam, args.genome, args.peaks], "r") logger.debug("Testing output directory/file writeability") make_directory(args.outdir) check_files(output_files, "w") #Open pdf for figures figure_pdf = PdfPages(figures_f, keep_empty=False) #----------------------------------------------------------------------------------------------------# # Read information in bam/fasta #----------------------------------------------------------------------------------------------------# logger.info("Reading info from .bam file") bamfile = pysam.AlignmentFile(args.bam, "rb") if bamfile.has_index() == False: logger.warning("No index found for bamfile - creating one via pysam.") pysam.index(args.bam) bam_references = bamfile.references #chromosomes in correct order bam_chrom_info = dict(zip(bamfile.references, bamfile.lengths)) logger.debug("bam_chrom_info: {0}".format(bam_chrom_info)) bamfile.close() logger.info("Reading info from .fasta file") fastafile = pysam.FastaFile(args.genome) fasta_chrom_info = dict(zip(fastafile.references, fastafile.lengths)) logger.debug("fasta_chrom_info: {0}".format(fasta_chrom_info)) fastafile.close() #Compare chrom lengths chrom_in_common = set(bam_chrom_info.keys()).intersection( fasta_chrom_info.keys()) for chrom in chrom_in_common: bamlen = bam_chrom_info[chrom] fastalen = fasta_chrom_info[chrom] if bamlen != fastalen: logger.warning("(Fastafile)\t{0} has length {1}".format( chrom, fasta_chrom_info[chrom])) logger.warning("(Bamfile)\t{0} has length {1}".format( chrom, bam_chrom_info[chrom])) sys.exit( "Error: .bam and .fasta have different chromosome lengths. Please make sure the genome file is similar to the one used in mapping." ) #Subset bam_references to those for which there are sequences in fasta chrom_not_in_fasta = set(bam_references) - set(fasta_chrom_info.keys()) if len(chrom_not_in_fasta) > 1: logger.warning( "The following contigs in --bam did not have sequences in --fasta: {0}. NOTE: These contigs will be skipped in calculation and output." .format(chrom_not_in_fasta)) bam_references = [ref for ref in bam_references if ref in fasta_chrom_info] chrom_in_common = [ref for ref in chrom_in_common if ref in bam_references] #----------------------------------------------------------------------------------------------------# # Read regions from bedfiles #----------------------------------------------------------------------------------------------------# logger.info("Processing input/output regions") #Chromosomes included in analysis genome_regions = RegionList().from_list([ OneRegion([chrom, 0, bam_chrom_info[chrom]]) for chrom in bam_references if not "M" in chrom ]) #full genome length chrom_in_common = [chrom for chrom in chrom_in_common if "M" not in chrom] logger.debug("CHROMS\t{0}".format("; ".join( ["{0} ({1})".format(reg.chrom, reg.end) for reg in genome_regions]))) genome_bp = sum([region.get_length() for region in genome_regions]) # Process peaks peak_regions = RegionList().from_bed(args.peaks) peak_regions.merge() for i in range(len(peak_regions) - 1, -1, -1): region = peak_regions[i] peak_regions[i] = region.check_boundary( bam_chrom_info, "cut") #regions are cut/removed from list if peak_regions[i] is None: logger.warning( "Peak region {0} was removed at it is either out of bounds or not in the chromosomes given in genome/bam." .format(region.tup(), i + 1)) del peak_regions[i] nonpeak_regions = deepcopy(genome_regions).subtract(peak_regions) # Process specific input regions if given if args.regions_in != None: input_regions = RegionList().from_bed(args.regions_in) input_regions.merge() input_regions.apply_method(OneRegion.check_boundary, bam_chrom_info, "cut") else: input_regions = nonpeak_regions # Process specific output regions if args.regions_out != None: output_regions = RegionList().from_bed(args.regions_out) else: output_regions = deepcopy(peak_regions) #Extend regions to make sure extend + flanking for window/flank are within boundaries flank_extend = args.k_flank + int(args.window / 2.0) output_regions.apply_method(OneRegion.extend_reg, args.extend + flank_extend) output_regions.merge() output_regions.apply_method(OneRegion.check_boundary, bam_chrom_info, "cut") output_regions.apply_method( OneRegion.extend_reg, -flank_extend ) #Cut to needed size knowing that the region will be extended in function #Remove blacklisted regions and chromosomes not in common blacklist_regions = RegionList().from_bed( args.blacklist) if args.blacklist != None else RegionList( []) #fill in with regions from args.blacklist regions_dict = { "genome": genome_regions, "input_regions": input_regions, "output_regions": output_regions, "peak_regions": peak_regions, "nonpeak_regions": nonpeak_regions, "blacklist_regions": blacklist_regions } for sub in [ "input_regions", "output_regions", "peak_regions", "nonpeak_regions" ]: regions_sub = regions_dict[sub] regions_sub.subtract(blacklist_regions) regions_sub = regions_sub.apply_method(OneRegion.split_region, 50000) regions_sub.keep_chroms(chrom_in_common) regions_dict[sub] = regions_sub #write beds to look at in igv #input_regions.write_bed(os.path.join(args.outdir, "input_regions.bed")) #output_regions.write_bed(os.path.join(args.outdir, "output_regions.bed")) #peak_regions.write_bed(os.path.join(args.outdir, "peak_regions.bed")) #nonpeak_regions.write_bed(os.path.join(args.outdir, "nonpeak_regions.bed")) #Sort according to order in bam_references: output_regions.loc_sort(bam_references) chrom_order = {bam_references[i]: i for i in range(len(bam_references)) } #for use later when sorting output #### Statistics about regions #### genome_bp = sum([region.get_length() for region in regions_dict["genome"]]) for key in regions_dict: total_bp = sum([region.get_length() for region in regions_dict[key]]) logger.stats("{0}: {1} regions | {2} bp | {3:.2f}% coverage".format( key, len(regions_dict[key]), total_bp, total_bp / genome_bp * 100)) #Estallish variables for regions to be used input_regions = regions_dict["input_regions"] output_regions = regions_dict["output_regions"] peak_regions = regions_dict["peak_regions"] nonpeak_regions = regions_dict["nonpeak_regions"] #Exit if no input/output regions were found if len(input_regions) == 0 or len(output_regions) == 0 or len( peak_regions) == 0 or len(nonpeak_regions) == 0: logger.error("No regions found - exiting!") sys.exit() #----------------------------------------------------------------------------------------------------# # Estimate normalization factors #----------------------------------------------------------------------------------------------------# #Setup logger queue logger.debug("Setting up listener for log") logger.start_logger_queue() args.log_q = logger.queue #----------------------------------------------------------------------------------------------------# logger.comment("") logger.info("----- Estimating normalization factors -----") #If normalization is to be calculated if not args.norm_off: #Reads in peaks/nonpeaks logger.info("Counting reads in peak regions") peak_region_chunks = peak_regions.chunks(args.split) reads_peaks = sum( run_parallel(count_reads, peak_region_chunks, [args], args.cores, logger)) logger.comment("") logger.info("Counting reads in nonpeak regions") nonpeak_region_chunks = nonpeak_regions.chunks(args.split) reads_nonpeaks = sum( run_parallel(count_reads, nonpeak_region_chunks, [args], args.cores, logger)) reads_total = reads_peaks + reads_nonpeaks logger.stats("TOTAL_READS\t{0}".format(reads_total)) logger.stats("PEAK_READS\t{0}".format(reads_peaks)) logger.stats("NONPEAK_READS\t{0}".format(reads_nonpeaks)) lib_norm = 10000000 / reads_total frip = reads_peaks / reads_total correct_factor = lib_norm * (1 / frip) logger.stats("LIB_NORM\t{0:.5f}".format(lib_norm)) logger.stats("FRiP\t{0:.5f}".format(frip)) else: logger.info("Normalization was switched off") correct_factor = 1.0 logger.stats("CORRECTION_FACTOR:\t{0:.5f}".format(correct_factor)) #----------------------------------------------------------------------------------------------------# # Estimate sequence bias #----------------------------------------------------------------------------------------------------# logger.comment("") logger.info("Started estimation of sequence bias...") input_region_chunks = input_regions.chunks( args.split) #split to 100 chunks (also decides the step of output) out_lst = run_parallel(bias_estimation, input_region_chunks, [args], args.cores, logger) #Output is list of AtacBias objects #Join objects estimated_bias = out_lst[0] #initialize object with first output for output in out_lst[1:]: estimated_bias.join( output ) #bias object contains bias/background SequenceMatrix objects logger.debug("Bias estimated\tno_reads: {0}".format( estimated_bias.no_reads)) #----------------------------------------------------------------------------------------------------# # Join estimations from all chunks of regions #----------------------------------------------------------------------------------------------------# bias_obj = estimated_bias bias_obj.correction_factor = correct_factor ### Bias motif ### logger.info("Finalizing bias motif for scoring") for strand in strands: bias_obj.bias[strand].prepare_mat() logger.debug("Saving pssm to figure pdf") fig = plot_pssm(bias_obj.bias[strand].pssm, "Tn5 insertion bias of reads ({0})".format(strand)) figure_pdf.savefig(fig) #Write bias motif to pickle out_f = os.path.join(args.outdir, args.prefix + "_AtacBias.pickle") logger.debug("Saving bias object to pickle ({0})".format(out_f)) bias_obj.to_pickle(out_f) #----------------------------------------------------------------------------------------------------# # Correct read bias and write to bigwig #----------------------------------------------------------------------------------------------------# logger.comment("") logger.info("----- Correcting reads from .bam within output regions -----") output_regions.loc_sort(bam_references) #sort in order of references output_regions_chunks = output_regions.chunks(args.split) no_tasks = float(len(output_regions_chunks)) chunk_sizes = [len(chunk) for chunk in output_regions_chunks] logger.debug("All regions chunked: {0} ({1})".format( len(output_regions), chunk_sizes)) ### Create key-file linking for bigwigs key2file = {} for track in output_bws: for strand in output_bws[track]: filename = output_bws[track][strand]["fn"] key = "{}:{}".format(track, strand) key2file[key] = filename #Start correction/write cores n_bigwig = len(key2file.values()) writer_cores = min(n_bigwig, max( 1, int(args.cores * 0.1))) #at most one core per bigwig or 10% of cores (or 1) worker_cores = max(1, args.cores - writer_cores) logger.debug("Worker cores: {0}".format(worker_cores)) logger.debug("Writer cores: {0}".format(writer_cores)) worker_pool = mp.Pool(processes=worker_cores) writer_pool = mp.Pool(processes=writer_cores) manager = mp.Manager() #Start bigwig file writers writer_tasks = [] header = [(chrom, bam_chrom_info[chrom]) for chrom in bam_references] key_chunks = [ list(key2file.keys())[i::writer_cores] for i in range(writer_cores) ] qs_list = [] qs = {} for chunk in key_chunks: logger.debug("Creating writer queue for {0}".format(chunk)) q = manager.Queue() qs_list.append(q) files = [key2file[key] for key in chunk] writer_tasks.append( writer_pool.apply_async(bigwig_writer, args=(q, dict(zip(chunk, files)), header, output_regions, args)) ) #, callback = lambda x: finished.append(x) print("Writing time: {0}".format(x))) for key in chunk: qs[key] = q args.qs = qs writer_pool.close() #no more jobs applied to writer_pool #Start correction logger.debug("Starting correction") task_list = [ worker_pool.apply_async(bias_correction, args=[chunk, args, bias_obj]) for chunk in output_regions_chunks ] worker_pool.close() monitor_progress(task_list, logger, "Correction progress:" ) #does not exit until tasks in task_list finished results = [task.get() for task in task_list] #Get all results pre_bias = results[0][0] #initialize with first result post_bias = results[0][1] #initialize with first result for result in results[1:]: pre_bias_chunk = result[0] post_bias_chunk = result[1] for direction in strands: pre_bias[direction].add_counts(pre_bias_chunk[direction]) post_bias[direction].add_counts(post_bias_chunk[direction]) #Stop all queues for writing logger.debug("Stop all queues by inserting None") for q in qs_list: q.put((None, None, None)) #Fetch error codes from bigwig writers logger.debug("Fetching possible errors from bigwig_writer tasks") results = [task.get() for task in writer_tasks] #blocks until writers are finished logger.debug("Joining bigwig_writer queues") qsum = sum([q.qsize() for q in qs_list]) while qsum != 0: qsum = sum([q.qsize() for q in qs_list]) logger.spam("- Queue sizes {0}".format([(key, qs[key].qsize()) for key in qs])) time.sleep(0.5) #Waits until all queues are closed writer_pool.join() worker_pool.terminate() worker_pool.join() #Stop multiprocessing logger logger.stop_logger_queue() #----------------------------------------------------------------------------------------------------# # Information and verification of corrected read frequencies #----------------------------------------------------------------------------------------------------# logger.comment("") logger.info("Verifying bias correction") #Calculating variance per base for strand in strands: #Invert negative counts abssum = np.abs(np.sum(post_bias[strand].neg_counts, axis=0)) post_bias[strand].neg_counts = post_bias[strand].neg_counts + abssum #Join negative/positive counts post_bias[strand].counts += post_bias[strand].neg_counts #now pos pre_bias[strand].prepare_mat() post_bias[strand].prepare_mat() pre_var = np.mean(np.var(pre_bias[strand].bias_pwm, axis=1)[:4]) #mean of variance per nucleotide post_var = np.mean(np.var(post_bias[strand].bias_pwm, axis=1)[:4]) logger.stats("BIAS\tpre-bias variance {0}:\t{1:.7f}".format( strand, pre_var)) logger.stats("BIAS\tpost-bias variance {0}:\t{1:.7f}".format( strand, post_var)) #Plot figure fig_title = "Nucleotide frequencies in corrected reads\n({0} strand)".format( strand) figure_pdf.savefig( plot_correction(pre_bias[strand].bias_pwm, post_bias[strand].bias_pwm, fig_title)) #----------------------------------------------------------------------------------------------------# # Finish up #----------------------------------------------------------------------------------------------------# plt.close('all') figure_pdf.close() logger.end()
strand, post_var)) #Plot figure fig_title = "Nucleotide frequencies in corrected reads\n({0} strand)".format( strand) figure_pdf.savefig( plot_correction(pre_bias[strand].bias_pwm, post_bias[strand].bias_pwm, fig_title)) #----------------------------------------------------------------------------------------------------# # Finish up #----------------------------------------------------------------------------------------------------# plt.close('all') figure_pdf.close() logger.end() #--------------------------------------------------------------------------------------------------------# if __name__ == '__main__': parser = argparse.ArgumentParser() parser = add_atacorrect_arguments(parser) args = parser.parse_args() if len(sys.argv[1:]) == 0: parser.print_help() sys.exit() run_atacorrect(args)