def run_bindetect(args): """ Main function to run bindetect algorithm with input files and parameters given in args """ #Checking input and setting cond_names check_required(args, ["signals", "motifs", "genome", "peaks"]) args.cond_names = [ os.path.basename(os.path.splitext(bw)[0]) for bw in args.signals ] if args.cond_names is None else args.cond_names args.outdir = os.path.abspath(args.outdir) #Set output files states = ["bound", "unbound"] outfiles = [ os.path.abspath( os.path.join(args.outdir, "*", "beds", "*_{0}_{1}.bed".format(condition, state))) for (condition, state) in itertools.product(args.cond_names, states) ] outfiles.append( os.path.abspath(os.path.join(args.outdir, "*", "beds", "*_all.bed"))) outfiles.append( os.path.abspath( os.path.join(args.outdir, "*", "plots", "*_log2fcs.pdf"))) outfiles.append( os.path.abspath(os.path.join(args.outdir, "*", "*_overview.txt"))) outfiles.append( os.path.abspath(os.path.join(args.outdir, "*", "*_overview.xlsx"))) outfiles.append( os.path.abspath( os.path.join(args.outdir, args.prefix + "_distances.txt"))) outfiles.append( os.path.abspath(os.path.join(args.outdir, args.prefix + "_results.txt"))) outfiles.append( os.path.abspath( os.path.join(args.outdir, args.prefix + "_results.xlsx"))) outfiles.append( os.path.abspath(os.path.join(args.outdir, args.prefix + "_figures.pdf"))) #-------------------------------------------------------------------------------------------------------------# #-------------------------------------------- Setup logger and pool ------------------------------------------# #-------------------------------------------------------------------------------------------------------------# logger = TobiasLogger("BINDetect", args.verbosity) logger.begin() parser = add_bindetect_arguments(argparse.ArgumentParser()) logger.arguments_overview(parser, args) logger.output_files(outfiles) # Setup pool args.cores = check_cores(args.cores, logger) writer_cores = max(1, int(args.cores * 0.1)) worker_cores = max(1, args.cores - writer_cores) logger.debug("Worker cores: {0}".format(worker_cores)) logger.debug("Writer cores: {0}".format(writer_cores)) pool = mp.Pool(processes=worker_cores) writer_pool = mp.Pool(processes=writer_cores) #-------------------------------------------------------------------------------------------------------------# #-------------------------- Pre-processing data: Reading motifs, sequences, peaks ----------------------------# #-------------------------------------------------------------------------------------------------------------# logger.info("----- Processing input data -----") #Check opening/writing of files logger.info("Checking reading/writing of files") check_files([args.signals, args.motifs, args.genome, args.peaks], action="r") check_files(outfiles[-3:], action="w") make_directory(args.outdir) #Comparisons between conditions no_conditions = len(args.signals) if args.time_series: comparisons = list(zip(args.cond_names[:-1], args.cond_names[1:])) args.comparisons = comparisons else: comparisons = list(itertools.combinations(args.cond_names, 2)) #all-against-all args.comparisons = comparisons #Open figure pdf and write overview fig_out = os.path.abspath( os.path.join(args.outdir, args.prefix + "_figures.pdf")) figure_pdf = PdfPages(fig_out, keep_empty=True) plt.figure() plt.axis('off') plt.text(0.5, 0.8, "BINDETECT FIGURES", ha="center", va="center", fontsize=20) #output and order titles = [] titles.append("Raw score distributions") titles.append("Normalized score distributions") if args.debug: for (cond1, cond2) in comparisons: titles.append("Background log2FCs ({0} / {1})".format( cond1, cond2)) for (cond1, cond2) in comparisons: titles.append("BINDetect plot ({0} / {1})".format(cond1, cond2)) plt.text(0.1, 0.6, "\n".join([ "Page {0}) {1}".format(i + 2, titles[i]) for i in range(len(titles)) ]) + "\n\n", va="top") figure_pdf.savefig(bbox_inches='tight') plt.close() ################# Read peaks ################ #Read peak and peak_header logger.info("Reading peaks") peaks = RegionList().from_bed(args.peaks) logger.info("- Found {0} regions in input peaks".format(len(peaks))) peaks = peaks.merge() #merge overlapping peaks logger.info("- Merged to {0} regions".format(len(peaks))) if len(peaks) == 0: logger.error("Input --peaks file is empty!") sys.exit() #Read header and check match with number of peak columns peak_columns = len(peaks[0]) #number of columns if args.peak_header != None: content = open(args.peak_header, "r").read() args.peak_header_list = content.split() logger.debug("Peak header: {0}".format(args.peak_header_list)) #Check whether peak header fits with number of peak columns if len(args.peak_header_list) != peak_columns: logger.error( "Length of --peak_header ({0}) does not fit number of columns in --peaks ({1})." .format(len(args.peak_header_list), peak_columns)) sys.exit() else: args.peak_header_list = ["peak_chr", "peak_start", "peak_end"] + [ "additional_" + str(num + 1) for num in range(peak_columns - 3) ] logger.debug("Peak header list: {0}".format(args.peak_header_list)) ################# Check for match between peaks and fasta/bigwig ################# logger.info( "Checking for match between --peaks and --fasta/--signals boundaries") logger.info("- Comparing peaks to {0}".format(args.genome)) fasta_obj = pysam.FastaFile(args.genome) fasta_boundaries = dict(zip(fasta_obj.references, fasta_obj.lengths)) fasta_obj.close() logger.debug("Fasta boundaries: {0}".format(fasta_boundaries)) peaks = peaks.apply_method(OneRegion.check_boundary, fasta_boundaries, "exit") #will exit if peaks are outside borders #Check boundaries of each bigwig signal individually for signal in args.signals: logger.info("- Comparing peaks to {0}".format(signal)) pybw_obj = pybw.open(signal) pybw_header = pybw_obj.chroms() pybw_obj.close() logger.debug("Signal boundaries: {0}".format(pybw_header)) peaks = peaks.apply_method(OneRegion.check_boundary, pybw_header, "exit") ##### GC content for motif scanning ###### #Make chunks of regions for multiprocessing logger.info("Estimating GC content from peak sequences") peak_chunks = peaks.chunks(args.split) gc_content_pool = pool.starmap( get_gc_content, itertools.product(peak_chunks, [args.genome])) gc_content = np.mean(gc_content_pool) #fraction args.gc = gc_content bg = np.array([(1 - args.gc) / 2.0, args.gc / 2.0, args.gc / 2.0, (1 - args.gc) / 2.0]) logger.info("- GC content estimated at {0:.2f}%".format(gc_content * 100)) ################ Get motifs ################ logger.info("Reading motifs from file") motif_list = MotifList() args.motifs = expand_dirs(args.motifs) for f in args.motifs: motif_list += MotifList().from_file(f) #List of OneMotif objects no_pfms = len(motif_list) logger.info("- Read {0} motifs".format(no_pfms)) logger.debug("Getting motifs ready") motif_list.bg = bg logger.debug("Getting reverse motifs") motif_list.extend([motif.get_reverse() for motif in motif_list]) logger.spam(motif_list) #Set prefixes for motif in motif_list: #now with reverse motifs as well motif.set_prefix(args.naming) motif.bg = bg logger.spam("Getting pssm for motif {0}".format(motif.name)) motif.get_pssm() motif_names = list(set([motif.prefix for motif in motif_list])) #Get threshold for motifs logger.debug("Getting match threshold per motif") outlist = pool.starmap(OneMotif.get_threshold, itertools.product(motif_list, [args.motif_pvalue])) motif_list = MotifList(outlist) for motif in motif_list: logger.debug("Motif {0}: threshold {1}".format(motif.name, motif.threshold)) logger.info("Creating folder structure for each TF") for TF in motif_names: logger.spam("Creating directories for {0}".format(TF)) make_directory(os.path.join(args.outdir, TF)) make_directory(os.path.join(args.outdir, TF, "beds")) make_directory(os.path.join(args.outdir, TF, "plots")) #-------------------------------------------------------------------------------------------------------------# #----------------------------------------- Plot logos for all motifs -----------------------------------------# #-------------------------------------------------------------------------------------------------------------# plus_motifs = [motif for motif in motif_list if motif.strand == "+"] logo_filenames = { motif.prefix: os.path.join(args.outdir, motif.prefix, motif.prefix + ".png") for motif in plus_motifs } logger.info("Plotting sequence logos for each motif") task_list = [ pool.apply_async(OneMotif.logo_to_file, ( motif, logo_filenames[motif.prefix], )) for motif in plus_motifs ] monitor_progress(task_list, logger) results = [task.get() for task in task_list] logger.comment("") logger.debug("Getting base64 strings per motif") for motif in motif_list: if motif.strand == "+": #motif.get_base() with open(logo_filenames[motif.prefix], "rb") as png: motif.base = base64.b64encode(png.read()).decode("utf-8") #-------------------------------------------------------------------------------------------------------------# #--------------------- Motif scanning: Find binding sites and match to footprint scores ----------------------# #-------------------------------------------------------------------------------------------------------------# logger.comment("") logger.start_logger_queue( ) #start process for listening and handling through the main logger queue args.log_q = logger.queue #queue for multiprocessing logging manager = mp.Manager() logger.info("Scanning for motifs and matching to signals...") #Create writer queues for bed-file output logger.debug("Setting up writer queues") qs_list = [] writer_qs = {} #writer_queue = create_writer_queue(key2file, writer_cores) #writer_queue.stop() #wait until all are done manager = mp.Manager() TF_names_chunks = [ motif_names[i::writer_cores] for i in range(writer_cores) ] for TF_names_sub in TF_names_chunks: logger.debug("Creating writer queue for {0}".format(TF_names_sub)) files = [ os.path.join(args.outdir, TF, "beds", TF + ".tmp") for TF in TF_names_sub ] q = manager.Queue() qs_list.append(q) writer_pool.apply_async( file_writer, args=(q, dict(zip(TF_names_sub, files)), args) ) #, callback = lambda x: finished.append(x) print("Writing time: {0}".format(x))) for TF in TF_names_sub: writer_qs[TF] = q writer_pool.close() #no more jobs applied to writer_pool #todo: use run_parallel #Start working on data if worker_cores == 1: logger.debug("Running with cores = 1") results = [] for chunk in peak_chunks: results.append( scan_and_score(chunk, motif_list, args, args.log_q, writer_qs)) else: logger.debug("Sending jobs to worker pool") task_list = [ pool.apply_async(scan_and_score, ( chunk, motif_list, args, args.log_q, writer_qs, )) for chunk in peak_chunks ] monitor_progress(task_list, logger) results = [task.get() for task in task_list] logger.info("Done scanning for TFBS across regions!") #logger.stop_logger_queue() #stop the listening process (wait until all was written) #--------------------------------------# logger.info("Waiting for bedfiles to write") #Stop all queues for writing logger.debug("Stop all queues by inserting None") for q in qs_list: q.put((None, None)) logger.debug("Joining bed_writer queues") for i, q in enumerate(qs_list): logger.debug("- Queue {0} (size {1})".format(i, q.qsize())) #Waits until all queues are closed writer_pool.join() #-------------------------------------------------------------------------------------------------------------# #---------------------------- Process information on background scores and overlaps --------------------------# #-------------------------------------------------------------------------------------------------------------# logger.info("Merging results from subsets") background = merge_dicts([result[0] for result in results]) TF_overlaps = merge_dicts([result[1] for result in results]) results = None #Add missing TF overlaps (if some TFs had zero sites) for TF1 in plus_motifs: if TF1.prefix not in TF_overlaps: TF_overlaps[TF1.prefix] = 0 for TF2 in plus_motifs: tup = (TF1.prefix, TF2.prefix) if tup not in TF_overlaps: TF_overlaps[tup] = 0 #Collect sampled background values for bigwig in args.cond_names: background["signal"][bigwig] = np.array(background["signal"][bigwig]) #Check how many values were fetched from background n_bg_values = len(background["signal"][args.cond_names[0]]) logger.debug("Collected {0} values from background".format(n_bg_values)) if n_bg_values < 1000: err_str = "Number of background values collected from peaks is low (={0}) ".format( n_bg_values) err_str += "- this affects estimation of the bound/unbound threshold and the normalization between conditions. " err_str += "To improve this estimation, please run BINDetect with --peaks = the full peak set across all conditions." logger.warning(err_str) logger.comment("") logger.info("Estimating score distribution per condition") fig = plot_score_distribution( [background["signal"][bigwig] for bigwig in args.cond_names], labels=args.cond_names, title="Raw scores per condition") figure_pdf.savefig(fig, bbox_inches='tight') plt.close() logger.info("Normalizing scores") list_of_vals = [background["signal"][bigwig] for bigwig in args.cond_names] normed, norm_objects = quantile_normalization(list_of_vals) args.norm_objects = dict(zip(args.cond_names, norm_objects)) for bigwig in args.cond_names: background["signal"][bigwig] = args.norm_objects[bigwig].normalize( background["signal"][bigwig]) fig = plot_score_distribution( [background["signal"][bigwig] for bigwig in args.cond_names], labels=args.cond_names, title="Normalized scores per condition") figure_pdf.savefig(fig, bbox_inches='tight') plt.close() ########################################################### logger.info("Estimating bound/unbound threshold") #Prepare scores (remove 0's etc.) bg_values = np.array(normed).flatten() bg_values = bg_values[np.logical_not(np.isclose( bg_values, 0.0))] #only non-zero counts x_max = np.percentile(bg_values, [99]) bg_values = bg_values[bg_values < x_max] if len(bg_values) == 0: logger.error( "Error processing bigwig scores from background. It could be that there are no scores in the bigwig (=0) assigned for the peaks. Please check your input files." ) sys.exit() #Fit mixture of normals lowest_bic = np.inf for n_components in [2]: #2 components gmm = sklearn.mixture.GaussianMixture(n_components=n_components, random_state=1) gmm.fit(np.log(bg_values).reshape(-1, 1)) bic = gmm.bic(np.log(bg_values).reshape(-1, 1)) logger.debug("n_compontents: {0} | bic: {1}".format(n_components, bic)) if bic < lowest_bic: lowest_bic = bic best_gmm = gmm gmm = best_gmm #Extract most-right gaussian means = gmm.means_.flatten() sds = np.sqrt(gmm.covariances_).flatten() chosen_i = np.argmax(means) #Mixture with largest mean log_params = scipy.stats.lognorm.fit(bg_values[bg_values < x_max], f0=sds[chosen_i], fscale=np.exp(means[chosen_i])) #all_log_params[bigwig] = log_params #Mode of distribution mode = scipy.optimize.fmin( lambda x: -scipy.stats.lognorm.pdf(x, *log_params), 0, disp=False)[0] logger.debug("- Mode estimated at: {0}".format(mode)) pseudo = mode / 2.0 #pseudo is half the mode args.pseudo = pseudo logger.debug("Pseudocount estimated at: {0}".format(round(args.pseudo, 5))) # Estimate theoretical normal for threshold leftside_x = np.linspace( scipy.stats.lognorm(*log_params).ppf([0.01]), mode, 100) leftside_pdf = scipy.stats.lognorm.pdf(leftside_x, *log_params) #Flip over mirrored_x = np.concatenate([leftside_x, np.max(leftside_x) + leftside_x]).flatten() mirrored_pdf = np.concatenate([leftside_pdf, leftside_pdf[::-1]]).flatten() popt, cov = scipy.optimize.curve_fit( lambda x, std, sc: sc * scipy.stats.norm.pdf(x, mode, std), mirrored_x, mirrored_pdf) norm_params = (mode, popt[0]) logger.debug("Theoretical normal parameters: {0}".format(norm_params)) #Set threshold for bound/unbound threshold = round( scipy.stats.norm.ppf(1 - args.bound_pvalue, *norm_params), 5) args.thresholds = {bigwig: threshold for bigwig in args.cond_names} logger.stats("- Threshold estimated at: {0}".format(threshold)) #Only plot if args.debug is True if args.debug: #Plot fit fig, ax = plt.subplots(1, 1) ax.hist(bg_values[bg_values < x_max], bins='auto', density=True, label="Observed score distribution") xvals = np.linspace(0, x_max, 1000) log_probas = scipy.stats.lognorm.pdf(xvals, *log_params) ax.plot(xvals, log_probas, label="Log-normal fit", color="orange") #Theoretical normal norm_probas = scipy.stats.norm.pdf(xvals, *norm_params) ax.plot(xvals, norm_probas * (np.max(log_probas) / np.max(norm_probas)), color="grey", linestyle="--", label="Theoretical normal") ax.axvline(threshold, color="black", label="Bound/unbound threshold") ymax = plt.ylim()[1] ax.text(threshold, ymax, "\n {0:.3f}".format(threshold), va="top") #Decorate plot plt.title("Score distribution") plt.xlabel("Bigwig score") plt.ylabel("Density") plt.legend(fontsize=8) plt.xlim((0, x_max)) figure_pdf.savefig(fig) plt.close(fig) ############ Foldchanges between conditions ################ logger.comment("") log2fc_params = {} if len(args.signals) > 1: logger.info( "Calculating background log2 fold-changes between conditions") for (bigwig1, bigwig2) in comparisons: #cond1, cond2 logger.info("- {0} / {1}".format(bigwig1, bigwig2)) #Estimate background log2fc scores1 = np.copy(background["signal"][bigwig1]) scores2 = np.copy(background["signal"][bigwig2]) included = np.logical_or(scores1 > 0, scores2 > 0) scores1 = scores1[included] scores2 = scores2[included] #Calculate background log2fc normal disitribution log2fcs = np.log2( np.true_divide(scores1 + args.pseudo, scores2 + args.pseudo)) lower, upper = np.percentile(log2fcs, [1, 99]) log2fcs_fit = log2fcs[np.logical_and(log2fcs >= lower, log2fcs <= upper)] norm_params = scipy.stats.norm.fit(log2fcs_fit) logger.debug( "({0} / {1}) Background log2fc normal distribution: {2}". format(bigwig1, bigwig2, norm_params)) log2fc_params[(bigwig1, bigwig2)] = norm_params #Plot background log2fc to figures fig, ax = plt.subplots(1, 1) plt.hist(log2fcs, density=True, bins='auto', label="Background log2fc ({0} / {1})".format( bigwig1, bigwig2)) xvals = np.linspace(plt.xlim()[0], plt.xlim()[1], 100) pdf = scipy.stats.norm.pdf(xvals, *log2fc_params[(bigwig1, bigwig2)]) plt.plot(xvals, pdf, label="Normal distribution fit") plt.title("Background log2FCs ({0} / {1})".format( bigwig1, bigwig2)) plt.xlabel("Log2 fold change") plt.ylabel("Density") if args.debug: figure_pdf.savefig(fig, bbox_inches='tight') f = open( os.path.join( args.outdir, "{0}_{1}_log2fcs.txt".format(bigwig1, bigwig2)), "w") f.write("\n".join([str(val) for val in log2fcs])) f.close() plt.close() background = None #free up space #-------------------------------------------------------------------------------------------------------------# #----------------------------- Read total sites per TF to estimate bound/unbound -----------------------------# #-------------------------------------------------------------------------------------------------------------# logger.comment("") logger.info("Processing scanned TFBS individually") #Getting bindetect table ready info_columns = ["total_tfbs"] info_columns.extend([ "{0}_{1}".format(cond, metric) for (cond, metric ) in itertools.product(args.cond_names, ["threshold", "bound"]) ]) info_columns.extend([ "{0}_{1}_{2}".format(comparison[0], comparison[1], metric) for (comparison, metric) in itertools.product(comparisons, ["change", "pvalue"]) ]) cols = len(info_columns) rows = len(motif_names) info_table = pd.DataFrame(np.zeros((rows, cols)), columns=info_columns, index=motif_names) #Starting calculations results = [] if args.cores == 1: for name in motif_names: logger.info("- {0}".format(name)) results.append(process_tfbs(name, args, log2fc_params)) else: logger.debug("Sending jobs to worker pool") task_list = [ pool.apply_async(process_tfbs, ( name, args, log2fc_params, )) for name in motif_names ] monitor_progress(task_list, logger) #will not exit before all jobs are done results = [task.get() for task in task_list] logger.info("Concatenating results from subsets") info_table = pd.concat(results) #pandas tables pool.terminate() pool.join() logger.stop_logger_queue() #-------------------------------------------------------------------------------------------------------------# #------------------------------------------------ Cluster TFBS -----------------------------------------------# #-------------------------------------------------------------------------------------------------------------# clustering = RegionCluster(TF_overlaps) clustering.cluster() #Convert full ids to alt ids convert = {motif.prefix: motif.name for motif in motif_list} for cluster in clustering.clusters: for name in convert: clustering.clusters[cluster]["cluster_name"] = clustering.clusters[ cluster]["cluster_name"].replace(name, convert[name]) #Write out distance matrix matrix_out = os.path.join(args.outdir, args.prefix + "_distances.txt") clustering.write_distance_mat(matrix_out) #-------------------------------------------------------------------------------------------------------------# #----------------------------------------- Write all_bindetect file ------------------------------------------# #-------------------------------------------------------------------------------------------------------------# logger.comment("") logger.info("Writing all_bindetect files") #Add columns of name / motif_id / prefix names = [] ids = [] for prefix in info_table.index: motif = [motif for motif in motif_list if motif.prefix == prefix] names.append(motif[0].name) ids.append(motif[0].id) info_table.insert(0, "output_prefix", info_table.index) info_table.insert(1, "name", names) info_table.insert(2, "motif_id", ids) #info_table.insert(3, "motif_logo", [os.path.join("motif_logos", os.path.basename(logo_filenames[prefix])) for prefix in info_table["output_prefix"]]) #add relative path to logo #Add cluster to info_table cluster_names = [] for name in info_table.index: for cluster in clustering.clusters: if name in clustering.clusters[cluster]["member_names"]: cluster_names.append( clustering.clusters[cluster]["cluster_name"]) info_table.insert(3, "cluster", cluster_names) #Cluster table on motif clusters info_table_clustered = info_table.groupby( "cluster").mean() #mean of each column info_table_clustered.reset_index(inplace=True) #Map correct type info_table["total_tfbs"] = info_table["total_tfbs"].map(int) for condition in args.cond_names: info_table[condition + "_bound"] = info_table[condition + "_bound"].map(int) #### Write excel ### bindetect_excel = os.path.join(args.outdir, args.prefix + "_results.xlsx") writer = pd.ExcelWriter(bindetect_excel, engine='xlsxwriter') #Tables info_table.to_excel(writer, index=False, sheet_name="Individual motifs") info_table_clustered.to_excel(writer, index=False, sheet_name="Motif clusters") for sheet in writer.sheets: worksheet = writer.sheets[sheet] n_rows = worksheet.dim_rowmax n_cols = worksheet.dim_colmax worksheet.autofilter(0, 0, n_rows, n_cols) writer.save() #Format comparisons for (cond1, cond2) in comparisons: base = cond1 + "_" + cond2 info_table[base + "_change"] = info_table[base + "_change"].round(5) info_table[base + "_pvalue"] = info_table[base + "_pvalue"].map( "{:.5E}".format, na_action="ignore") #Write bindetect results tables #info_table.insert(0, "TF_name", info_table.index) #Set index as first column bindetect_out = os.path.join(args.outdir, args.prefix + "_results.txt") info_table.to_csv(bindetect_out, sep="\t", index=False, header=True, na_rep="NA") #-------------------------------------------------------------------------------------------------------------# #------------------------------------------- Make BINDetect plot ---------------------------------------------# #-------------------------------------------------------------------------------------------------------------# if no_conditions > 1: logger.info("Creating BINDetect plot(s)") #Fill NAs from info_table to enable plotting of log2fcs (NA -> 0 change) change_cols = [col for col in info_table.columns if "_change" in col] pvalue_cols = [col for col in info_table.columns if "_pvalue" in col] info_table[change_cols] = info_table[change_cols].fillna(0) info_table[pvalue_cols] = info_table[pvalue_cols].fillna(1) #Plotting bindetect per comparison for (cond1, cond2) in comparisons: logger.info("- {0} / {1}".format(cond1, cond2)) base = cond1 + "_" + cond2 #Define which motifs to show xvalues = info_table[base + "_change"].astype(float) yvalues = info_table[base + "_pvalue"].astype(float) y_min = np.percentile(yvalues[yvalues > 0], 5) #5% smallest pvalues x_min, x_max = np.percentile( xvalues, [5, 95]) #5% smallest and largest changes #Make copy of motifs and fill in with metadata comparison_motifs = [ motif for motif in motif_list if motif.strand == "+" ] #copy.deepcopy(motif_list) - swig pickle error, just overwrite motif_list for motif in comparison_motifs: name = motif.prefix motif.change = float(info_table.at[name, base + "_change"]) motif.pvalue = float(info_table.at[name, base + "_pvalue"]) motif.logpvalue = -np.log10( motif.pvalue) if motif.pvalue > 0 else -np.log10(1e-308) #Assign each motif to group if motif.change < x_min or motif.change > x_max or motif.pvalue < y_min: if motif.change < 0: motif.group = cond2 + "_up" if motif.change > 0: motif.group = cond1 + "_up" else: motif.group = "n.s." #Bindetect plot fig = plot_bindetect(comparison_motifs, clustering, [cond1, cond2], args) figure_pdf.savefig(fig, bbox_inches='tight') plt.close(fig) #Interactive BINDetect plot html_out = os.path.join(args.outdir, "bindetect_" + base + ".html") plot_interactive_bindetect(comparison_motifs, [cond1, cond2], html_out) #-------------------------------------------------------------------------------------------------------------# #----------------------------- Make heatmap across conditions (for debugging)---------------------------------# #-------------------------------------------------------------------------------------------------------------# if args.debug: mean_columns = [cond + "_mean_score" for cond in args.cond_names] heatmap_table = info_table[mean_columns] heatmap_table.index = info_table["output_prefix"] #Decide fig size rows, cols = heatmap_table.shape figsize = (7 + cols, max(10, rows / 8.0)) cm = sns.clustermap( heatmap_table, figsize=figsize, z_score=0, #zscore for rows col_cluster=False, #do not cluster condition columns yticklabels=True, #show all row annotations xticklabels=True, cbar_pos=(0, 0, .4, .005), dendrogram_ratio=(0.3, 0.01), cbar_kws={ "orientation": "horizontal", 'label': 'Row z-score' }, method="single") #Adjust width of columns #hm = cm.ax_heatmap.get_position() #cm.ax_heatmap.set_position([hm.x0, hm.y0, cols * 3 * hm.height / rows, hm.height]) #aspect should be equal plt.setp(cm.ax_heatmap.get_xticklabels(), fontsize=8, rotation=45, ha="right") plt.setp(cm.ax_heatmap.get_yticklabels(), fontsize=5) cm.ax_col_dendrogram.set_title('Mean scores across conditions', fontsize=20) cm.ax_heatmap.set_ylabel("Transcription factor motifs", fontsize=15, rotation=270) #cm.ax_heatmap.set_title('Conditions') #cm.fig.suptitle('Mean scores across conditions') #cm.cax.set_visible(False) #Save to output pdf plt.tight_layout() figure_pdf.savefig(cm.fig, bbox_inches='tight') plt.close(cm.fig) #-------------------------------------------------------------------------------------------------------------# #-------------------------------------------------- Wrap up---------------------------------------------------# #-------------------------------------------------------------------------------------------------------------# figure_pdf.close() logger.end()
def run_atacorrect(args): """ Function for bias correction of input .bam files Calls functions in ATACorrect_functions and several internal classes """ #Test if required arguments were given: if args.bam == None: sys.exit("Error: No .bam-file given") if args.genome == None: sys.exit("Error: No .fasta-file given") if args.peaks == None: sys.exit("Error: No .peaks-file given") #Adjust some parameters depending on input args.prefix = os.path.splitext(os.path.basename( args.bam))[0] if args.prefix == None else args.prefix args.outdir = os.path.abspath( args.outdir) if args.outdir != None else os.path.abspath(os.getcwd()) #Set output bigwigs based on input tracks = ["uncorrected", "bias", "expected", "corrected"] tracks = [track for track in tracks if track not in args.track_off] # switch off printing if args.split_strands == True: strands = ["forward", "reverse"] else: strands = ["both"] output_bws = {} for track in tracks: output_bws[track] = {} for strand in strands: elements = [args.prefix, track] if strand == "both" else [ args.prefix, track, strand ] output_bws[track][strand] = { "fn": os.path.join(args.outdir, "{0}.bw".format("_".join(elements))) } #Set all output files bam_out = os.path.join(args.outdir, args.prefix + "_atacorrect.bam") bigwigs = [ output_bws[track][strand]["fn"] for (track, strand) in itertools.product(tracks, strands) ] figures_f = os.path.join(args.outdir, "{0}_atacorrect.pdf".format(args.prefix)) output_files = bigwigs + [figures_f] output_files = list(OrderedDict.fromkeys( output_files)) #remove duplicates due to "both" option strands = ["forward", "reverse"] #----------------------------------------------------------------------------------------------------# # Print info on run #----------------------------------------------------------------------------------------------------# logger = TobiasLogger("ATACorrect", args.verbosity) logger.begin() parser = add_atacorrect_arguments(argparse.ArgumentParser()) logger.arguments_overview(parser, args) logger.output_files(output_files) args.cores = check_cores(args.cores, logger) #----------------------------------------------------------------------------------------------------# # Test input file availability for reading #----------------------------------------------------------------------------------------------------# logger.info("----- Processing input data -----") logger.debug("Testing input file availability") check_files([args.bam, args.genome, args.peaks], "r") logger.debug("Testing output directory/file writeability") make_directory(args.outdir) check_files(output_files, "w") #Open pdf for figures figure_pdf = PdfPages(figures_f, keep_empty=False) #----------------------------------------------------------------------------------------------------# # Read information in bam/fasta #----------------------------------------------------------------------------------------------------# logger.info("Reading info from .bam file") bamfile = pysam.AlignmentFile(args.bam, "rb") if bamfile.has_index() == False: logger.warning("No index found for bamfile - creating one via pysam.") pysam.index(args.bam) bam_references = bamfile.references #chromosomes in correct order bam_chrom_info = dict(zip(bamfile.references, bamfile.lengths)) logger.debug("bam_chrom_info: {0}".format(bam_chrom_info)) bamfile.close() logger.info("Reading info from .fasta file") fastafile = pysam.FastaFile(args.genome) fasta_chrom_info = dict(zip(fastafile.references, fastafile.lengths)) logger.debug("fasta_chrom_info: {0}".format(fasta_chrom_info)) fastafile.close() #Compare chrom lengths chrom_in_common = set(bam_chrom_info.keys()).intersection( fasta_chrom_info.keys()) for chrom in chrom_in_common: bamlen = bam_chrom_info[chrom] fastalen = fasta_chrom_info[chrom] if bamlen != fastalen: logger.warning("(Fastafile)\t{0} has length {1}".format( chrom, fasta_chrom_info[chrom])) logger.warning("(Bamfile)\t{0} has length {1}".format( chrom, bam_chrom_info[chrom])) sys.exit( "Error: .bam and .fasta have different chromosome lengths. Please make sure the genome file is similar to the one used in mapping." ) #Subset bam_references to those for which there are sequences in fasta chrom_not_in_fasta = set(bam_references) - set(fasta_chrom_info.keys()) if len(chrom_not_in_fasta) > 1: logger.warning( "The following contigs in --bam did not have sequences in --fasta: {0}. NOTE: These contigs will be skipped in calculation and output." .format(chrom_not_in_fasta)) bam_references = [ref for ref in bam_references if ref in fasta_chrom_info] chrom_in_common = [ref for ref in chrom_in_common if ref in bam_references] #----------------------------------------------------------------------------------------------------# # Read regions from bedfiles #----------------------------------------------------------------------------------------------------# logger.info("Processing input/output regions") #Chromosomes included in analysis genome_regions = RegionList().from_list([ OneRegion([chrom, 0, bam_chrom_info[chrom]]) for chrom in bam_references if not "M" in chrom ]) #full genome length chrom_in_common = [chrom for chrom in chrom_in_common if "M" not in chrom] logger.debug("CHROMS\t{0}".format("; ".join( ["{0} ({1})".format(reg.chrom, reg.end) for reg in genome_regions]))) genome_bp = sum([region.get_length() for region in genome_regions]) # Process peaks peak_regions = RegionList().from_bed(args.peaks) peak_regions.merge() for i in range(len(peak_regions) - 1, -1, -1): region = peak_regions[i] peak_regions[i] = region.check_boundary( bam_chrom_info, "cut") #regions are cut/removed from list if peak_regions[i] is None: logger.warning( "Peak region {0} was removed at it is either out of bounds or not in the chromosomes given in genome/bam." .format(region.tup(), i + 1)) del peak_regions[i] nonpeak_regions = deepcopy(genome_regions).subtract(peak_regions) # Process specific input regions if given if args.regions_in != None: input_regions = RegionList().from_bed(args.regions_in) input_regions.merge() input_regions.apply_method(OneRegion.check_boundary, bam_chrom_info, "cut") else: input_regions = nonpeak_regions # Process specific output regions if args.regions_out != None: output_regions = RegionList().from_bed(args.regions_out) else: output_regions = deepcopy(peak_regions) #Extend regions to make sure extend + flanking for window/flank are within boundaries flank_extend = args.k_flank + int(args.window / 2.0) output_regions.apply_method(OneRegion.extend_reg, args.extend + flank_extend) output_regions.merge() output_regions.apply_method(OneRegion.check_boundary, bam_chrom_info, "cut") output_regions.apply_method( OneRegion.extend_reg, -flank_extend ) #Cut to needed size knowing that the region will be extended in function #Remove blacklisted regions and chromosomes not in common blacklist_regions = RegionList().from_bed( args.blacklist) if args.blacklist != None else RegionList( []) #fill in with regions from args.blacklist regions_dict = { "genome": genome_regions, "input_regions": input_regions, "output_regions": output_regions, "peak_regions": peak_regions, "nonpeak_regions": nonpeak_regions, "blacklist_regions": blacklist_regions } for sub in [ "input_regions", "output_regions", "peak_regions", "nonpeak_regions" ]: regions_sub = regions_dict[sub] regions_sub.subtract(blacklist_regions) regions_sub = regions_sub.apply_method(OneRegion.split_region, 50000) regions_sub.keep_chroms(chrom_in_common) regions_dict[sub] = regions_sub #write beds to look at in igv #input_regions.write_bed(os.path.join(args.outdir, "input_regions.bed")) #output_regions.write_bed(os.path.join(args.outdir, "output_regions.bed")) #peak_regions.write_bed(os.path.join(args.outdir, "peak_regions.bed")) #nonpeak_regions.write_bed(os.path.join(args.outdir, "nonpeak_regions.bed")) #Sort according to order in bam_references: output_regions.loc_sort(bam_references) chrom_order = {bam_references[i]: i for i in range(len(bam_references)) } #for use later when sorting output #### Statistics about regions #### genome_bp = sum([region.get_length() for region in regions_dict["genome"]]) for key in regions_dict: total_bp = sum([region.get_length() for region in regions_dict[key]]) logger.stats("{0}: {1} regions | {2} bp | {3:.2f}% coverage".format( key, len(regions_dict[key]), total_bp, total_bp / genome_bp * 100)) #Estallish variables for regions to be used input_regions = regions_dict["input_regions"] output_regions = regions_dict["output_regions"] peak_regions = regions_dict["peak_regions"] nonpeak_regions = regions_dict["nonpeak_regions"] #Exit if no input/output regions were found if len(input_regions) == 0 or len(output_regions) == 0 or len( peak_regions) == 0 or len(nonpeak_regions) == 0: logger.error("No regions found - exiting!") sys.exit() #----------------------------------------------------------------------------------------------------# # Estimate normalization factors #----------------------------------------------------------------------------------------------------# #Setup logger queue logger.debug("Setting up listener for log") logger.start_logger_queue() args.log_q = logger.queue #----------------------------------------------------------------------------------------------------# logger.comment("") logger.info("----- Estimating normalization factors -----") #If normalization is to be calculated if not args.norm_off: #Reads in peaks/nonpeaks logger.info("Counting reads in peak regions") peak_region_chunks = peak_regions.chunks(args.split) reads_peaks = sum( run_parallel(count_reads, peak_region_chunks, [args], args.cores, logger)) logger.comment("") logger.info("Counting reads in nonpeak regions") nonpeak_region_chunks = nonpeak_regions.chunks(args.split) reads_nonpeaks = sum( run_parallel(count_reads, nonpeak_region_chunks, [args], args.cores, logger)) reads_total = reads_peaks + reads_nonpeaks logger.stats("TOTAL_READS\t{0}".format(reads_total)) logger.stats("PEAK_READS\t{0}".format(reads_peaks)) logger.stats("NONPEAK_READS\t{0}".format(reads_nonpeaks)) lib_norm = 10000000 / reads_total frip = reads_peaks / reads_total correct_factor = lib_norm * (1 / frip) logger.stats("LIB_NORM\t{0:.5f}".format(lib_norm)) logger.stats("FRiP\t{0:.5f}".format(frip)) else: logger.info("Normalization was switched off") correct_factor = 1.0 logger.stats("CORRECTION_FACTOR:\t{0:.5f}".format(correct_factor)) #----------------------------------------------------------------------------------------------------# # Estimate sequence bias #----------------------------------------------------------------------------------------------------# logger.comment("") logger.info("Started estimation of sequence bias...") input_region_chunks = input_regions.chunks( args.split) #split to 100 chunks (also decides the step of output) out_lst = run_parallel(bias_estimation, input_region_chunks, [args], args.cores, logger) #Output is list of AtacBias objects #Join objects estimated_bias = out_lst[0] #initialize object with first output for output in out_lst[1:]: estimated_bias.join( output ) #bias object contains bias/background SequenceMatrix objects logger.debug("Bias estimated\tno_reads: {0}".format( estimated_bias.no_reads)) #----------------------------------------------------------------------------------------------------# # Join estimations from all chunks of regions #----------------------------------------------------------------------------------------------------# bias_obj = estimated_bias bias_obj.correction_factor = correct_factor ### Bias motif ### logger.info("Finalizing bias motif for scoring") for strand in strands: bias_obj.bias[strand].prepare_mat() logger.debug("Saving pssm to figure pdf") fig = plot_pssm(bias_obj.bias[strand].pssm, "Tn5 insertion bias of reads ({0})".format(strand)) figure_pdf.savefig(fig) #Write bias motif to pickle out_f = os.path.join(args.outdir, args.prefix + "_AtacBias.pickle") logger.debug("Saving bias object to pickle ({0})".format(out_f)) bias_obj.to_pickle(out_f) #----------------------------------------------------------------------------------------------------# # Correct read bias and write to bigwig #----------------------------------------------------------------------------------------------------# logger.comment("") logger.info("----- Correcting reads from .bam within output regions -----") output_regions.loc_sort(bam_references) #sort in order of references output_regions_chunks = output_regions.chunks(args.split) no_tasks = float(len(output_regions_chunks)) chunk_sizes = [len(chunk) for chunk in output_regions_chunks] logger.debug("All regions chunked: {0} ({1})".format( len(output_regions), chunk_sizes)) ### Create key-file linking for bigwigs key2file = {} for track in output_bws: for strand in output_bws[track]: filename = output_bws[track][strand]["fn"] key = "{}:{}".format(track, strand) key2file[key] = filename #Start correction/write cores n_bigwig = len(key2file.values()) writer_cores = min(n_bigwig, max( 1, int(args.cores * 0.1))) #at most one core per bigwig or 10% of cores (or 1) worker_cores = max(1, args.cores - writer_cores) logger.debug("Worker cores: {0}".format(worker_cores)) logger.debug("Writer cores: {0}".format(writer_cores)) worker_pool = mp.Pool(processes=worker_cores) writer_pool = mp.Pool(processes=writer_cores) manager = mp.Manager() #Start bigwig file writers writer_tasks = [] header = [(chrom, bam_chrom_info[chrom]) for chrom in bam_references] key_chunks = [ list(key2file.keys())[i::writer_cores] for i in range(writer_cores) ] qs_list = [] qs = {} for chunk in key_chunks: logger.debug("Creating writer queue for {0}".format(chunk)) q = manager.Queue() qs_list.append(q) files = [key2file[key] for key in chunk] writer_tasks.append( writer_pool.apply_async(bigwig_writer, args=(q, dict(zip(chunk, files)), header, output_regions, args)) ) #, callback = lambda x: finished.append(x) print("Writing time: {0}".format(x))) for key in chunk: qs[key] = q args.qs = qs writer_pool.close() #no more jobs applied to writer_pool #Start correction logger.debug("Starting correction") task_list = [ worker_pool.apply_async(bias_correction, args=[chunk, args, bias_obj]) for chunk in output_regions_chunks ] worker_pool.close() monitor_progress(task_list, logger, "Correction progress:" ) #does not exit until tasks in task_list finished results = [task.get() for task in task_list] #Get all results pre_bias = results[0][0] #initialize with first result post_bias = results[0][1] #initialize with first result for result in results[1:]: pre_bias_chunk = result[0] post_bias_chunk = result[1] for direction in strands: pre_bias[direction].add_counts(pre_bias_chunk[direction]) post_bias[direction].add_counts(post_bias_chunk[direction]) #Stop all queues for writing logger.debug("Stop all queues by inserting None") for q in qs_list: q.put((None, None, None)) #Fetch error codes from bigwig writers logger.debug("Fetching possible errors from bigwig_writer tasks") results = [task.get() for task in writer_tasks] #blocks until writers are finished logger.debug("Joining bigwig_writer queues") qsum = sum([q.qsize() for q in qs_list]) while qsum != 0: qsum = sum([q.qsize() for q in qs_list]) logger.spam("- Queue sizes {0}".format([(key, qs[key].qsize()) for key in qs])) time.sleep(0.5) #Waits until all queues are closed writer_pool.join() worker_pool.terminate() worker_pool.join() #Stop multiprocessing logger logger.stop_logger_queue() #----------------------------------------------------------------------------------------------------# # Information and verification of corrected read frequencies #----------------------------------------------------------------------------------------------------# logger.comment("") logger.info("Verifying bias correction") #Calculating variance per base for strand in strands: #Invert negative counts abssum = np.abs(np.sum(post_bias[strand].neg_counts, axis=0)) post_bias[strand].neg_counts = post_bias[strand].neg_counts + abssum #Join negative/positive counts post_bias[strand].counts += post_bias[strand].neg_counts #now pos pre_bias[strand].prepare_mat() post_bias[strand].prepare_mat() pre_var = np.mean(np.var(pre_bias[strand].bias_pwm, axis=1)[:4]) #mean of variance per nucleotide post_var = np.mean(np.var(post_bias[strand].bias_pwm, axis=1)[:4]) logger.stats("BIAS\tpre-bias variance {0}:\t{1:.7f}".format( strand, pre_var)) logger.stats("BIAS\tpost-bias variance {0}:\t{1:.7f}".format( strand, post_var)) #Plot figure fig_title = "Nucleotide frequencies in corrected reads\n({0} strand)".format( strand) figure_pdf.savefig( plot_correction(pre_bias[strand].bias_pwm, post_bias[strand].bias_pwm, fig_title)) #----------------------------------------------------------------------------------------------------# # Finish up #----------------------------------------------------------------------------------------------------# plt.close('all') figure_pdf.close() logger.end()
def run_motifclust(args): ###### Check input arguments ###### check_required(args, ["motifs"]) #Check input arguments check_files([args.motifs]) #Check if files exist out_cons_img = os.path.join(args.outdir, "consensus_motifs_img") make_directory(out_cons_img) out_prefix = os.path.join(args.outdir, args.prefix) ###### Create logger and write argument overview ###### logger = TobiasLogger("ClusterMotifs", args.verbosity) logger.begin() parser = add_motifclust_arguments(argparse.ArgumentParser()) logger.arguments_overview(parser, args) #logger.output_files([]) out_prefix = os.path.join(args.outdir, args.prefix) #----------------------------------------- Check for gimmemotifs ----------------------------------------# try: from gimmemotifs.motif import Motif from gimmemotifs.comparison import MotifComparer sns.set_style( "ticks" ) #set style back to ticks, as this is set globally during gimmemotifs import except ModuleNotFoundError: logger.error( "MotifClust requires the python package 'gimmemotifs'. You can install it using 'pip install gimmemotifs' or 'conda install gimmemotifs'." ) sys.exit(1) except ImportError as e: #gimmemotifs was installed, but there was an error during import pandas_version = pd.__version__ python_version = platform.python_version() if e.name == "collections" and ( version.parse(python_version) >= version.parse("3.10.0") ): #collections error from norns=0.1.5 and from other packages on python=3.10 logger.error( "Due to package dependency errors, 'TOBIAS ClusterMotifs' is not available for python>=3.10. Current python version is '{0}'. Please downgrade python in order to use this tool." .format(python_version)) sys.exit(1) elif e.name == "pandas.core.indexing" and ( version.parse(pandas_version) >= version.parse("1.3.0")): logger.error( "Package 'gimmemotifs' version < 0.17.0 requires 'pandas' version < 1.3.0. Current pandas version is {0}." .format(pandas_version)) sys.exit(1) else: #other import error logger.error( "Tried to import package 'gimmemotifs' but failed with error: '{0}'" .format(repr(e))) logger.error("Traceback:") raise e except Exception as e: logger.error( "Tried to import package 'gimmemotifs' but failed with error: '{0}'" .format(repr(e))) logger.error( "Please check that 'gimmemotifs' was successfully installed.") sys.exit(1) #Check gimmemotifs version vs. metric given import gimmemotifs gimme_version = gimmemotifs.__version__ if gimme_version == "0.17.0" and args.dist_method in ["pcc", "akl"]: logger.warning( "The dist_method given ('{0}') is invalid for gimmemotifs version 0.17.0. Please choose another --dist_method. See also: https://github.com/vanheeringen-lab/gimmemotifs/issues/243" .format(args.dist_method)) sys.exit(1) #---------------------------------------- Reading motifs from file(s) -----------------------------------# logger.info("Reading input file(s)") motif_list = MotifList() #list containing OneMotif objects motif_dict = {} #dictionary containing separate motif lists per file if sys.version_info < ( 3, 7, 0): # workaround for deepcopy with python version < 3.5 copy._deepcopy_dispatch[type(re.compile(''))] = lambda r, _: r for f in args.motifs: logger.debug("Reading {0}".format(f)) motif_format = get_motif_format(open(f).read()) sub_motif_list = MotifList().from_file(f) #MotifList object logger.stats("- Read {0} motifs from {1} (format: {2})".format( len(sub_motif_list), f, motif_format)) motif_list.extend(sub_motif_list) motif_dict[f] = sub_motif_list #Check whether ids are unique #TODO #---------------------------------------- Motif stats ---------------------------------------------------# logger.info("Creating matrix statistics") gimmemotifs_list = [ motif.get_gimmemotif().gimme_obj for motif in motif_list ] #Stats for all motifs full_motifs_out = out_prefix + "_stats_motifs.txt" motifs_stats = get_motif_stats(gimmemotifs_list) write_motif_stats(motifs_stats, full_motifs_out) #---------------------------------------- Motif clustering ----------------------------------------------# logger.info("Clustering motifs") clusters = motif_list.cluster(threshold=args.threshold, metric=args.dist_method, clust_method=args.clust_method) logger.stats("- Identified {0} clusters".format(len(clusters))) #Write out overview of clusters cluster_dict = { cluster_id: [ motif.get_gimmemotif().gimme_obj.id for motif in clusters[cluster_id] ] for cluster_id in clusters } cluster_f = out_prefix + "_" + "clusters.yml" logger.info("- Writing clustering to {0}".format(cluster_f)) write_yaml(cluster_dict, cluster_f) # Save similarity matrix to file matrix_out = out_prefix + "_matrix.txt" logger.info("- Saving similarity matrix to the file: " + str(matrix_out)) motif_list.similarity_matrix.to_csv(matrix_out, sep='\t') #Plot dendrogram logger.info("Plotting clustering dendrogram") dendrogram_f = out_prefix + "_" + "dendrogram." + args.type #plot format pdf/png plot_dendrogram(motif_list.similarity_matrix.columns, motif_list.linkage_mat, 12, dendrogram_f, "Clustering", args.threshold, args.dpi) #---------------------------------------- Consensus motif -----------------------------------------------# logger.info("Building consensus motif for each cluster") consensus_motifs = MotifList() for cluster_id in clusters: consensus = clusters[cluster_id].create_consensus( metric=args.dist_method ) #MotifList object with create_consensus method consensus.id = cluster_id if len( clusters[cluster_id]) > 1 else clusters[cluster_id][ 0].id #set original motif id if cluster length = 1 consensus_motifs.append(consensus) #Write out consensus motif file out_f = out_prefix + "_consensus_motifs." + args.cons_format logger.info("- Writing consensus motifs to: {0}".format(out_f)) consensus_motifs.to_file(out_f, args.cons_format) #Create logo plots out_cons_img = os.path.join(args.outdir, "consensus_motifs_img") logger.info( "- Making logo plots for consensus motifs (output folder: {0})".format( out_cons_img)) for motif in consensus_motifs: filename = os.path.join(out_cons_img, motif.id + "_consensus." + args.type) motif.logo_to_file(filename) #---------------------------------------- Plot heatmap --------------------------------------------------# logger.info("Plotting similarity heatmap") logger.info( "Note: Can take a while for --type=pdf. Try \"--type png\" for speed up." ) args.nrc = False args.ncc = False args.zscore = "None" clust_linkage = motif_list.linkage_mat similarity_matrix = motif_list.similarity_matrix pdf_out = out_prefix + "_heatmap_all." + args.type x_label = "All motifs" y_label = "All motifs" plot_heatmap(similarity_matrix, pdf_out, clust_linkage, clust_linkage, args.dpi, x_label, y_label, args.color, args.ncc, args.nrc, args.zscore) # Plot heatmaps for each combination of motif files comparisons = itertools.combinations(args.motifs, 2) for i, (motif_file_1, motif_file_2) in enumerate(comparisons): pdf_out = out_prefix + "_heatmap" + str(i) + "." + args.type logger.info("Plotting comparison of {0} and {1} motifs to the file ". format(motif_file_1, motif_file_2) + str(pdf_out)) x_label, y_label = motif_file_1, motif_file_2 #Create subset of matrices for row/col clustering motif_names_1 = [ motif.get_gimmemotif().gimme_obj.id for motif in motif_dict[motif_file_1] ] motif_names_2 = [ motif.get_gimmemotif().gimme_obj.id for motif in motif_dict[motif_file_2] ] m1_matrix, m2_matrix, similarity_matrix_sub = subset_matrix( similarity_matrix, motif_names_1, motif_names_2) col_linkage = linkage(ssd.squareform(m1_matrix)) if ( len(motif_names_1) > 1 and len(motif_names_2) > 1) else None row_linkage = linkage(ssd.squareform(m2_matrix)) if ( len(motif_names_1) > 1 and len(motif_names_2) > 1) else None #Plot similarity heatmap between file1 and file2 plot_heatmap(similarity_matrix_sub, pdf_out, col_linkage, row_linkage, args.dpi, x_label, y_label, args.color, args.ncc, args.nrc, args.zscore) # ClusterMotifs finished logger.end()
def run_aggregate(args): """ Function to make aggregate plot given input from args """ ######################################################################################### ############################## Setup logger/input/output ################################ ######################################################################################### logger = TobiasLogger("PlotAggregate", args.verbosity) logger.begin() parser = add_aggregate_arguments(argparse.ArgumentParser()) logger.arguments_overview(parser, args) logger.output_files([args.output]) #Check input parameters check_required(args, ["TFBS", "signals"]) check_files([ args.TFBS, args.signals, args.regions, args.whitelist, args.blacklist ], action="r") check_files([args.output], action="w") #### Test input #### if args.TFBS_labels != None and (len(args.TFBS) != len(args.TFBS_labels)): logger.error( "ERROR --TFBS and --TFBS-labels have different lengths ({0} vs. {1})" .format(len(args.TFBS), len(args.TFBS_labels))) sys.exit(1) if args.region_labels != None and (len(args.regions) != len( args.region_labels)): logger.error( "ERROR: --regions and --region-labels have different lengths ({0} vs. {1})" .format(len(args.regions), len(args.region_labels))) sys.exit(1) if args.signal_labels != None and (len(args.signals) != len( args.signal_labels)): logger.error( "ERROR: --signals and --signal-labels have different lengths ({0} vs. {1})" .format(len(args.signals), len(args.signal_labels))) sys.exit(1) #### Format input #### args.TFBS_labels = [ os.path.splitext(os.path.basename(f))[0] for f in args.TFBS ] if args.TFBS_labels == None else args.TFBS_labels args.region_labels = [ os.path.splitext(os.path.basename(f))[0] for f in args.regions ] if args.region_labels == None else args.region_labels args.signal_labels = [ os.path.splitext(os.path.basename(f))[0] for f in args.signals ] if args.signal_labels == None else args.signal_labels #TFBS labels cannot be the same if len(set(args.TFBS_labels)) < len( args.TFBS_labels): #this indicates duplicates logger.error( "ERROR: --TFBS-labels are not allowed to contain duplicates. Note that '--TFBS-labels' are created automatically from the '--TFBS'-files if no input was given." + "Please check that neither contain duplicate names") sys.exit(1) ######################################################################################### ############################ Get input regions/signals ready ############################ ######################################################################################### logger.info("---- Processing input ----") logger.info("Reading information from .bed-files") #Make combinations of TFBS / regions region_names = [] if len(args.regions) > 0: logger.info("Overlapping sites to --regions") regions_dict = {} combis = itertools.product(range(len(args.TFBS)), range(len(args.regions))) for (i, j) in combis: TFBS_f = args.TFBS[i] region_f = args.regions[j] #Make overlap pb_tfbs = pb.BedTool(TFBS_f) pb_region = pb.BedTool(region_f) #todo: write out lengths overlap = pb_tfbs.intersect(pb_region, u=True) #todo: length after overlap name = args.TFBS_labels[i] + " <OVERLAPPING> " + args.region_labels[ j] #name for column region_names.append(name) regions_dict[name] = RegionList().from_bed(overlap.fn) if args.negate == True: overlap_neg = pb_tfbs.intersect(pb_region, v=True) name = args.TFBS_labels[ i] + " <NOT OVERLAPPING> " + args.region_labels[j] region_names.append(name) regions_dict[name] = RegionList().from_bed(overlap_neg.fn) else: region_names = args.TFBS_labels regions_dict = { args.TFBS_labels[i]: RegionList().from_bed(args.TFBS[i]) for i in range(len(args.TFBS)) } for name in region_names: logger.stats("COUNT {0}: {1} sites".format( name, len(regions_dict[name]))) #length of RegionList obj #-------- Do overlap of regions if whitelist / blacklist -------# if len(args.whitelist) > 0 or len(args.blacklist) > 0: logger.info("Subsetting regions on whitelist/blacklist") for regions_id in regions_dict: sites = pb.BedTool(regions_dict[regions_id].as_bed(), from_string=True) logger.stats("Found {0} sites in {1}".format( len(regions_dict[regions_id]), regions_id)) if len(args.whitelist) > 0: for whitelist_f in args.whitelist: whitelist = pb.BedTool(whitelist_f) sites_tmp = sites.intersect(whitelist, u=True) sites = sites_tmp logger.stats("Overlapped to whitelist -> {0}".format( len(sites))) if len(args.blacklist) > 0: for blacklist_f in args.blacklist: blacklist = pb.BedTool(blacklist_f) sites_tmp = sites.intersect(blacklist, v=True) sites = sites_tmp logger.stats("Removed blacklist -> {0}".format( format(len(sites)))) regions_dict[regions_id] = RegionList().from_bed(sites.fn) # Estimate motif width per --TFBS motif_widths = {} for regions_id in regions_dict: site_list = regions_dict[regions_id] if len(site_list) > 0: motif_widths[regions_id] = site_list[0].get_width() else: motif_widths[regions_id] = 0 ######################################################################################### ############################ Read signal for bigwig per site ############################ ######################################################################################### logger.info("Reading signal from bigwigs") args.width = args.flank * 2 #output regions will be of args.width signal_dict = {} for i, signal_f in enumerate(args.signals): signal_name = args.signal_labels[i] signal_dict[signal_name] = {} #Open pybw to read signal pybw = pyBigWig.open(signal_f) boundaries = pybw.chroms() #dictionary of {chrom: length} logger.info("- Reading signal from {0}".format(signal_name)) for regions_id in regions_dict: original = copy.deepcopy(regions_dict[regions_id]) # Set width (centered on mid) regions_dict[regions_id].apply_method(OneRegion.set_width, args.width) #Check that regions are within boundaries and remove if not invalid = [ i for i, region in enumerate(regions_dict[regions_id]) if region.check_boundary(boundaries, action="remove") == None ] for invalid_idx in invalid[::-1]: #idx from higher to lower logger.warning( "Region '{reg}' ('{orig}' before flank extension) from bed regions '{id}' is out of chromosome boundaries. This region will be excluded from output." .format(reg=regions_dict[regions_id][invalid_idx].pretty(), orig=original[invalid_idx].pretty(), id=regions_id)) del regions_dict[regions_id][invalid_idx] #Get signal from remaining regions for one_region in regions_dict[regions_id]: tup = one_region.tup() #(chr, start, end, strand) if tup not in signal_dict[ signal_name]: #only get signal if it was not already read previously signal_dict[signal_name][tup] = one_region.get_signal( pybw, logger=logger, key=signal_name) #returns signal pybw.close() ######################################################################################### ################################## Calculate aggregates ################################# ######################################################################################### signal_names = args.signal_labels #Calculate aggregate per signal/region comparison logger.info("Calculating aggregate signals") aggregate_dict = { signal_name: {region_name: [] for region_name in regions_dict} for signal_name in signal_names } for row, signal_name in enumerate(signal_names): for col, region_name in enumerate(region_names): signalmat = np.array([ signal_dict[signal_name][reg.tup()] for reg in regions_dict[region_name] ]) #Check shape of signalmat if signalmat.shape[0] == 0: #no regions logger.warning( "No regions left for '{0}'. The aggregate for this signal will be set to 0." .format(signal_name)) aggregate = np.zeros(args.width) else: #Exclude outlier rows max_values = np.max(signalmat, axis=1) upper_limit = np.percentile( max_values, [100 * args.remove_outliers ])[0] #remove-outliers is a fraction logical = max_values <= upper_limit logger.debug( "{0}:{1}\tUpper limit: {2} (regions removed: {3})".format( signal_name, region_name, upper_limit, len(signalmat) - sum(logical))) signalmat = signalmat[logical] #Log-transform values before aggregating if args.log_transform: signalmat_abs = np.abs(signalmat) signalmat_log = np.log2(signalmat_abs + 1) signalmat_log[ signalmat < 0] *= -1 #original negatives back to <0 signalmat = signalmat_log aggregate = np.nanmean(signalmat, axis=0) #normalize between 0-1 if args.normalize: aggregate = preprocessing.minmax_scale(aggregate) if args.smooth > 1: aggregate_extend = np.pad(aggregate, args.smooth, "edge") aggregate_smooth = fast_rolling_math( aggregate_extend.astype('float64'), args.smooth, "mean") aggregate = aggregate_smooth[args.smooth:-args.smooth] aggregate_dict[signal_name][region_name] = aggregate signalmat = None #free up space signal_dict = None #free up space ######################################################################################### ############################## Write aggregates to file ################################# ######################################################################################### if args.output_txt is not None: #Open file for writing f_out = open(args.output_txt, "w") f_out.write("### AGGREGATE\n") f_out.write("# Signal\tRegions\tAggregate\n") for row, signal_name in enumerate(signal_names): for col, region_name in enumerate(region_names): agg = aggregate_dict[signal_name][region_name] agg_txt = ",".join(["{:.4f}".format(val) for val in agg]) f_out.write("{0}\t{1}\t{2}\n".format(signal_name, region_name, agg_txt)) f_out.close() ######################################################################################### ################################## Footprint measures ################################### ######################################################################################### logger.comment("") logger.info("---- Analysis ----") #Measure of footprint depth in comparison to baseline logger.info("Calculating footprint depth measure") logger.info("FPD (signal,regions): footprint_width baseline middle FPD") for row, signal_name in enumerate(signal_names): for col, region_name in enumerate(region_names): agg = aggregate_dict[signal_name][region_name] motif_width = motif_widths[region_name] #Estimation of possible footprint width FPD_results = [] for fp_flank in range(int(motif_width / 2), min([25, args.flank ])): #motif width for this bed #Baseline level baseline_indices = list(range( 0, args.flank - fp_flank)) + list( range(args.flank + fp_flank, len(agg))) baseline = np.mean(agg[baseline_indices]) #Footprint level middle_indices = list( range(args.flank - fp_flank, args.flank + fp_flank)) middle = np.mean(agg[middle_indices]) #within the motif #Footprint depth depth = middle - baseline FPD_results.append([fp_flank * 2, baseline, middle, depth]) #Estimation of possible footprint width all_fpds = [result[-1] for result in FPD_results] FPD_results_best = FPD_results #[result + [" "] if result[-1] != min(all_fpds) else result + ["*"] for result in FPD_results] for result in FPD_results_best: logger.stats( "FPD ({0},{1}): {2} {3:.5f} {4:.5f} {5:.5f}".format( signal_name, region_name, result[0], result[1], result[2], result[3])) #Compare pairwise to calculate correlation of signals logger.comment("") logger.info("Calculating measures for comparing pairwise aggregates") logger.info( "CORRELATION (signal1,region1) VS (signal2,region2): PEARSONR\tSUM_DIFF" ) plots = itertools.product(signal_names, region_names) combis = itertools.combinations(plots, 2) for ax1, ax2 in combis: signal1, region1 = ax1 signal2, region2 = ax2 agg1 = aggregate_dict[signal1][region1] agg2 = aggregate_dict[signal2][region2] pearsonr, pval = scipy.stats.pearsonr(agg1, agg2) diff = np.sum(np.abs(agg1 - agg2)) #Sum of difference between agg1 and agg2 logger.stats( "CORRELATION ({0},{1}) VS ({2},{3}): {4:.5f}\t{5:.5f}".format( signal1, region1, signal2, region2, pearsonr, diff)) ######################################################################################### ################################ Set up plotting grid ################################### ######################################################################################### logger.comment("") logger.info("---- Plotting aggregates ----") logger.info("Setting up plotting grid") n_signals = len(signal_names) n_regions = len(region_names) #regions are set of sites signal_compare = True if n_signals > 1 else False region_compare = True if n_regions > 1 else False #Define whether signal is on x/y if args.signal_on_x: #x-axis n_cols = n_signals col_compare = signal_compare col_names = signal_names #y-axis n_rows = n_regions row_compare = region_compare row_names = region_names else: #x-axis n_cols = n_regions col_compare = region_compare col_names = region_names #y-axis n_rows = n_signals row_compare = signal_compare row_names = signal_names #Compare across rows/cols? if row_compare: n_rows += 1 row_names += ["Comparison"] if col_compare: n_cols += 1 col_names += ["Comparison"] #Set grid fig, axarr = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 5), constrained_layout=True) axarr = np.array(axarr).reshape( (-1, 1)) if n_cols == 1 else axarr #Fix indexing for one column figures axarr = np.array(axarr).reshape( (1, -1)) if n_rows == 1 else axarr #Fix indexing for one row figures #X axis / Y axis labels #mainax = fig.add_subplot(111, frameon=False) #mainax.set_xlabel("X label", labelpad=30, fontsize=16) #mainax.set_ylabel("Y label", labelpad=30, fontsize=16) #mainax.xaxis.set_label_position('top') #Title of plot and grid plt.suptitle( " " * 7 + args.title, fontsize=16 ) #Add a little whitespace to center the title on the plot; not the frame #Titles per column for col in range(n_cols): title = col_names[col].replace(" ", "\n") l = max([len(line) for line in title.split("\n") ]) #length of longest line in title s = fontsize_func(l) #decide fontsize based on length axarr[0, col].set_title(title, fontsize=s) #Titles (ylabels) per row for row in range(n_rows): label = row_names[row] l = max([len(line) for line in label.split("\n")]) axarr[row, 0].set_ylabel(label, fontsize=fontsize_func(l)) #Colors colors = mpl.cm.brg( np.linspace(0, 1, len(signal_names) + len(region_names))) #xvals flank = int(args.width / 2.0) xvals = np.arange(-flank, flank + 1) xvals = np.delete(xvals, flank) #Settings for each subplot for row in range(n_rows): for col in range(n_cols): axarr[row, col].set_xlim(-flank, flank) axarr[row, col].set_xlabel('bp from center') #axarr[row, col].set_ylabel('Mean aggregated signal') minor_ticks = np.arange(-flank, flank, args.width / 10.0) #Settings for comparison plots a = [ axarr[-1, col].set_facecolor("0.9") if row_compare == True else 0 for col in range(n_cols) ] a = [ axarr[row, -1].set_facecolor("0.9") if col_compare == True else 0 for row in range(n_rows) ] ######################################################################################### ####################### Fill in grid with aggregate bigwig scores ####################### ######################################################################################### for si in range(n_signals): signal_name = signal_names[si] for ri in range(n_regions): region_name = region_names[ri] logger.info("Plotting regions {0} from signal {1}".format( region_name, signal_name)) row, col = (ri, si) if args.signal_on_x else (si, ri) #If there are any regions: if len(regions_dict[region_name]) > 0: #Signal in region aggregate = aggregate_dict[signal_name][region_name] axarr[row, col].plot(xvals, aggregate, color=colors[col + row], linewidth=1, label=signal_name) #Compare across rows and cols if col_compare: #compare between different columns by adding one more column axarr[row, -1].plot(xvals, aggregate, color=colors[row + col], linewidth=1, alpha=0.8, label=col_names[col]) s = min([ ax.title.get_fontproperties()._size for ax in axarr[0, :] ]) #smallest fontsize of all columns axarr[row, -1].legend(loc="lower right", fontsize=s) if row_compare: #compare between different rows by adding one more row axarr[-1, col].plot(xvals, aggregate, color=colors[row + col], linewidth=1, alpha=0.8, label=row_names[row]) s = min([ ax.yaxis.label.get_fontproperties()._size for ax in axarr[:, 0] ]) #smallest fontsize of all rows axarr[-1, col].legend(loc="lower right", fontsize=s) #Diagonal comparison if n_rows == n_cols and col_compare and row_compare and col == row: axarr[-1, -1].plot(xvals, aggregate, color=colors[row + col], linewidth=1, alpha=0.8) #Add number of sites to plot axarr[row, col].text(0.98, 0.98, str(len(regions_dict[region_name])), transform=axarr[row, col].transAxes, fontsize=12, va="top", ha="right") #Motif boundaries (can only be compared across sites) if args.plot_boundaries: #Get motif width for this list of TFBS width = motif_widths[region_names[min( row, n_rows - 2)]] if args.signal_on_x else motif_widths[ region_names[min(col, n_cols - 2)]] mstart = -np.floor(width / 2.0) mend = np.ceil( width / 2.0) - 1 #as it spans the "0" nucleotide axarr[row, col].axvline(mstart, color="grey", linestyle="dashed", linewidth=1) axarr[row, col].axvline(mend, color="grey", linestyle="dashed", linewidth=1) #------------- Finishing up plots ---------------# logger.info("Adjusting final details") #remove lower-right corner if not applicable if n_rows != n_cols and n_rows > 1 and n_cols > 1: axarr[-1, -1].axis('off') axarr[-1, -1] = None #Check whether share_y is set if args.share_y == "none": pass #Comparable rows (rowcompare = same across all) elif (args.share_y == "signals" and args.signal_on_x == False) or (args.share_y == "sites" and args.signal_on_x == True): for col in range(n_cols): lims = np.array( [ax.get_ylim() for ax in axarr[:, col] if ax is not None]) ymin, ymax = np.min(lims), np.max(lims) #Set limit across rows for this col for row in range(n_rows): if axarr[row, col] is not None: axarr[row, col].set_ylim(ymin, ymax) #Comparable columns (colcompare = same across all) elif (args.share_y == "sites" and args.signal_on_x == False) or (args.share_y == "signals" and args.signal_on_x == True): for row in range(n_rows): lims = np.array( [ax.get_ylim() for ax in axarr[row, :] if ax is not None]) ymin, ymax = np.min(lims), np.max(lims) #Set limit across cols for this row for col in range(n_cols): if axarr[row, col] is not None: axarr[row, col].set_ylim(ymin, ymax) #Comparable on both rows/columns elif args.share_y == "both": global_ymin, global_ymax = np.inf, -np.inf for row in range(n_rows): for col in range(n_cols): if axarr[row, col] is not None: local_ymin, local_ymax = axarr[row, col].get_ylim() global_ymin = local_ymin if local_ymin < global_ymin else global_ymin global_ymax = local_ymax if local_ymax > global_ymax else global_ymax for row in range(n_rows): for col in range(n_cols): if axarr[row, col] is not None: axarr[row, col].set_ylim(global_ymin, global_ymax) #Force plots to be square for row in range(n_rows): for col in range(n_cols): forceSquare(axarr[row, col]) plt.savefig(args.output, bbox_inches='tight') plt.close() logger.end()
def run_scorebigwig(args): check_required(args, ["signal", "output", "regions"]) check_files([args.signal, args.regions], "r") check_files([args.output], "w") #---------------------------------------------------------------------------------------# # Create logger and write info to log #---------------------------------------------------------------------------------------# logger = TobiasLogger("ScoreBigwig", args.verbosity) logger.begin() parser = add_scorebigwig_arguments(argparse.ArgumentParser()) logger.arguments_overview(parser, args) logger.output_files([args.output]) logger.debug("Setting up listener for log") logger.start_logger_queue() args.log_q = logger.queue #---------------------------------------------------------------------------------------# #----------------------- I/O - get regions/bigwig ready --------------------------------# #---------------------------------------------------------------------------------------# logger.info("Processing input files") logger.info("- Opening input cutsite bigwig") pybw_signal = pyBigWig.open(args.signal) pybw_header = pybw_signal.chroms() chrom_info = {chrom:int(pybw_header[chrom]) for chrom in pybw_header} logger.debug("Chromosome lengths from input bigwig: {0}".format(chrom_info)) #Decide regions logger.info("- Getting output regions ready") if args.regions: regions = RegionList().from_bed(args.regions) #Check whether regions are available in input bigwig not_in_bigwig = list(set(regions.get_chroms()) - set(chrom_info.keys())) if len(not_in_bigwig) > 0: logger.warning("Contigs {0} were found in input --regions, but were not found in input --signal. These regions cannot be scored and will therefore be excluded from output.".format(not_in_bigwig)) regions = regions.remove_chroms(not_in_bigwig) regions.apply_method(OneRegion.extend_reg, args.extend) regions.merge() regions.apply_method(OneRegion.check_boundary, chrom_info, "cut") else: regions = RegionList().from_chrom_lengths(chrom_info) #Set flank to enable scoring in ends of regions if args.score == "sum": args.region_flank = int(args.window/2.0) elif args.score == "footprint" or args.score == "FOS": args.region_flank = int(args.flank_max) else: args.region_flank = 0 #Go through each region for i, region in enumerate(regions): region.extend_reg(args.region_flank) region = region.check_boundary(chrom_info, "cut") region.extend_reg(-args.region_flank) #Information for output bigwig reference_chroms = sorted(list(chrom_info.keys())) header = [(chrom, chrom_info[chrom]) for chrom in reference_chroms] regions.loc_sort(reference_chroms) #---------------------------------------------------------------------------------------# #------------------------ Calculating footprints and writing out -----------------------# #---------------------------------------------------------------------------------------# logger.info("Calculating footprints in regions...") regions_chunks = regions.chunks(args.split) #Setup pools args.cores = check_cores(args.cores, logger) writer_cores = 1 worker_cores = max(1, args.cores - writer_cores) logger.debug("Worker cores: {0}".format(worker_cores)) logger.debug("Writer cores: {0}".format(writer_cores)) worker_pool = mp.Pool(processes=worker_cores) writer_pool = mp.Pool(processes=writer_cores) manager = mp.Manager() #Start bigwig file writers q = manager.Queue() writer_pool.apply_async(bigwig_writer, args=(q, {"scores":args.output}, header, regions, args)) writer_pool.close() #no more jobs applied to writer_pool writer_qs = {"scores": q} args.writer_qs = writer_qs #Start calculating scores pool = mp.Pool(processes=args.cores) task_list = [pool.apply_async(calculate_scores, args=[chunk, args]) for chunk in regions_chunks] no_tasks = len(task_list) pool.close() monitor_progress(task_list, logger) results = [task.get() for task in task_list] #Stop all queues for writing logger.debug("Stop all queues by inserting None") for q in writer_qs.values(): q.put((None, None, None)) #Done computing writer_pool.join() worker_pool.terminate() worker_pool.join() logger.stop_logger_queue() #Finished scoring logger.end()
def run_motifclust(args): ###### Check input arguments ###### check_required(args, ["motifs"]) #Check input arguments check_files([args.motifs]) #Check if files exist out_cons_img = os.path.join(args.outdir, "consensus_motifs_img") make_directory(out_cons_img) out_prefix = os.path.join(args.outdir, args.prefix) ###### Create logger and write argument overview ###### logger = TobiasLogger("ClusterMotifs", args.verbosity) logger.begin() parser = add_motifclust_arguments(argparse.ArgumentParser()) logger.arguments_overview(parser, args) #logger.output_files([]) out_prefix = os.path.join(args.outdir, args.prefix) #----------------------------------------- Check for gimmemotifs ----------------------------------------# try: from gimmemotifs.motif import Motif from gimmemotifs.comparison import MotifComparer sns.set_style( "ticks" ) #set style back to ticks, as this is set globally during gimmemotifs import except: logger.error( "MotifClust requires the python package 'gimmemotifs'. You can install it using 'pip install gimmemotifs' or 'conda install gimmemotifs'." ) sys.exit() #---------------------------------------- Reading motifs from file(s) -----------------------------------# logger.info("Reading input file(s)") motif_list = MotifList() #list containing OneMotif objects motif_dict = {} #dictionary containing separate motif lists per file if sys.version_info < ( 3, 7, 0): # workaround for deepcopy with python version < 3.5 copy._deepcopy_dispatch[type(re.compile(''))] = lambda r, _: r for f in args.motifs: logger.debug("Reading {0}".format(f)) motif_format = get_motif_format(open(f).read()) sub_motif_list = MotifList().from_file(f) #MotifList object logger.stats("- Read {0} motifs from {1} (format: {2})".format( len(sub_motif_list), f, motif_format)) motif_list.extend(sub_motif_list) motif_dict[f] = sub_motif_list #Check whether ids are unique #TODO #---------------------------------------- Motif stats ---------------------------------------------------# logger.info("Creating matrix statistics") gimmemotifs_list = [ motif.get_gimmemotif().gimme_obj for motif in motif_list ] #Stats for all motifs full_motifs_out = out_prefix + "_stats_motifs.txt" motifs_stats = get_motif_stats(gimmemotifs_list) write_motif_stats(motifs_stats, full_motifs_out) #---------------------------------------- Motif clustering ----------------------------------------------# logger.info("Clustering motifs") clusters = motif_list.cluster(threshold=args.threshold, metric=args.dist_method, clust_method=args.clust_method) logger.stats("- Identified {0} clusters".format(len(clusters))) #Write out overview of clusters cluster_dict = { cluster_id: [ motif.get_gimmemotif().gimme_obj.id for motif in clusters[cluster_id] ] for cluster_id in clusters } cluster_f = out_prefix + "_" + "clusters.yml" logger.info("- Writing clustering to {0}".format(cluster_f)) write_yaml(cluster_dict, cluster_f) # Save similarity matrix to file matrix_out = out_prefix + "_matrix.txt" logger.info("- Saving similarity matrix to the file: " + str(matrix_out)) motif_list.similarity_matrix.to_csv(matrix_out, sep='\t') #Plot dendrogram logger.info("Plotting clustering dendrogram") dendrogram_f = out_prefix + "_" + "dendrogram." + args.type #plot format pdf/png plot_dendrogram(motif_list.similarity_matrix.columns, motif_list.linkage_mat, 12, dendrogram_f, "Clustering", args.threshold, args.dpi) #---------------------------------------- Consensus motif -----------------------------------------------# logger.info("Building consensus motif for each cluster") consensus_motifs = MotifList() for cluster_id in clusters: consensus = clusters[cluster_id].create_consensus( ) #MotifList object with create_consensus method consensus.id = cluster_id if len( clusters[cluster_id]) > 1 else clusters[cluster_id][ 0].id #set original motif id if cluster length = 1 consensus_motifs.append(consensus) #Write out consensus motif file out_f = out_prefix + "_consensus_motifs." + args.cons_format logger.info("- Writing consensus motifs to: {0}".format(out_f)) consensus_motifs.to_file(out_f, args.cons_format) #Create logo plots out_cons_img = os.path.join(args.outdir, "consensus_motifs_img") logger.info( "- Making logo plots for consensus motifs (output folder: {0})".format( out_cons_img)) for motif in consensus_motifs: filename = os.path.join(out_cons_img, motif.id + "_consensus." + args.type) motif.logo_to_file(filename) #---------------------------------------- Plot heatmap --------------------------------------------------# logger.info("Plotting similarity heatmap") logger.info( "Note: Can take a while for --type=pdf. Try \"--type png\" for speed up." ) args.nrc = False args.ncc = False args.zscore = "None" clust_linkage = motif_list.linkage_mat similarity_matrix = motif_list.similarity_matrix pdf_out = out_prefix + "_heatmap_all." + args.type x_label = "All motifs" y_label = "All motifs" plot_heatmap(similarity_matrix, pdf_out, clust_linkage, clust_linkage, args.dpi, x_label, y_label, args.color, args.ncc, args.nrc, args.zscore) # Plot heatmaps for each combination of motif files comparisons = itertools.combinations(args.motifs, 2) for i, (motif_file_1, motif_file_2) in enumerate(comparisons): pdf_out = out_prefix + "_heatmap" + str(i) + "." + args.type logger.info("Plotting comparison of {0} and {1} motifs to the file ". format(motif_file_1, motif_file_2) + str(pdf_out)) x_label, y_label = motif_file_1, motif_file_2 #Create subset of matrices for row/col clustering motif_names_1 = [ motif.get_gimmemotif().gimme_obj.id for motif in motif_dict[motif_file_1] ] motif_names_2 = [ motif.get_gimmemotif().gimme_obj.id for motif in motif_dict[motif_file_2] ] m1_matrix, m2_matrix, similarity_matrix_sub = subset_matrix( similarity_matrix, motif_names_1, motif_names_2) col_linkage = linkage(ssd.squareform(m1_matrix)) if ( len(motif_names_1) > 1 and len(motif_names_2) > 1) else None row_linkage = linkage(ssd.squareform(m2_matrix)) if ( len(motif_names_1) > 1 and len(motif_names_2) > 1) else None #Plot similarity heatmap between file1 and file2 plot_heatmap(similarity_matrix_sub, pdf_out, col_linkage, row_linkage, args.dpi, x_label, y_label, args.color, args.ncc, args.nrc, args.zscore) # ClusterMotifs finished logger.end()
def run_formatmotifs(args): check_required(args, ["input", "output"]) #Check input arguments motif_files = expand_dirs(args.input) #Expand any dirs in input check_files(motif_files + [args.filter]) #Check if files exist # Create logger and write argument overview logger = TobiasLogger("FormatMotifs", args.verbosity) logger.begin() parser = add_formatmotifs_arguments(argparse.ArgumentParser()) logger.arguments_overview(parser, args) logger.output_files([args.output]) ####### Getting ready ####### if args.task == "split": logger.info("Making directory {0} if not existing".format(args.output)) make_directory(args.output) #Check and create output directory ### Read motifs from files ### logger.info("Reading input files...") motif_list = MotifList() converted_content = "" for f in motif_files: logger.debug("- {0}".format(f)) motif_list.extend(MotifList().from_file(f)) logger.info("Read {} motifs\n".format(len(motif_list))) #Sort out duplicate motifs all_motif_ids = [motif.id for motif in motif_list] unique_motif_ids = set(all_motif_ids) if len(all_motif_ids) != len(unique_motif_ids): logger.info( "Found duplicate motif ids in file - choosing first motif with unique id." ) motif_list = MotifList([ motif_list[all_motif_ids.index(motifid)] for motifid in unique_motif_ids ]) logger.info("Reduced to {0} unique motif ids".format(len(motif_list))) ### Filter motif list ### if args.filter != None: #Read filter logger.info("Reading entries in {0}".format(args.filter)) entries = open(args.filter, "r").read().split() logger.info("Read {0} unique filter values".format(len(set(entries)))) #Match to input motifs #print(entries) logger.info("Matching motifs to filter") used_filters = [] filtered_list = MotifList() for input_motif in motif_list: found_in_filter = 0 i = -1 while found_in_filter == 0 and i < len(entries) - 1: i += 1 if entries[i].lower() in input_motif.name.lower( ) or entries[i].lower() in input_motif.id.lower(): filtered_list.append(input_motif) logger.debug( "Selected motif {0} ({1}) due to filter value {2}". format(input_motif.name, input_motif.id, entries[i])) found_in_filter = 1 used_filters.append(entries[i]) if found_in_filter == 0: logger.debug( "Could not find any match to motif {0} ({1}) in filter". format(input_motif.name, input_motif.id)) logger.info("Filtered number of motifs from {0} to {1}".format( len(motif_list), len(filtered_list))) motif_list = filtered_list logger.debug("Filters not used: {0}".format( list(set(entries) - set(used_filters)))) #### Write out results #### if args.task == "split": logger.info("Writing individual files to directory {0}".format( args.output)) for motif in motif_list: motif_string = MotifList([motif]).as_string(args.format) #Open file and write out_path = os.path.join(args.output, motif.id + "." + args.format) logger.info("- {0}".format(out_path)) f_out = open(out_path, "w") f_out.write(motif_string) f_out.close() elif args.task == "join": logger.info("Writing converted motifs to file {0}".format(args.output)) f_out = open(args.output, "w") motif_string = motif_list.as_string(args.format) f_out.write(motif_string) f_out.close() logger.end()