def match(args): sample = dict([(m.id, m) for m in pwmfile_to_motifs(args.pwmfile)]) db = dict([(m.id, m) for m in pwmfile_to_motifs(args.dbpwmfile)]) mc = MotifComparer() result = mc.get_closest_match(sample.values(), db.values(), "partial", "wic", "mean") print "Motif\tMatch\tScore\tP-value" for motif, match in result.items(): pval, pos, orient = mc.compare_motifs(sample[motif], db[match[0]], "partial", "wic", "mean", pval=True) print "%s\t%s\t%0.2f\t%0.3e" % (motif, match[0], match[1][0], pval) if args.img: plotdata = [] for query, match in result.items(): motif = sample[query] dbmotif = db[match[0]] pval, pos, orient = mc.compare_motifs(motif, dbmotif, "partial", "wic", "mean", pval=True) if orient == -1: tmp = dbmotif.id dbmotif = dbmotif.rc() dbmotif.id = tmp if pos < 0: tmp = motif.id motif = Motif([[0.25,0.25,0.25,0.25]] * -pos + motif.pwm) motif.id = tmp elif pos > 0: tmp = dbmotif.id dbmotif = Motif([[0.25,0.25,0.25,0.25]] * pos + dbmotif.pwm) dbmotif.id = tmp plotdata.append((motif, dbmotif, pval)) match_plot(plotdata, args.img)
def create_consensus(self): """ Create consensus motif from MotifList """ motif_list = [motif.gimme_obj for motif in self] #list of gimmemotif objects if len(motif_list) > 1: consensus_found = False mc = MotifComparer() #Initialize score_dict score_dict = mc.get_all_scores(motif_list, motif_list, match="total", metric="pcc", combine="mean") while not consensus_found: #Which motifs to merge? best_similarity_motifs = sorted( find_best_pair(motif_list, score_dict) ) #indices of most similar motifs in cluster_motifs #Merge new_motif = merge_motifs(motif_list[best_similarity_motifs[0]], motif_list[best_similarity_motifs[1]]) del (motif_list[best_similarity_motifs[1]]) motif_list[best_similarity_motifs[0]] = new_motif if len(motif_list) == 1: #done merging consensus_found = True else: #Update score_dict #add the comparison of the new motif to the score_dict score_dict[new_motif.id] = score_dict.get(new_motif.id, {}) for m in motif_list: score_dict[new_motif.id][m.id] = mc.compare_motifs( new_motif, m, metric="pcc") score_dict[m.id][new_motif.id] = mc.compare_motifs( m, new_motif, metric="pcc") #Round pwm values gimmemotif_consensus = motif_list[0] gimmemotif_consensus.pwm = [[round(f, 5) for f in l] for l in gimmemotif_consensus.pwm] #Convert back to OneMotif obj onemotif_consensus = gimmemotif_to_onemotif(gimmemotif_consensus) onemotif_consensus.gimme_obj = gimmemotif_consensus #Control the naming of the new motif all_names = [motif.name for motif in self] onemotif_consensus.name = ",".join(all_names[:3]) onemotif_consensus.name += "(...)" if len(all_names) > 3 else "" return (onemotif_consensus)
def match(args): sample = dict([(m.id, m) for m in pwmfile_to_motifs(args.pwmfile)]) db = dict([(m.id, m) for m in pwmfile_to_motifs(args.dbpwmfile)]) mc = MotifComparer() result = mc.get_closest_match(sample.values(), db.values(), "partial", "wic", "mean") print("Motif\tMatch\tScore\tP-value") for motif, match in result.items(): pval, pos, orient = mc.compare_motifs(sample[motif], db[match[0]], "partial", "wic", "mean", pval=True) print("%s\t%s\t%0.2f\t%0.3e" % (motif, match[0], match[1][0], pval)) if args.img: plotdata = [] for query, match in result.items(): motif = sample[query] dbmotif = db[match[0]] pval, pos, orient = mc.compare_motifs(motif, dbmotif, "partial", "wic", "mean", pval=True) if orient == -1: tmp = dbmotif.id dbmotif = dbmotif.rc() dbmotif.id = tmp if pos < 0: tmp = motif.id motif = Motif([[0.25,0.25,0.25,0.25]] * -pos + motif.pwm) motif.id = tmp elif pos > 0: tmp = dbmotif.id dbmotif = Motif([[0.25,0.25,0.25,0.25]] * pos + dbmotif.pwm) dbmotif.id = tmp plotdata.append((motif, dbmotif, pval)) match_plot(plotdata, args.img)
def determine_closest_match(self, motifs): self.logger.debug("Determining closest matching motifs in database") motif_db = self.config.get_default_params()["motif_db"] db = os.path.join(self.config.get_motif_dir(), motif_db) db_motifs = [] if db.endswith("pwm") or db.endswith("pfm"): db_motifs = read_motifs(open(db), fmt="pwm") elif db.endswith("transfac"): db_motifs = read_motifs(db, fmt="transfac") closest_match = {} mc = MotifComparer() db_motif_lookup = dict([(m.id, m) for m in db_motifs]) match = mc.get_closest_match(motifs, db_motifs, "partial", "wic", "mean", parallel=False) for motif in motifs: # Calculate p-value pval, pos, orient = mc.compare_motifs( motif, db_motif_lookup[match[motif.id][0]], "partial", "wic", "mean", pval=True) closest_match[motif.id] = [ db_motif_lookup[match[motif.id][0]], pval ] return closest_match
def _create_images(outdir, clusters): ids = [] mc = MotifComparer() trim_ic = 0.2 sys.stderr.write("Creating images\n") for cluster,members in clusters: cluster.trim(trim_ic) cluster.to_img(os.path.join(outdir,"%s.png" % cluster.id), fmt="PNG") ids.append([cluster.id, {"src":"%s.png" % cluster.id},[]]) if len(members) > 1: scores = {} for motif in members: scores[motif] = mc.compare_motifs(cluster, motif, "total", "wic", "mean", pval=True) add_pos = sorted(scores.values(), key=lambda x: x[1])[0][1] for motif in members: _, pos, strand = scores[motif] add = pos - add_pos if strand in [1,"+"]: pass else: #print "RC %s" % motif.id rc = motif.rc() rc.id = motif.id motif = rc #print "%s\t%s" % (motif.id, add) motif.to_img(os.path.join(outdir, "%s.png" % motif.id.replace(" ", "_")), fmt="PNG", add_left=add) ids[-1][2] = [dict([("src", "%s.png" % m.id.replace(" ", "_")), ("alt", m.id.replace(" ", "_"))]) for m in members] return ids
def match(args): sample = dict([(m.id, m) for m in read_motifs(args.pfmfile)]) db = dict([(m.id, m) for m in read_motifs(args.dbpfmfile)]) mc = MotifComparer() result = mc.get_best_matches( sample.values(), args.nmatches, db.values(), "partial", "seqcor", "mean" ) plotdata = [] print("Motif\tMatch\tScore\tP-value") for motif_name, matches in result.items(): for match in matches: pval, pos, orient = mc.compare_motifs( sample[motif_name], db[match[0]], "partial", "seqcor", "mean", pval=True ) print("%s\t%s\t%0.2f\t%0.3e" % (motif_name, match[0], match[1][0], pval)) motif = sample[motif_name] dbmotif = db[match[0]] if args.img: if orient == -1: tmp = dbmotif.id dbmotif = dbmotif.rc() dbmotif.id = tmp if pos < 0: tmp = motif.id motif = Motif([[0.25, 0.25, 0.25, 0.25]] * -pos + motif.pwm) motif.id = tmp elif pos > 0: tmp = dbmotif.id dbmotif = Motif([[0.25, 0.25, 0.25, 0.25]] * pos + dbmotif.pwm) dbmotif.id = tmp diff = len(motif) - len(dbmotif) if diff > 0: dbmotif = Motif(dbmotif.pwm + [[0.25, 0.25, 0.25, 0.25]] * diff) else: motif = Motif(motif.pwm + [[0.25, 0.25, 0.25, 0.25]] * -diff) plotdata.append((motif, dbmotif, pval)) if args.img: match_plot(plotdata, args.img)
def determine_closest_match(self, motifs): self.logger.debug("Determining closest matching motifs in database") motif_db = self.config.get_default_params()["motif_db"] db = os.path.join(self.config.get_motif_dir(), motif_db) db_motifs = [] if db.endswith("pwm") or db.endswith("pfm"): db_motifs = read_motifs(open(db), fmt="pwm") elif db.endswith("transfac"): db_motifs = read_motifs(db, fmt="transfac") closest_match = {} mc = MotifComparer() db_motif_lookup = dict([(m.id, m) for m in db_motifs]) match = mc.get_closest_match(motifs, db_motifs, "partial", "wic", "mean", parallel=False) for motif in motifs: # Calculate p-value pval, pos, orient = mc.compare_motifs(motif, db_motif_lookup[match[motif.id][0]], "partial", "wic", "mean", pval=True) closest_match[motif.id] = [db_motif_lookup[match[motif.id][0]], pval] return closest_match
def merge_motifs(motif_1, motif_2): """Creates the consensus motif from two provided motifs, using the pos and orientation calculated by gimmemotifs get_all_scores() Parameter: ---------- motif_1 : Object of class Motif First gimmemotif object to create the consensus. motif_2 : Object of class Motif Second gimmemotif object to create consensus. Returns: -------- consensus : Object of class Motif Consensus of both motifs with id composed of ids of motifs it was created. """ from gimmemotifs.comparison import MotifComparer mc = MotifComparer() _, pos, orientation = mc.compare_motifs(motif_1, motif_2, metric= "pcc") consensus = motif_1.average_motifs(motif_2, pos = pos, orientation = orientation) consensus.id = motif_1.id + "+" + motif_2.id return consensus
def cluster(args): revcomp = not args.single outdir = os.path.abspath(args.outdir) if not os.path.exists(outdir): os.mkdir(outdir) trim_ic = 0.2 clusters = [] motifs = pwmfile_to_motifs(args.inputfile) if len(motifs) == 1: clusters = [[motifs[0], motifs]] else: tree = cluster_motifs(args.inputfile, "total", "wic", "mean", True, threshold=args.threshold, include_bg=True) clusters = tree.getResult() ids = [] mc = MotifComparer() sys.stderr.write("Creating images\n") for cluster,members in clusters: cluster.trim(trim_ic) cluster.to_img(os.path.join(outdir,"%s.png" % cluster.id), format="PNG") ids.append([cluster.id, {"src":"%s.png" % cluster.id},[]]) if len(members) > 1: scores = {} for motif in members: scores[motif] = mc.compare_motifs(cluster, motif, "total", "wic", "mean", pval=True) add_pos = sorted(scores.values(),cmp=lambda x,y: cmp(x[1], y[1]))[0][1] for motif in members: score, pos, strand = scores[motif] add = pos - add_pos if strand in [1,"+"]: pass else: #print "RC %s" % motif.id rc = motif.rc() rc.id = motif.id motif = rc #print "%s\t%s" % (motif.id, add) motif.to_img(os.path.join(outdir, "%s.png" % motif.id.replace(" ", "_")), format="PNG", add_left=add) ids[-1][2] = [dict([("src", "%s.png" % motif.id.replace(" ", "_")), ("alt", motif.id.replace(" ", "_"))]) for motif in members] config = MotifConfig() env = jinja2.Environment(loader=jinja2.FileSystemLoader([config.get_template_dir()])) template = env.get_template("cluster_template.jinja.html") result = template.render(motifs=ids) with open(os.path.join(outdir, "cluster_report.html"), "w") as f: f.write(result.encode('utf-8')) f = open(os.path.join(outdir, "cluster_key.txt"), "w") for id in ids: f.write("%s\t%s\n" % (id[0], ",".join([x["alt"] for x in id[2]]))) f.close() f = open(os.path.join(outdir, "clustered_motifs.pwm"), "w") if len(clusters) == 1 and len(clusters[0][1]) == 1: f.write("%s\n" % clusters[0][0].to_pwm()) else: for motif in tree.get_clustered_motifs(): f.write("%s\n" % motif.to_pwm()) f.close()
def _cluster_motifs(self, pfm_file, cluster_pwm, dir, threshold): self.logger.info("clustering significant motifs.") trim_ic = 0.2 clusters = [] motifs = read_motifs(open(pfm_file), fmt="pwm") if len(motifs) == 1: clusters = [[motifs[0], motifs]] else: tree = cluster_motifs(pfm_file, "total", "wic", "mean", True, threshold=float(threshold), include_bg=True, progress=False) clusters = tree.getResult() ids = [] mc = MotifComparer() for cluster, members in clusters: cluster.trim(trim_ic) cluster.to_img(os.path.join(self.imgdir, "%s.png" % cluster.id), format="PNG") ids.append([cluster.id, {"src": "images/%s.png" % cluster.id}, []]) if len(members) > 1: scores = {} for motif in members: scores[motif] = mc.compare_motifs(cluster, motif, "total", "wic", "mean", pval=True) add_pos = sorted(scores.values(), cmp=lambda x, y: cmp(x[1], y[1]))[0][1] for motif in members: score, pos, strand = scores[motif] add = pos - add_pos if strand in [1, "+"]: pass else: #print "RC %s" % motif.id rc = motif.rc() rc.id = motif.id motif = rc #print "%s\t%s" % (motif.id, add) motif.to_img(os.path.join( self.imgdir, "%s.png" % motif.id.replace(" ", "_")), format="PNG", add_left=add) ids[-1][2] = [ dict([("src", "images/%s.png" % motif.id.replace(" ", "_")), ("alt", motif.id.replace(" ", "_"))]) for motif in members ] env = jinja2.Environment( loader=jinja2.FileSystemLoader([self.config.get_template_dir()])) template = env.get_template("cluster_template.jinja.html") result = template.render(expname=self.basename, motifs=ids, inputfile=self.inputfile, date=datetime.today().strftime("%d/%m/%Y"), version=GM_VERSION) f = open(self.cluster_report, "w") f.write(result.encode('utf-8')) f.close() f = open(cluster_pwm, "w") if len(clusters) == 1 and len(clusters[0][1]) == 1: f.write("%s\n" % clusters[0][0].to_pwm()) else: for motif in tree.get_clustered_motifs(): f.write("%s\n" % motif.to_pwm()) f.close() self.logger.debug("Clustering done. See the result in %s", self.cluster_report) return clusters
def cluster_motifs(motifs, match="total", metric="wic", combine="mean", pval=True, threshold=0.95, trim_edges=False, edge_ic_cutoff=0.2, include_bg=True, progress=True): """ Clusters a set of sequence motifs. Required arg 'motifs' is a file containing positional frequency matrices or an array with motifs. Optional args: 'match', 'metric' and 'combine' specify the method used to compare and score the motifs. By default the WIC score is used (metric='wic'), using the the score over the whole alignment (match='total'), with the total motif score calculated as the mean score of all positions (combine='mean'). 'match' can be either 'total' for the total alignment or 'subtotal' for the maximum scoring subsequence of the alignment. 'metric' can be any metric defined in MotifComparer, currently: 'pcc', 'ed', 'distance', 'wic' or 'chisq' 'combine' determines how the total score is calculated from the score of individual positions and can be either 'sum' or 'mean' 'pval' can be True or False and determines if the score should be converted to an empirical p-value 'threshold' determines the score (or p-value) cutoff If 'trim_edges' is set to True, all motif edges with an IC below 'edge_ic_cutoff' will be removed before clustering When computing the average of two motifs 'include_bg' determines if, at a position only present in one motif, the information in that motif should be kept, or if it should be averaged with background frequencies. Should probably be left set to True. """ # First read pfm or pfm formatted motiffile if type([]) != type(motifs): motifs = read_motifs(open(motifs), fmt="pwm") mc = MotifComparer() # Trim edges with low information content if trim_edges: for motif in motifs: motif.trim(edge_ic_cutoff) # Make a MotifTree node for every motif nodes = [MotifTree(m) for m in motifs] # Determine all pairwise scores and maxscore per motif scores = {} motif_nodes = dict([(n.motif.id,n) for n in nodes]) motifs = [n.motif for n in nodes] if progress: sys.stderr.write("Calculating initial scores\n") result = mc.get_all_scores(motifs, motifs, match, metric, combine, pval, parallel=True) for m1, other_motifs in result.items(): for m2, score in other_motifs.items(): if m1 == m2: if pval: motif_nodes[m1].maxscore = 1 - score[0] else: motif_nodes[m1].maxscore = score[0] else: if pval: score = [1 - score[0]] + score[1:] scores[(motif_nodes[m1],motif_nodes[m2])] = score cluster_nodes = [node for node in nodes] ave_count = 1 total = len(cluster_nodes) while len(cluster_nodes) > 1: l = sorted(scores.keys(), key=lambda x: scores[x][0]) i = -1 (n1, n2) = l[i] while not n1 in cluster_nodes or not n2 in cluster_nodes: i -= 1 (n1,n2) = l[i] (score, pos, orientation) = scores[(n1,n2)] ave_motif = n1.motif.average_motifs(n2.motif, pos, orientation, include_bg=include_bg) ave_motif.trim(edge_ic_cutoff) ave_motif.id = "Average_%s" % ave_count ave_count += 1 new_node = MotifTree(ave_motif) if pval: new_node.maxscore = 1 - mc.compare_motifs(new_node.motif, new_node.motif, match, metric, combine, pval)[0] else: new_node.maxscore = mc.compare_motifs(new_node.motif, new_node.motif, match, metric, combine, pval)[0] new_node.mergescore = score #print "%s + %s = %s with score %s" % (n1.motif.id, n2.motif.id, ave_motif.id, score) n1.parent = new_node n2.parent = new_node new_node.left = n1 new_node.right = n2 cmp_nodes = dict([(node.motif, node) for node in nodes if not node.parent]) if progress: progress = (1 - len(cmp_nodes) / float(total)) * 100 sys.stderr.write('\rClustering [{0}{1}] {2}%'.format( '#'*(int(progress)/10), " "*(10 - int(progress)/10), int(progress))) result = mc.get_all_scores( [new_node.motif], cmp_nodes.keys(), match, metric, combine, pval, parallel=True) for motif, n in cmp_nodes.items(): x = result[new_node.motif.id][motif.id] if pval: x = [1 - x[0]] + x[1:] scores[(new_node, n)] = x nodes.append(new_node) cluster_nodes = [node for node in nodes if not node.parent] if progress: sys.stderr.write("\n") root = nodes[-1] for node in [node for node in nodes if not node.left]: node.parent.checkMerge(root, threshold) return root
def cluster_motifs_with_report(infile, outfile, outdir, threshold, title=None): # Cluster significant motifs if title is None: title = infile motifs = read_motifs(infile, fmt="pwm") trim_ic = 0.2 clusters = [] if len(motifs) == 0: return [] elif len(motifs) == 1: clusters = [[motifs[0], motifs]] else: logger.info("clustering %d motifs.", len(motifs)) tree = cluster_motifs(infile, "total", "wic", "mean", True, threshold=float(threshold), include_bg=True, progress=False) clusters = tree.getResult() ids = [] mc = MotifComparer() img_dir = os.path.join(outdir, "images") if not os.path.exists(img_dir): os.mkdir(img_dir) for cluster, members in clusters: cluster.trim(trim_ic) png = "images/{}.png".format(cluster.id) cluster.to_img(os.path.join(outdir, png), fmt="PNG") ids.append([cluster.id, {"src": png}, []]) if len(members) > 1: scores = {} for motif in members: scores[motif] = mc.compare_motifs(cluster, motif, "total", "wic", "mean", pval=True) add_pos = sorted(scores.values(), key=lambda x: x[1])[0][1] for motif in members: score, pos, strand = scores[motif] add = pos - add_pos if strand in [1, "+"]: pass else: rc = motif.rc() rc.id = motif.id motif = rc #print "%s\t%s" % (motif.id, add) png = "images/{}.png".format(motif.id.replace(" ", "_")) motif.to_img(os.path.join(outdir, png), fmt="PNG", add_left=add) ids[-1][2] = [ dict([("src", "images/{}.png".format(motif.id.replace(" ", "_"))), ("alt", motif.id.replace(" ", "_"))]) for motif in members ] config = MotifConfig() env = jinja2.Environment( loader=jinja2.FileSystemLoader([config.get_template_dir()])) template = env.get_template("cluster_template.jinja.html") result = template.render(motifs=ids, inputfile=title, date=datetime.today().strftime("%d/%m/%Y"), version=__version__) cluster_report = os.path.join(outdir, "cluster_report.html") with open(cluster_report, "wb") as f: f.write(result.encode('utf-8')) f = open(outfile, "w") if len(clusters) == 1 and len(clusters[0][1]) == 1: f.write("%s\n" % clusters[0][0].to_pwm()) else: for motif in tree.get_clustered_motifs(): f.write("%s\n" % motif.to_pwm()) f.close() logger.debug("Clustering done. See the result in %s", cluster_report) return clusters
def _cluster_motifs(self, pfm_file, cluster_pwm, dir, threshold): self.logger.info("clustering significant motifs.") trim_ic = 0.2 clusters = [] motifs = read_motifs(open(pfm_file), fmt="pwm") if len(motifs) == 1: clusters = [[motifs[0], motifs]] else: tree = cluster_motifs( pfm_file, "total", "wic", "mean", True, threshold=float(threshold), include_bg=True, progress=False ) clusters = tree.getResult() ids = [] mc = MotifComparer() for cluster,members in clusters: cluster.trim(trim_ic) cluster.to_img(os.path.join(self.imgdir,"%s.png" % cluster.id), format="PNG") ids.append([cluster.id, {"src":"images/%s.png" % cluster.id},[]]) if len(members) > 1: scores = {} for motif in members: scores[motif] = mc.compare_motifs(cluster, motif, "total", "wic", "mean", pval=True) add_pos = sorted(scores.values(),cmp=lambda x,y: cmp(x[1], y[1]))[0][1] for motif in members: score, pos, strand = scores[motif] add = pos - add_pos if strand in [1,"+"]: pass else: #print "RC %s" % motif.id rc = motif.rc() rc.id = motif.id motif = rc #print "%s\t%s" % (motif.id, add) motif.to_img(os.path.join(self.imgdir, "%s.png" % motif.id.replace(" ", "_")), format="PNG", add_left=add) ids[-1][2] = [dict([("src", "images/%s.png" % motif.id.replace(" ", "_")), ("alt", motif.id.replace(" ", "_"))]) for motif in members] env = jinja2.Environment(loader=jinja2.FileSystemLoader([self.config.get_template_dir()])) template = env.get_template("cluster_template.jinja.html") result = template.render(expname=self.basename, motifs=ids, inputfile=self.inputfile, date=datetime.today().strftime("%d/%m/%Y"), version=GM_VERSION) f = open(self.cluster_report, "w") f.write(result.encode('utf-8')) f.close() f = open(cluster_pwm, "w") if len(clusters) == 1 and len(clusters[0][1]) == 1: f.write("%s\n" % clusters[0][0].to_pwm()) else: for motif in tree.get_clustered_motifs(): f.write("%s\n" % motif.to_pwm()) f.close() self.logger.debug("Clustering done. See the result in %s", self.cluster_report) return clusters
def cluster_motifs( motifs, match="total", metric="wic", combine="mean", pval=True, threshold=0.95, trim_edges=False, edge_ic_cutoff=0.2, include_bg=True, progress=True, ): """ Clusters a set of sequence motifs. Required arg 'motifs' is a file containing positional frequency matrices or an array with motifs. Optional args: 'match', 'metric' and 'combine' specify the method used to compare and score the motifs. By default the WIC score is used (metric='wic'), using the the score over the whole alignment (match='total'), with the total motif score calculated as the mean score of all positions (combine='mean'). 'match' can be either 'total' for the total alignment or 'subtotal' for the maximum scoring subsequence of the alignment. 'metric' can be any metric defined in MotifComparer, currently: 'pcc', 'ed', 'distance', 'wic' or 'chisq' 'combine' determines how the total score is calculated from the score of individual positions and can be either 'sum' or 'mean' 'pval' can be True or False and determines if the score should be converted to an empirical p-value 'threshold' determines the score (or p-value) cutoff If 'trim_edges' is set to True, all motif edges with an IC below 'edge_ic_cutoff' will be removed before clustering When computing the average of two motifs 'include_bg' determines if, at a position only present in one motif, the information in that motif should be kept, or if it should be averaged with background frequencies. Should probably be left set to True. """ # First read pfm or pfm formatted motiffile if type([]) != type(motifs): motifs = read_motifs(open(motifs), fmt="pwm") mc = MotifComparer() # Trim edges with low information content if trim_edges: for motif in motifs: motif.trim(edge_ic_cutoff) # Make a MotifTree node for every motif nodes = [MotifTree(m) for m in motifs] # Determine all pairwise scores and maxscore per motif scores = {} motif_nodes = dict([(n.motif.id, n) for n in nodes]) motifs = [n.motif for n in nodes] if progress: sys.stderr.write("Calculating initial scores\n") result = mc.get_all_scores(motifs, motifs, match, metric, combine, pval, parallel=True) for m1, other_motifs in result.items(): for m2, score in other_motifs.items(): if m1 == m2: if pval: motif_nodes[m1].maxscore = 1 - score[0] else: motif_nodes[m1].maxscore = score[0] else: if pval: score = [1 - score[0]] + score[1:] scores[(motif_nodes[m1], motif_nodes[m2])] = score cluster_nodes = [node for node in nodes] ave_count = 1 total = len(cluster_nodes) while len(cluster_nodes) > 1: l = sorted(scores.keys(), key=lambda x: scores[x][0]) i = -1 (n1, n2) = l[i] while not n1 in cluster_nodes or not n2 in cluster_nodes: i -= 1 (n1, n2) = l[i] (score, pos, orientation) = scores[(n1, n2)] ave_motif = n1.motif.average_motifs(n2.motif, pos, orientation, include_bg=include_bg) ave_motif.trim(edge_ic_cutoff) ave_motif.id = "Average_%s" % ave_count ave_count += 1 new_node = MotifTree(ave_motif) if pval: new_node.maxscore = 1 - mc.compare_motifs(new_node.motif, new_node.motif, match, metric, combine, pval)[0] else: new_node.maxscore = mc.compare_motifs(new_node.motif, new_node.motif, match, metric, combine, pval)[0] new_node.mergescore = score # print "%s + %s = %s with score %s" % (n1.motif.id, n2.motif.id, ave_motif.id, score) n1.parent = new_node n2.parent = new_node new_node.left = n1 new_node.right = n2 cmp_nodes = dict([(node.motif, node) for node in nodes if not node.parent]) if progress: progress = (1 - len(cmp_nodes) / float(total)) * 100 sys.stderr.write( "\rClustering [{0}{1}] {2}%".format( "#" * (int(progress) / 10), " " * (10 - int(progress) / 10), int(progress) ) ) result = mc.get_all_scores([new_node.motif], cmp_nodes.keys(), match, metric, combine, pval, parallel=True) for motif, n in cmp_nodes.items(): x = result[new_node.motif.id][motif.id] if pval: x = [1 - x[0]] + x[1:] scores[(new_node, n)] = x nodes.append(new_node) cluster_nodes = [node for node in nodes if not node.parent] if progress: sys.stderr.write("\n") root = nodes[-1] for node in [node for node in nodes if not node.left]: node.parent.checkMerge(root, threshold) return root
def cluster_motifs_with_report(infile, outfile, outdir, threshold, title=None): # Cluster significant motifs if title is None: title = infile motifs = read_motifs(infile, fmt="pwm") trim_ic = 0.2 clusters = [] if len(motifs) == 0: return [] elif len(motifs) == 1: clusters = [[motifs[0], motifs]] else: logger.info("clustering %d motifs.", len(motifs)) tree = cluster_motifs( infile, "total", "wic", "mean", True, threshold=float(threshold), include_bg=True, progress=False ) clusters = tree.getResult() ids = [] mc = MotifComparer() img_dir = os.path.join(outdir, "images") if not os.path.exists(img_dir): os.mkdir(img_dir) for cluster,members in clusters: cluster.trim(trim_ic) png = "images/{}.png".format(cluster.id) cluster.to_img(os.path.join(outdir, png), fmt="PNG") ids.append([cluster.id, {"src":png},[]]) if len(members) > 1: scores = {} for motif in members: scores[motif] = mc.compare_motifs(cluster, motif, "total", "wic", "mean", pval=True) add_pos = sorted(scores.values(),key=lambda x: x[1])[0][1] for motif in members: score, pos, strand = scores[motif] add = pos - add_pos if strand in [1,"+"]: pass else: rc = motif.rc() rc.id = motif.id motif = rc #print "%s\t%s" % (motif.id, add) png = "images/{}.png".format(motif.id.replace(" ", "_")) motif.to_img(os.path.join(outdir, png), fmt="PNG", add_left=add) ids[-1][2] = [dict([("src", "images/{}.png".format(motif.id.replace(" ", "_"))), ("alt", motif.id.replace(" ", "_"))]) for motif in members] config = MotifConfig() env = jinja2.Environment(loader=jinja2.FileSystemLoader([config.get_template_dir()])) template = env.get_template("cluster_template.jinja.html") result = template.render( motifs=ids, inputfile=title, date=datetime.today().strftime("%d/%m/%Y"), version=__version__) cluster_report = os.path.join(outdir, "cluster_report.html") with open(cluster_report, "wb") as f: f.write(result.encode('utf-8')) f = open(outfile, "w") if len(clusters) == 1 and len(clusters[0][1]) == 1: f.write("%s\n" % clusters[0][0].to_pwm()) else: for motif in tree.get_clustered_motifs(): f.write("%s\n" % motif.to_pwm()) f.close() logger.debug("Clustering done. See the result in %s", cluster_report) return clusters