def run_bindetect(args): """ Main function to run bindetect algorithm with input files and parameters given in args """ #Checking input and setting cond_names check_required(args, ["signals", "motifs", "genome", "peaks"]) args.cond_names = [ os.path.basename(os.path.splitext(bw)[0]) for bw in args.signals ] if args.cond_names is None else args.cond_names args.outdir = os.path.abspath(args.outdir) #Set output files states = ["bound", "unbound"] outfiles = [ os.path.abspath( os.path.join(args.outdir, "*", "beds", "*_{0}_{1}.bed".format(condition, state))) for (condition, state) in itertools.product(args.cond_names, states) ] outfiles.append( os.path.abspath(os.path.join(args.outdir, "*", "beds", "*_all.bed"))) outfiles.append( os.path.abspath( os.path.join(args.outdir, "*", "plots", "*_log2fcs.pdf"))) outfiles.append( os.path.abspath(os.path.join(args.outdir, "*", "*_overview.txt"))) outfiles.append( os.path.abspath(os.path.join(args.outdir, "*", "*_overview.xlsx"))) outfiles.append( os.path.abspath( os.path.join(args.outdir, args.prefix + "_distances.txt"))) outfiles.append( os.path.abspath(os.path.join(args.outdir, args.prefix + "_results.txt"))) outfiles.append( os.path.abspath( os.path.join(args.outdir, args.prefix + "_results.xlsx"))) outfiles.append( os.path.abspath(os.path.join(args.outdir, args.prefix + "_figures.pdf"))) #-------------------------------------------------------------------------------------------------------------# #-------------------------------------------- Setup logger and pool ------------------------------------------# #-------------------------------------------------------------------------------------------------------------# logger = TobiasLogger("BINDetect", args.verbosity) logger.begin() parser = add_bindetect_arguments(argparse.ArgumentParser()) logger.arguments_overview(parser, args) logger.output_files(outfiles) # Setup pool args.cores = check_cores(args.cores, logger) writer_cores = max(1, int(args.cores * 0.1)) worker_cores = max(1, args.cores - writer_cores) logger.debug("Worker cores: {0}".format(worker_cores)) logger.debug("Writer cores: {0}".format(writer_cores)) pool = mp.Pool(processes=worker_cores) writer_pool = mp.Pool(processes=writer_cores) #-------------------------------------------------------------------------------------------------------------# #-------------------------- Pre-processing data: Reading motifs, sequences, peaks ----------------------------# #-------------------------------------------------------------------------------------------------------------# logger.info("----- Processing input data -----") #Check opening/writing of files logger.info("Checking reading/writing of files") check_files([args.signals, args.motifs, args.genome, args.peaks], action="r") check_files(outfiles[-3:], action="w") make_directory(args.outdir) #Comparisons between conditions no_conditions = len(args.signals) if args.time_series: comparisons = list(zip(args.cond_names[:-1], args.cond_names[1:])) args.comparisons = comparisons else: comparisons = list(itertools.combinations(args.cond_names, 2)) #all-against-all args.comparisons = comparisons #Open figure pdf and write overview fig_out = os.path.abspath( os.path.join(args.outdir, args.prefix + "_figures.pdf")) figure_pdf = PdfPages(fig_out, keep_empty=True) plt.figure() plt.axis('off') plt.text(0.5, 0.8, "BINDETECT FIGURES", ha="center", va="center", fontsize=20) #output and order titles = [] titles.append("Raw score distributions") titles.append("Normalized score distributions") if args.debug: for (cond1, cond2) in comparisons: titles.append("Background log2FCs ({0} / {1})".format( cond1, cond2)) for (cond1, cond2) in comparisons: titles.append("BINDetect plot ({0} / {1})".format(cond1, cond2)) plt.text(0.1, 0.6, "\n".join([ "Page {0}) {1}".format(i + 2, titles[i]) for i in range(len(titles)) ]) + "\n\n", va="top") figure_pdf.savefig(bbox_inches='tight') plt.close() ################# Read peaks ################ #Read peak and peak_header logger.info("Reading peaks") peaks = RegionList().from_bed(args.peaks) logger.info("- Found {0} regions in input peaks".format(len(peaks))) peaks = peaks.merge() #merge overlapping peaks logger.info("- Merged to {0} regions".format(len(peaks))) if len(peaks) == 0: logger.error("Input --peaks file is empty!") sys.exit() #Read header and check match with number of peak columns peak_columns = len(peaks[0]) #number of columns if args.peak_header != None: content = open(args.peak_header, "r").read() args.peak_header_list = content.split() logger.debug("Peak header: {0}".format(args.peak_header_list)) #Check whether peak header fits with number of peak columns if len(args.peak_header_list) != peak_columns: logger.error( "Length of --peak_header ({0}) does not fit number of columns in --peaks ({1})." .format(len(args.peak_header_list), peak_columns)) sys.exit() else: args.peak_header_list = ["peak_chr", "peak_start", "peak_end"] + [ "additional_" + str(num + 1) for num in range(peak_columns - 3) ] logger.debug("Peak header list: {0}".format(args.peak_header_list)) ################# Check for match between peaks and fasta/bigwig ################# logger.info( "Checking for match between --peaks and --fasta/--signals boundaries") logger.info("- Comparing peaks to {0}".format(args.genome)) fasta_obj = pysam.FastaFile(args.genome) fasta_boundaries = dict(zip(fasta_obj.references, fasta_obj.lengths)) fasta_obj.close() logger.debug("Fasta boundaries: {0}".format(fasta_boundaries)) peaks = peaks.apply_method(OneRegion.check_boundary, fasta_boundaries, "exit") #will exit if peaks are outside borders #Check boundaries of each bigwig signal individually for signal in args.signals: logger.info("- Comparing peaks to {0}".format(signal)) pybw_obj = pybw.open(signal) pybw_header = pybw_obj.chroms() pybw_obj.close() logger.debug("Signal boundaries: {0}".format(pybw_header)) peaks = peaks.apply_method(OneRegion.check_boundary, pybw_header, "exit") ##### GC content for motif scanning ###### #Make chunks of regions for multiprocessing logger.info("Estimating GC content from peak sequences") peak_chunks = peaks.chunks(args.split) gc_content_pool = pool.starmap( get_gc_content, itertools.product(peak_chunks, [args.genome])) gc_content = np.mean(gc_content_pool) #fraction args.gc = gc_content bg = np.array([(1 - args.gc) / 2.0, args.gc / 2.0, args.gc / 2.0, (1 - args.gc) / 2.0]) logger.info("- GC content estimated at {0:.2f}%".format(gc_content * 100)) ################ Get motifs ################ logger.info("Reading motifs from file") motif_list = MotifList() args.motifs = expand_dirs(args.motifs) for f in args.motifs: motif_list += MotifList().from_file(f) #List of OneMotif objects no_pfms = len(motif_list) logger.info("- Read {0} motifs".format(no_pfms)) logger.debug("Getting motifs ready") motif_list.bg = bg logger.debug("Getting reverse motifs") motif_list.extend([motif.get_reverse() for motif in motif_list]) logger.spam(motif_list) #Set prefixes for motif in motif_list: #now with reverse motifs as well motif.set_prefix(args.naming) motif.bg = bg logger.spam("Getting pssm for motif {0}".format(motif.name)) motif.get_pssm() motif_names = list(set([motif.prefix for motif in motif_list])) #Get threshold for motifs logger.debug("Getting match threshold per motif") outlist = pool.starmap(OneMotif.get_threshold, itertools.product(motif_list, [args.motif_pvalue])) motif_list = MotifList(outlist) for motif in motif_list: logger.debug("Motif {0}: threshold {1}".format(motif.name, motif.threshold)) logger.info("Creating folder structure for each TF") for TF in motif_names: logger.spam("Creating directories for {0}".format(TF)) make_directory(os.path.join(args.outdir, TF)) make_directory(os.path.join(args.outdir, TF, "beds")) make_directory(os.path.join(args.outdir, TF, "plots")) #-------------------------------------------------------------------------------------------------------------# #----------------------------------------- Plot logos for all motifs -----------------------------------------# #-------------------------------------------------------------------------------------------------------------# plus_motifs = [motif for motif in motif_list if motif.strand == "+"] logo_filenames = { motif.prefix: os.path.join(args.outdir, motif.prefix, motif.prefix + ".png") for motif in plus_motifs } logger.info("Plotting sequence logos for each motif") task_list = [ pool.apply_async(OneMotif.logo_to_file, ( motif, logo_filenames[motif.prefix], )) for motif in plus_motifs ] monitor_progress(task_list, logger) results = [task.get() for task in task_list] logger.comment("") logger.debug("Getting base64 strings per motif") for motif in motif_list: if motif.strand == "+": #motif.get_base() with open(logo_filenames[motif.prefix], "rb") as png: motif.base = base64.b64encode(png.read()).decode("utf-8") #-------------------------------------------------------------------------------------------------------------# #--------------------- Motif scanning: Find binding sites and match to footprint scores ----------------------# #-------------------------------------------------------------------------------------------------------------# logger.comment("") logger.start_logger_queue( ) #start process for listening and handling through the main logger queue args.log_q = logger.queue #queue for multiprocessing logging manager = mp.Manager() logger.info("Scanning for motifs and matching to signals...") #Create writer queues for bed-file output logger.debug("Setting up writer queues") qs_list = [] writer_qs = {} #writer_queue = create_writer_queue(key2file, writer_cores) #writer_queue.stop() #wait until all are done manager = mp.Manager() TF_names_chunks = [ motif_names[i::writer_cores] for i in range(writer_cores) ] for TF_names_sub in TF_names_chunks: logger.debug("Creating writer queue for {0}".format(TF_names_sub)) files = [ os.path.join(args.outdir, TF, "beds", TF + ".tmp") for TF in TF_names_sub ] q = manager.Queue() qs_list.append(q) writer_pool.apply_async( file_writer, args=(q, dict(zip(TF_names_sub, files)), args) ) #, callback = lambda x: finished.append(x) print("Writing time: {0}".format(x))) for TF in TF_names_sub: writer_qs[TF] = q writer_pool.close() #no more jobs applied to writer_pool #todo: use run_parallel #Start working on data if worker_cores == 1: logger.debug("Running with cores = 1") results = [] for chunk in peak_chunks: results.append( scan_and_score(chunk, motif_list, args, args.log_q, writer_qs)) else: logger.debug("Sending jobs to worker pool") task_list = [ pool.apply_async(scan_and_score, ( chunk, motif_list, args, args.log_q, writer_qs, )) for chunk in peak_chunks ] monitor_progress(task_list, logger) results = [task.get() for task in task_list] logger.info("Done scanning for TFBS across regions!") #logger.stop_logger_queue() #stop the listening process (wait until all was written) #--------------------------------------# logger.info("Waiting for bedfiles to write") #Stop all queues for writing logger.debug("Stop all queues by inserting None") for q in qs_list: q.put((None, None)) logger.debug("Joining bed_writer queues") for i, q in enumerate(qs_list): logger.debug("- Queue {0} (size {1})".format(i, q.qsize())) #Waits until all queues are closed writer_pool.join() #-------------------------------------------------------------------------------------------------------------# #---------------------------- Process information on background scores and overlaps --------------------------# #-------------------------------------------------------------------------------------------------------------# logger.info("Merging results from subsets") background = merge_dicts([result[0] for result in results]) TF_overlaps = merge_dicts([result[1] for result in results]) results = None #Add missing TF overlaps (if some TFs had zero sites) for TF1 in plus_motifs: if TF1.prefix not in TF_overlaps: TF_overlaps[TF1.prefix] = 0 for TF2 in plus_motifs: tup = (TF1.prefix, TF2.prefix) if tup not in TF_overlaps: TF_overlaps[tup] = 0 #Collect sampled background values for bigwig in args.cond_names: background["signal"][bigwig] = np.array(background["signal"][bigwig]) #Check how many values were fetched from background n_bg_values = len(background["signal"][args.cond_names[0]]) logger.debug("Collected {0} values from background".format(n_bg_values)) if n_bg_values < 1000: err_str = "Number of background values collected from peaks is low (={0}) ".format( n_bg_values) err_str += "- this affects estimation of the bound/unbound threshold and the normalization between conditions. " err_str += "To improve this estimation, please run BINDetect with --peaks = the full peak set across all conditions." logger.warning(err_str) logger.comment("") logger.info("Estimating score distribution per condition") fig = plot_score_distribution( [background["signal"][bigwig] for bigwig in args.cond_names], labels=args.cond_names, title="Raw scores per condition") figure_pdf.savefig(fig, bbox_inches='tight') plt.close() logger.info("Normalizing scores") list_of_vals = [background["signal"][bigwig] for bigwig in args.cond_names] normed, norm_objects = quantile_normalization(list_of_vals) args.norm_objects = dict(zip(args.cond_names, norm_objects)) for bigwig in args.cond_names: background["signal"][bigwig] = args.norm_objects[bigwig].normalize( background["signal"][bigwig]) fig = plot_score_distribution( [background["signal"][bigwig] for bigwig in args.cond_names], labels=args.cond_names, title="Normalized scores per condition") figure_pdf.savefig(fig, bbox_inches='tight') plt.close() ########################################################### logger.info("Estimating bound/unbound threshold") #Prepare scores (remove 0's etc.) bg_values = np.array(normed).flatten() bg_values = bg_values[np.logical_not(np.isclose( bg_values, 0.0))] #only non-zero counts x_max = np.percentile(bg_values, [99]) bg_values = bg_values[bg_values < x_max] if len(bg_values) == 0: logger.error( "Error processing bigwig scores from background. It could be that there are no scores in the bigwig (=0) assigned for the peaks. Please check your input files." ) sys.exit() #Fit mixture of normals lowest_bic = np.inf for n_components in [2]: #2 components gmm = sklearn.mixture.GaussianMixture(n_components=n_components, random_state=1) gmm.fit(np.log(bg_values).reshape(-1, 1)) bic = gmm.bic(np.log(bg_values).reshape(-1, 1)) logger.debug("n_compontents: {0} | bic: {1}".format(n_components, bic)) if bic < lowest_bic: lowest_bic = bic best_gmm = gmm gmm = best_gmm #Extract most-right gaussian means = gmm.means_.flatten() sds = np.sqrt(gmm.covariances_).flatten() chosen_i = np.argmax(means) #Mixture with largest mean log_params = scipy.stats.lognorm.fit(bg_values[bg_values < x_max], f0=sds[chosen_i], fscale=np.exp(means[chosen_i])) #all_log_params[bigwig] = log_params #Mode of distribution mode = scipy.optimize.fmin( lambda x: -scipy.stats.lognorm.pdf(x, *log_params), 0, disp=False)[0] logger.debug("- Mode estimated at: {0}".format(mode)) pseudo = mode / 2.0 #pseudo is half the mode args.pseudo = pseudo logger.debug("Pseudocount estimated at: {0}".format(round(args.pseudo, 5))) # Estimate theoretical normal for threshold leftside_x = np.linspace( scipy.stats.lognorm(*log_params).ppf([0.01]), mode, 100) leftside_pdf = scipy.stats.lognorm.pdf(leftside_x, *log_params) #Flip over mirrored_x = np.concatenate([leftside_x, np.max(leftside_x) + leftside_x]).flatten() mirrored_pdf = np.concatenate([leftside_pdf, leftside_pdf[::-1]]).flatten() popt, cov = scipy.optimize.curve_fit( lambda x, std, sc: sc * scipy.stats.norm.pdf(x, mode, std), mirrored_x, mirrored_pdf) norm_params = (mode, popt[0]) logger.debug("Theoretical normal parameters: {0}".format(norm_params)) #Set threshold for bound/unbound threshold = round( scipy.stats.norm.ppf(1 - args.bound_pvalue, *norm_params), 5) args.thresholds = {bigwig: threshold for bigwig in args.cond_names} logger.stats("- Threshold estimated at: {0}".format(threshold)) #Only plot if args.debug is True if args.debug: #Plot fit fig, ax = plt.subplots(1, 1) ax.hist(bg_values[bg_values < x_max], bins='auto', density=True, label="Observed score distribution") xvals = np.linspace(0, x_max, 1000) log_probas = scipy.stats.lognorm.pdf(xvals, *log_params) ax.plot(xvals, log_probas, label="Log-normal fit", color="orange") #Theoretical normal norm_probas = scipy.stats.norm.pdf(xvals, *norm_params) ax.plot(xvals, norm_probas * (np.max(log_probas) / np.max(norm_probas)), color="grey", linestyle="--", label="Theoretical normal") ax.axvline(threshold, color="black", label="Bound/unbound threshold") ymax = plt.ylim()[1] ax.text(threshold, ymax, "\n {0:.3f}".format(threshold), va="top") #Decorate plot plt.title("Score distribution") plt.xlabel("Bigwig score") plt.ylabel("Density") plt.legend(fontsize=8) plt.xlim((0, x_max)) figure_pdf.savefig(fig) plt.close(fig) ############ Foldchanges between conditions ################ logger.comment("") log2fc_params = {} if len(args.signals) > 1: logger.info( "Calculating background log2 fold-changes between conditions") for (bigwig1, bigwig2) in comparisons: #cond1, cond2 logger.info("- {0} / {1}".format(bigwig1, bigwig2)) #Estimate background log2fc scores1 = np.copy(background["signal"][bigwig1]) scores2 = np.copy(background["signal"][bigwig2]) included = np.logical_or(scores1 > 0, scores2 > 0) scores1 = scores1[included] scores2 = scores2[included] #Calculate background log2fc normal disitribution log2fcs = np.log2( np.true_divide(scores1 + args.pseudo, scores2 + args.pseudo)) lower, upper = np.percentile(log2fcs, [1, 99]) log2fcs_fit = log2fcs[np.logical_and(log2fcs >= lower, log2fcs <= upper)] norm_params = scipy.stats.norm.fit(log2fcs_fit) logger.debug( "({0} / {1}) Background log2fc normal distribution: {2}". format(bigwig1, bigwig2, norm_params)) log2fc_params[(bigwig1, bigwig2)] = norm_params #Plot background log2fc to figures fig, ax = plt.subplots(1, 1) plt.hist(log2fcs, density=True, bins='auto', label="Background log2fc ({0} / {1})".format( bigwig1, bigwig2)) xvals = np.linspace(plt.xlim()[0], plt.xlim()[1], 100) pdf = scipy.stats.norm.pdf(xvals, *log2fc_params[(bigwig1, bigwig2)]) plt.plot(xvals, pdf, label="Normal distribution fit") plt.title("Background log2FCs ({0} / {1})".format( bigwig1, bigwig2)) plt.xlabel("Log2 fold change") plt.ylabel("Density") if args.debug: figure_pdf.savefig(fig, bbox_inches='tight') f = open( os.path.join( args.outdir, "{0}_{1}_log2fcs.txt".format(bigwig1, bigwig2)), "w") f.write("\n".join([str(val) for val in log2fcs])) f.close() plt.close() background = None #free up space #-------------------------------------------------------------------------------------------------------------# #----------------------------- Read total sites per TF to estimate bound/unbound -----------------------------# #-------------------------------------------------------------------------------------------------------------# logger.comment("") logger.info("Processing scanned TFBS individually") #Getting bindetect table ready info_columns = ["total_tfbs"] info_columns.extend([ "{0}_{1}".format(cond, metric) for (cond, metric ) in itertools.product(args.cond_names, ["threshold", "bound"]) ]) info_columns.extend([ "{0}_{1}_{2}".format(comparison[0], comparison[1], metric) for (comparison, metric) in itertools.product(comparisons, ["change", "pvalue"]) ]) cols = len(info_columns) rows = len(motif_names) info_table = pd.DataFrame(np.zeros((rows, cols)), columns=info_columns, index=motif_names) #Starting calculations results = [] if args.cores == 1: for name in motif_names: logger.info("- {0}".format(name)) results.append(process_tfbs(name, args, log2fc_params)) else: logger.debug("Sending jobs to worker pool") task_list = [ pool.apply_async(process_tfbs, ( name, args, log2fc_params, )) for name in motif_names ] monitor_progress(task_list, logger) #will not exit before all jobs are done results = [task.get() for task in task_list] logger.info("Concatenating results from subsets") info_table = pd.concat(results) #pandas tables pool.terminate() pool.join() logger.stop_logger_queue() #-------------------------------------------------------------------------------------------------------------# #------------------------------------------------ Cluster TFBS -----------------------------------------------# #-------------------------------------------------------------------------------------------------------------# clustering = RegionCluster(TF_overlaps) clustering.cluster() #Convert full ids to alt ids convert = {motif.prefix: motif.name for motif in motif_list} for cluster in clustering.clusters: for name in convert: clustering.clusters[cluster]["cluster_name"] = clustering.clusters[ cluster]["cluster_name"].replace(name, convert[name]) #Write out distance matrix matrix_out = os.path.join(args.outdir, args.prefix + "_distances.txt") clustering.write_distance_mat(matrix_out) #-------------------------------------------------------------------------------------------------------------# #----------------------------------------- Write all_bindetect file ------------------------------------------# #-------------------------------------------------------------------------------------------------------------# logger.comment("") logger.info("Writing all_bindetect files") #Add columns of name / motif_id / prefix names = [] ids = [] for prefix in info_table.index: motif = [motif for motif in motif_list if motif.prefix == prefix] names.append(motif[0].name) ids.append(motif[0].id) info_table.insert(0, "output_prefix", info_table.index) info_table.insert(1, "name", names) info_table.insert(2, "motif_id", ids) #info_table.insert(3, "motif_logo", [os.path.join("motif_logos", os.path.basename(logo_filenames[prefix])) for prefix in info_table["output_prefix"]]) #add relative path to logo #Add cluster to info_table cluster_names = [] for name in info_table.index: for cluster in clustering.clusters: if name in clustering.clusters[cluster]["member_names"]: cluster_names.append( clustering.clusters[cluster]["cluster_name"]) info_table.insert(3, "cluster", cluster_names) #Cluster table on motif clusters info_table_clustered = info_table.groupby( "cluster").mean() #mean of each column info_table_clustered.reset_index(inplace=True) #Map correct type info_table["total_tfbs"] = info_table["total_tfbs"].map(int) for condition in args.cond_names: info_table[condition + "_bound"] = info_table[condition + "_bound"].map(int) #### Write excel ### bindetect_excel = os.path.join(args.outdir, args.prefix + "_results.xlsx") writer = pd.ExcelWriter(bindetect_excel, engine='xlsxwriter') #Tables info_table.to_excel(writer, index=False, sheet_name="Individual motifs") info_table_clustered.to_excel(writer, index=False, sheet_name="Motif clusters") for sheet in writer.sheets: worksheet = writer.sheets[sheet] n_rows = worksheet.dim_rowmax n_cols = worksheet.dim_colmax worksheet.autofilter(0, 0, n_rows, n_cols) writer.save() #Format comparisons for (cond1, cond2) in comparisons: base = cond1 + "_" + cond2 info_table[base + "_change"] = info_table[base + "_change"].round(5) info_table[base + "_pvalue"] = info_table[base + "_pvalue"].map( "{:.5E}".format, na_action="ignore") #Write bindetect results tables #info_table.insert(0, "TF_name", info_table.index) #Set index as first column bindetect_out = os.path.join(args.outdir, args.prefix + "_results.txt") info_table.to_csv(bindetect_out, sep="\t", index=False, header=True, na_rep="NA") #-------------------------------------------------------------------------------------------------------------# #------------------------------------------- Make BINDetect plot ---------------------------------------------# #-------------------------------------------------------------------------------------------------------------# if no_conditions > 1: logger.info("Creating BINDetect plot(s)") #Fill NAs from info_table to enable plotting of log2fcs (NA -> 0 change) change_cols = [col for col in info_table.columns if "_change" in col] pvalue_cols = [col for col in info_table.columns if "_pvalue" in col] info_table[change_cols] = info_table[change_cols].fillna(0) info_table[pvalue_cols] = info_table[pvalue_cols].fillna(1) #Plotting bindetect per comparison for (cond1, cond2) in comparisons: logger.info("- {0} / {1}".format(cond1, cond2)) base = cond1 + "_" + cond2 #Define which motifs to show xvalues = info_table[base + "_change"].astype(float) yvalues = info_table[base + "_pvalue"].astype(float) y_min = np.percentile(yvalues[yvalues > 0], 5) #5% smallest pvalues x_min, x_max = np.percentile( xvalues, [5, 95]) #5% smallest and largest changes #Make copy of motifs and fill in with metadata comparison_motifs = [ motif for motif in motif_list if motif.strand == "+" ] #copy.deepcopy(motif_list) - swig pickle error, just overwrite motif_list for motif in comparison_motifs: name = motif.prefix motif.change = float(info_table.at[name, base + "_change"]) motif.pvalue = float(info_table.at[name, base + "_pvalue"]) motif.logpvalue = -np.log10( motif.pvalue) if motif.pvalue > 0 else -np.log10(1e-308) #Assign each motif to group if motif.change < x_min or motif.change > x_max or motif.pvalue < y_min: if motif.change < 0: motif.group = cond2 + "_up" if motif.change > 0: motif.group = cond1 + "_up" else: motif.group = "n.s." #Bindetect plot fig = plot_bindetect(comparison_motifs, clustering, [cond1, cond2], args) figure_pdf.savefig(fig, bbox_inches='tight') plt.close(fig) #Interactive BINDetect plot html_out = os.path.join(args.outdir, "bindetect_" + base + ".html") plot_interactive_bindetect(comparison_motifs, [cond1, cond2], html_out) #-------------------------------------------------------------------------------------------------------------# #----------------------------- Make heatmap across conditions (for debugging)---------------------------------# #-------------------------------------------------------------------------------------------------------------# if args.debug: mean_columns = [cond + "_mean_score" for cond in args.cond_names] heatmap_table = info_table[mean_columns] heatmap_table.index = info_table["output_prefix"] #Decide fig size rows, cols = heatmap_table.shape figsize = (7 + cols, max(10, rows / 8.0)) cm = sns.clustermap( heatmap_table, figsize=figsize, z_score=0, #zscore for rows col_cluster=False, #do not cluster condition columns yticklabels=True, #show all row annotations xticklabels=True, cbar_pos=(0, 0, .4, .005), dendrogram_ratio=(0.3, 0.01), cbar_kws={ "orientation": "horizontal", 'label': 'Row z-score' }, method="single") #Adjust width of columns #hm = cm.ax_heatmap.get_position() #cm.ax_heatmap.set_position([hm.x0, hm.y0, cols * 3 * hm.height / rows, hm.height]) #aspect should be equal plt.setp(cm.ax_heatmap.get_xticklabels(), fontsize=8, rotation=45, ha="right") plt.setp(cm.ax_heatmap.get_yticklabels(), fontsize=5) cm.ax_col_dendrogram.set_title('Mean scores across conditions', fontsize=20) cm.ax_heatmap.set_ylabel("Transcription factor motifs", fontsize=15, rotation=270) #cm.ax_heatmap.set_title('Conditions') #cm.fig.suptitle('Mean scores across conditions') #cm.cax.set_visible(False) #Save to output pdf plt.tight_layout() figure_pdf.savefig(cm.fig, bbox_inches='tight') plt.close(cm.fig) #-------------------------------------------------------------------------------------------------------------# #-------------------------------------------------- Wrap up---------------------------------------------------# #-------------------------------------------------------------------------------------------------------------# figure_pdf.close() logger.end()
def run_aggregate(args): """ Function to make aggregate plot given input from args """ ######################################################################################### ############################## Setup logger/input/output ################################ ######################################################################################### logger = TobiasLogger("PlotAggregate", args.verbosity) logger.begin() parser = add_aggregate_arguments(argparse.ArgumentParser()) logger.arguments_overview(parser, args) logger.output_files([args.output]) #Check input parameters check_required(args, ["TFBS", "signals"]) check_files([ args.TFBS, args.signals, args.regions, args.whitelist, args.blacklist ], action="r") check_files([args.output], action="w") #### Test input #### if args.TFBS_labels != None and (len(args.TFBS) != len(args.TFBS_labels)): logger.error( "ERROR --TFBS and --TFBS-labels have different lengths ({0} vs. {1})" .format(len(args.TFBS), len(args.TFBS_labels))) sys.exit(1) if args.region_labels != None and (len(args.regions) != len( args.region_labels)): logger.error( "ERROR: --regions and --region-labels have different lengths ({0} vs. {1})" .format(len(args.regions), len(args.region_labels))) sys.exit(1) if args.signal_labels != None and (len(args.signals) != len( args.signal_labels)): logger.error( "ERROR: --signals and --signal-labels have different lengths ({0} vs. {1})" .format(len(args.signals), len(args.signal_labels))) sys.exit(1) #### Format input #### args.TFBS_labels = [ os.path.splitext(os.path.basename(f))[0] for f in args.TFBS ] if args.TFBS_labels == None else args.TFBS_labels args.region_labels = [ os.path.splitext(os.path.basename(f))[0] for f in args.regions ] if args.region_labels == None else args.region_labels args.signal_labels = [ os.path.splitext(os.path.basename(f))[0] for f in args.signals ] if args.signal_labels == None else args.signal_labels #TFBS labels cannot be the same if len(set(args.TFBS_labels)) < len( args.TFBS_labels): #this indicates duplicates logger.error( "ERROR: --TFBS-labels are not allowed to contain duplicates. Note that '--TFBS-labels' are created automatically from the '--TFBS'-files if no input was given." + "Please check that neither contain duplicate names") sys.exit(1) ######################################################################################### ############################ Get input regions/signals ready ############################ ######################################################################################### logger.info("---- Processing input ----") logger.info("Reading information from .bed-files") #Make combinations of TFBS / regions region_names = [] if len(args.regions) > 0: logger.info("Overlapping sites to --regions") regions_dict = {} combis = itertools.product(range(len(args.TFBS)), range(len(args.regions))) for (i, j) in combis: TFBS_f = args.TFBS[i] region_f = args.regions[j] #Make overlap pb_tfbs = pb.BedTool(TFBS_f) pb_region = pb.BedTool(region_f) #todo: write out lengths overlap = pb_tfbs.intersect(pb_region, u=True) #todo: length after overlap name = args.TFBS_labels[i] + " <OVERLAPPING> " + args.region_labels[ j] #name for column region_names.append(name) regions_dict[name] = RegionList().from_bed(overlap.fn) if args.negate == True: overlap_neg = pb_tfbs.intersect(pb_region, v=True) name = args.TFBS_labels[ i] + " <NOT OVERLAPPING> " + args.region_labels[j] region_names.append(name) regions_dict[name] = RegionList().from_bed(overlap_neg.fn) else: region_names = args.TFBS_labels regions_dict = { args.TFBS_labels[i]: RegionList().from_bed(args.TFBS[i]) for i in range(len(args.TFBS)) } for name in region_names: logger.stats("COUNT {0}: {1} sites".format( name, len(regions_dict[name]))) #length of RegionList obj #-------- Do overlap of regions if whitelist / blacklist -------# if len(args.whitelist) > 0 or len(args.blacklist) > 0: logger.info("Subsetting regions on whitelist/blacklist") for regions_id in regions_dict: sites = pb.BedTool(regions_dict[regions_id].as_bed(), from_string=True) logger.stats("Found {0} sites in {1}".format( len(regions_dict[regions_id]), regions_id)) if len(args.whitelist) > 0: for whitelist_f in args.whitelist: whitelist = pb.BedTool(whitelist_f) sites_tmp = sites.intersect(whitelist, u=True) sites = sites_tmp logger.stats("Overlapped to whitelist -> {0}".format( len(sites))) if len(args.blacklist) > 0: for blacklist_f in args.blacklist: blacklist = pb.BedTool(blacklist_f) sites_tmp = sites.intersect(blacklist, v=True) sites = sites_tmp logger.stats("Removed blacklist -> {0}".format( format(len(sites)))) regions_dict[regions_id] = RegionList().from_bed(sites.fn) # Estimate motif width per --TFBS motif_widths = {} for regions_id in regions_dict: site_list = regions_dict[regions_id] if len(site_list) > 0: motif_widths[regions_id] = site_list[0].get_width() else: motif_widths[regions_id] = 0 ######################################################################################### ############################ Read signal for bigwig per site ############################ ######################################################################################### logger.info("Reading signal from bigwigs") args.width = args.flank * 2 #output regions will be of args.width signal_dict = {} for i, signal_f in enumerate(args.signals): signal_name = args.signal_labels[i] signal_dict[signal_name] = {} #Open pybw to read signal pybw = pyBigWig.open(signal_f) boundaries = pybw.chroms() #dictionary of {chrom: length} logger.info("- Reading signal from {0}".format(signal_name)) for regions_id in regions_dict: original = copy.deepcopy(regions_dict[regions_id]) # Set width (centered on mid) regions_dict[regions_id].apply_method(OneRegion.set_width, args.width) #Check that regions are within boundaries and remove if not invalid = [ i for i, region in enumerate(regions_dict[regions_id]) if region.check_boundary(boundaries, action="remove") == None ] for invalid_idx in invalid[::-1]: #idx from higher to lower logger.warning( "Region '{reg}' ('{orig}' before flank extension) from bed regions '{id}' is out of chromosome boundaries. This region will be excluded from output." .format(reg=regions_dict[regions_id][invalid_idx].pretty(), orig=original[invalid_idx].pretty(), id=regions_id)) del regions_dict[regions_id][invalid_idx] #Get signal from remaining regions for one_region in regions_dict[regions_id]: tup = one_region.tup() #(chr, start, end, strand) if tup not in signal_dict[ signal_name]: #only get signal if it was not already read previously signal_dict[signal_name][tup] = one_region.get_signal( pybw, logger=logger, key=signal_name) #returns signal pybw.close() ######################################################################################### ################################## Calculate aggregates ################################# ######################################################################################### signal_names = args.signal_labels #Calculate aggregate per signal/region comparison logger.info("Calculating aggregate signals") aggregate_dict = { signal_name: {region_name: [] for region_name in regions_dict} for signal_name in signal_names } for row, signal_name in enumerate(signal_names): for col, region_name in enumerate(region_names): signalmat = np.array([ signal_dict[signal_name][reg.tup()] for reg in regions_dict[region_name] ]) #Check shape of signalmat if signalmat.shape[0] == 0: #no regions logger.warning( "No regions left for '{0}'. The aggregate for this signal will be set to 0." .format(signal_name)) aggregate = np.zeros(args.width) else: #Exclude outlier rows max_values = np.max(signalmat, axis=1) upper_limit = np.percentile( max_values, [100 * args.remove_outliers ])[0] #remove-outliers is a fraction logical = max_values <= upper_limit logger.debug( "{0}:{1}\tUpper limit: {2} (regions removed: {3})".format( signal_name, region_name, upper_limit, len(signalmat) - sum(logical))) signalmat = signalmat[logical] #Log-transform values before aggregating if args.log_transform: signalmat_abs = np.abs(signalmat) signalmat_log = np.log2(signalmat_abs + 1) signalmat_log[ signalmat < 0] *= -1 #original negatives back to <0 signalmat = signalmat_log aggregate = np.nanmean(signalmat, axis=0) #normalize between 0-1 if args.normalize: aggregate = preprocessing.minmax_scale(aggregate) if args.smooth > 1: aggregate_extend = np.pad(aggregate, args.smooth, "edge") aggregate_smooth = fast_rolling_math( aggregate_extend.astype('float64'), args.smooth, "mean") aggregate = aggregate_smooth[args.smooth:-args.smooth] aggregate_dict[signal_name][region_name] = aggregate signalmat = None #free up space signal_dict = None #free up space ######################################################################################### ############################## Write aggregates to file ################################# ######################################################################################### if args.output_txt is not None: #Open file for writing f_out = open(args.output_txt, "w") f_out.write("### AGGREGATE\n") f_out.write("# Signal\tRegions\tAggregate\n") for row, signal_name in enumerate(signal_names): for col, region_name in enumerate(region_names): agg = aggregate_dict[signal_name][region_name] agg_txt = ",".join(["{:.4f}".format(val) for val in agg]) f_out.write("{0}\t{1}\t{2}\n".format(signal_name, region_name, agg_txt)) f_out.close() ######################################################################################### ################################## Footprint measures ################################### ######################################################################################### logger.comment("") logger.info("---- Analysis ----") #Measure of footprint depth in comparison to baseline logger.info("Calculating footprint depth measure") logger.info("FPD (signal,regions): footprint_width baseline middle FPD") for row, signal_name in enumerate(signal_names): for col, region_name in enumerate(region_names): agg = aggregate_dict[signal_name][region_name] motif_width = motif_widths[region_name] #Estimation of possible footprint width FPD_results = [] for fp_flank in range(int(motif_width / 2), min([25, args.flank ])): #motif width for this bed #Baseline level baseline_indices = list(range( 0, args.flank - fp_flank)) + list( range(args.flank + fp_flank, len(agg))) baseline = np.mean(agg[baseline_indices]) #Footprint level middle_indices = list( range(args.flank - fp_flank, args.flank + fp_flank)) middle = np.mean(agg[middle_indices]) #within the motif #Footprint depth depth = middle - baseline FPD_results.append([fp_flank * 2, baseline, middle, depth]) #Estimation of possible footprint width all_fpds = [result[-1] for result in FPD_results] FPD_results_best = FPD_results #[result + [" "] if result[-1] != min(all_fpds) else result + ["*"] for result in FPD_results] for result in FPD_results_best: logger.stats( "FPD ({0},{1}): {2} {3:.5f} {4:.5f} {5:.5f}".format( signal_name, region_name, result[0], result[1], result[2], result[3])) #Compare pairwise to calculate correlation of signals logger.comment("") logger.info("Calculating measures for comparing pairwise aggregates") logger.info( "CORRELATION (signal1,region1) VS (signal2,region2): PEARSONR\tSUM_DIFF" ) plots = itertools.product(signal_names, region_names) combis = itertools.combinations(plots, 2) for ax1, ax2 in combis: signal1, region1 = ax1 signal2, region2 = ax2 agg1 = aggregate_dict[signal1][region1] agg2 = aggregate_dict[signal2][region2] pearsonr, pval = scipy.stats.pearsonr(agg1, agg2) diff = np.sum(np.abs(agg1 - agg2)) #Sum of difference between agg1 and agg2 logger.stats( "CORRELATION ({0},{1}) VS ({2},{3}): {4:.5f}\t{5:.5f}".format( signal1, region1, signal2, region2, pearsonr, diff)) ######################################################################################### ################################ Set up plotting grid ################################### ######################################################################################### logger.comment("") logger.info("---- Plotting aggregates ----") logger.info("Setting up plotting grid") n_signals = len(signal_names) n_regions = len(region_names) #regions are set of sites signal_compare = True if n_signals > 1 else False region_compare = True if n_regions > 1 else False #Define whether signal is on x/y if args.signal_on_x: #x-axis n_cols = n_signals col_compare = signal_compare col_names = signal_names #y-axis n_rows = n_regions row_compare = region_compare row_names = region_names else: #x-axis n_cols = n_regions col_compare = region_compare col_names = region_names #y-axis n_rows = n_signals row_compare = signal_compare row_names = signal_names #Compare across rows/cols? if row_compare: n_rows += 1 row_names += ["Comparison"] if col_compare: n_cols += 1 col_names += ["Comparison"] #Set grid fig, axarr = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 5), constrained_layout=True) axarr = np.array(axarr).reshape( (-1, 1)) if n_cols == 1 else axarr #Fix indexing for one column figures axarr = np.array(axarr).reshape( (1, -1)) if n_rows == 1 else axarr #Fix indexing for one row figures #X axis / Y axis labels #mainax = fig.add_subplot(111, frameon=False) #mainax.set_xlabel("X label", labelpad=30, fontsize=16) #mainax.set_ylabel("Y label", labelpad=30, fontsize=16) #mainax.xaxis.set_label_position('top') #Title of plot and grid plt.suptitle( " " * 7 + args.title, fontsize=16 ) #Add a little whitespace to center the title on the plot; not the frame #Titles per column for col in range(n_cols): title = col_names[col].replace(" ", "\n") l = max([len(line) for line in title.split("\n") ]) #length of longest line in title s = fontsize_func(l) #decide fontsize based on length axarr[0, col].set_title(title, fontsize=s) #Titles (ylabels) per row for row in range(n_rows): label = row_names[row] l = max([len(line) for line in label.split("\n")]) axarr[row, 0].set_ylabel(label, fontsize=fontsize_func(l)) #Colors colors = mpl.cm.brg( np.linspace(0, 1, len(signal_names) + len(region_names))) #xvals flank = int(args.width / 2.0) xvals = np.arange(-flank, flank + 1) xvals = np.delete(xvals, flank) #Settings for each subplot for row in range(n_rows): for col in range(n_cols): axarr[row, col].set_xlim(-flank, flank) axarr[row, col].set_xlabel('bp from center') #axarr[row, col].set_ylabel('Mean aggregated signal') minor_ticks = np.arange(-flank, flank, args.width / 10.0) #Settings for comparison plots a = [ axarr[-1, col].set_facecolor("0.9") if row_compare == True else 0 for col in range(n_cols) ] a = [ axarr[row, -1].set_facecolor("0.9") if col_compare == True else 0 for row in range(n_rows) ] ######################################################################################### ####################### Fill in grid with aggregate bigwig scores ####################### ######################################################################################### for si in range(n_signals): signal_name = signal_names[si] for ri in range(n_regions): region_name = region_names[ri] logger.info("Plotting regions {0} from signal {1}".format( region_name, signal_name)) row, col = (ri, si) if args.signal_on_x else (si, ri) #If there are any regions: if len(regions_dict[region_name]) > 0: #Signal in region aggregate = aggregate_dict[signal_name][region_name] axarr[row, col].plot(xvals, aggregate, color=colors[col + row], linewidth=1, label=signal_name) #Compare across rows and cols if col_compare: #compare between different columns by adding one more column axarr[row, -1].plot(xvals, aggregate, color=colors[row + col], linewidth=1, alpha=0.8, label=col_names[col]) s = min([ ax.title.get_fontproperties()._size for ax in axarr[0, :] ]) #smallest fontsize of all columns axarr[row, -1].legend(loc="lower right", fontsize=s) if row_compare: #compare between different rows by adding one more row axarr[-1, col].plot(xvals, aggregate, color=colors[row + col], linewidth=1, alpha=0.8, label=row_names[row]) s = min([ ax.yaxis.label.get_fontproperties()._size for ax in axarr[:, 0] ]) #smallest fontsize of all rows axarr[-1, col].legend(loc="lower right", fontsize=s) #Diagonal comparison if n_rows == n_cols and col_compare and row_compare and col == row: axarr[-1, -1].plot(xvals, aggregate, color=colors[row + col], linewidth=1, alpha=0.8) #Add number of sites to plot axarr[row, col].text(0.98, 0.98, str(len(regions_dict[region_name])), transform=axarr[row, col].transAxes, fontsize=12, va="top", ha="right") #Motif boundaries (can only be compared across sites) if args.plot_boundaries: #Get motif width for this list of TFBS width = motif_widths[region_names[min( row, n_rows - 2)]] if args.signal_on_x else motif_widths[ region_names[min(col, n_cols - 2)]] mstart = -np.floor(width / 2.0) mend = np.ceil( width / 2.0) - 1 #as it spans the "0" nucleotide axarr[row, col].axvline(mstart, color="grey", linestyle="dashed", linewidth=1) axarr[row, col].axvline(mend, color="grey", linestyle="dashed", linewidth=1) #------------- Finishing up plots ---------------# logger.info("Adjusting final details") #remove lower-right corner if not applicable if n_rows != n_cols and n_rows > 1 and n_cols > 1: axarr[-1, -1].axis('off') axarr[-1, -1] = None #Check whether share_y is set if args.share_y == "none": pass #Comparable rows (rowcompare = same across all) elif (args.share_y == "signals" and args.signal_on_x == False) or (args.share_y == "sites" and args.signal_on_x == True): for col in range(n_cols): lims = np.array( [ax.get_ylim() for ax in axarr[:, col] if ax is not None]) ymin, ymax = np.min(lims), np.max(lims) #Set limit across rows for this col for row in range(n_rows): if axarr[row, col] is not None: axarr[row, col].set_ylim(ymin, ymax) #Comparable columns (colcompare = same across all) elif (args.share_y == "sites" and args.signal_on_x == False) or (args.share_y == "signals" and args.signal_on_x == True): for row in range(n_rows): lims = np.array( [ax.get_ylim() for ax in axarr[row, :] if ax is not None]) ymin, ymax = np.min(lims), np.max(lims) #Set limit across cols for this row for col in range(n_cols): if axarr[row, col] is not None: axarr[row, col].set_ylim(ymin, ymax) #Comparable on both rows/columns elif args.share_y == "both": global_ymin, global_ymax = np.inf, -np.inf for row in range(n_rows): for col in range(n_cols): if axarr[row, col] is not None: local_ymin, local_ymax = axarr[row, col].get_ylim() global_ymin = local_ymin if local_ymin < global_ymin else global_ymin global_ymax = local_ymax if local_ymax > global_ymax else global_ymax for row in range(n_rows): for col in range(n_cols): if axarr[row, col] is not None: axarr[row, col].set_ylim(global_ymin, global_ymax) #Force plots to be square for row in range(n_rows): for col in range(n_cols): forceSquare(axarr[row, col]) plt.savefig(args.output, bbox_inches='tight') plt.close() logger.end()