Esempio n. 1
0
def match(args):
    sample = dict([(m.id, m) for m in pwmfile_to_motifs(args.pwmfile)])
    db = dict([(m.id, m) for m in pwmfile_to_motifs(args.dbpwmfile)])

    mc = MotifComparer()
    result = mc.get_closest_match(sample.values(), db.values(), "partial", "wic", "mean")

    print "Motif\tMatch\tScore\tP-value"
    for motif, match in result.items():
        pval, pos, orient = mc.compare_motifs(sample[motif], db[match[0]], "partial", "wic", "mean", pval=True)
        print "%s\t%s\t%0.2f\t%0.3e" % (motif, match[0], match[1][0], pval)

    if args.img:
        plotdata = []
        for query, match in result.items():
            motif = sample[query]
            dbmotif = db[match[0]]
            pval, pos, orient = mc.compare_motifs(motif, dbmotif, "partial", "wic", "mean", pval=True)
            
            if orient == -1:
                tmp = dbmotif.id
                dbmotif = dbmotif.rc()
                dbmotif.id = tmp

            if pos < 0:
                tmp = motif.id
                motif = Motif([[0.25,0.25,0.25,0.25]] * -pos + motif.pwm)
                motif.id = tmp
            elif pos > 0:
                tmp = dbmotif.id
                dbmotif = Motif([[0.25,0.25,0.25,0.25]] * pos + dbmotif.pwm)
                dbmotif.id = tmp

            plotdata.append((motif, dbmotif, pval))
            match_plot(plotdata, args.img)
Esempio n. 2
0
    def create_consensus(self):
        """ Create consensus motif from MotifList """

        motif_list = [motif.gimme_obj
                      for motif in self]  #list of gimmemotif objects

        if len(motif_list) > 1:
            consensus_found = False
            mc = MotifComparer()

            #Initialize score_dict
            score_dict = mc.get_all_scores(motif_list,
                                           motif_list,
                                           match="total",
                                           metric="pcc",
                                           combine="mean")

            while not consensus_found:

                #Which motifs to merge?
                best_similarity_motifs = sorted(
                    find_best_pair(motif_list, score_dict)
                )  #indices of most similar motifs in cluster_motifs

                #Merge
                new_motif = merge_motifs(motif_list[best_similarity_motifs[0]],
                                         motif_list[best_similarity_motifs[1]])

                del (motif_list[best_similarity_motifs[1]])
                motif_list[best_similarity_motifs[0]] = new_motif

                if len(motif_list) == 1:  #done merging
                    consensus_found = True

                else:  #Update score_dict

                    #add the comparison of the new motif to the score_dict
                    score_dict[new_motif.id] = score_dict.get(new_motif.id, {})

                    for m in motif_list:
                        score_dict[new_motif.id][m.id] = mc.compare_motifs(
                            new_motif, m, metric="pcc")
                        score_dict[m.id][new_motif.id] = mc.compare_motifs(
                            m, new_motif, metric="pcc")

        #Round pwm values
        gimmemotif_consensus = motif_list[0]
        gimmemotif_consensus.pwm = [[round(f, 5) for f in l]
                                    for l in gimmemotif_consensus.pwm]

        #Convert back to OneMotif obj
        onemotif_consensus = gimmemotif_to_onemotif(gimmemotif_consensus)
        onemotif_consensus.gimme_obj = gimmemotif_consensus

        #Control the naming of the new motif
        all_names = [motif.name for motif in self]
        onemotif_consensus.name = ",".join(all_names[:3])
        onemotif_consensus.name += "(...)" if len(all_names) > 3 else ""

        return (onemotif_consensus)
Esempio n. 3
0
def match(args):
    sample = dict([(m.id, m) for m in pwmfile_to_motifs(args.pwmfile)])
    db = dict([(m.id, m) for m in pwmfile_to_motifs(args.dbpwmfile)])

    mc = MotifComparer()
    result = mc.get_closest_match(sample.values(), db.values(), "partial", "wic", "mean")

    print("Motif\tMatch\tScore\tP-value")
    for motif, match in result.items():
        pval, pos, orient = mc.compare_motifs(sample[motif], db[match[0]], "partial", "wic", "mean", pval=True)
        print("%s\t%s\t%0.2f\t%0.3e" % (motif, match[0], match[1][0], pval))

    if args.img:
        plotdata = []
        for query, match in result.items():
            motif = sample[query]
            dbmotif = db[match[0]]
            pval, pos, orient = mc.compare_motifs(motif, dbmotif, "partial", "wic", "mean", pval=True)
            
            if orient == -1:
                tmp = dbmotif.id
                dbmotif = dbmotif.rc()
                dbmotif.id = tmp

            if pos < 0:
                tmp = motif.id
                motif = Motif([[0.25,0.25,0.25,0.25]] * -pos + motif.pwm)
                motif.id = tmp
            elif pos > 0:
                tmp = dbmotif.id
                dbmotif = Motif([[0.25,0.25,0.25,0.25]] * pos + dbmotif.pwm)
                dbmotif.id = tmp

            plotdata.append((motif, dbmotif, pval))
            match_plot(plotdata, args.img)
Esempio n. 4
0
    def determine_closest_match(self, motifs):
        self.logger.debug("Determining closest matching motifs in database")
        motif_db = self.config.get_default_params()["motif_db"]
        db = os.path.join(self.config.get_motif_dir(), motif_db)
        db_motifs = []
        if db.endswith("pwm") or db.endswith("pfm"):
            db_motifs = read_motifs(open(db), fmt="pwm")
        elif db.endswith("transfac"):
            db_motifs = read_motifs(db, fmt="transfac")

        closest_match = {}
        mc = MotifComparer()
        db_motif_lookup = dict([(m.id, m) for m in db_motifs])
        match = mc.get_closest_match(motifs,
                                     db_motifs,
                                     "partial",
                                     "wic",
                                     "mean",
                                     parallel=False)
        for motif in motifs:
            # Calculate p-value
            pval, pos, orient = mc.compare_motifs(
                motif,
                db_motif_lookup[match[motif.id][0]],
                "partial",
                "wic",
                "mean",
                pval=True)
            closest_match[motif.id] = [
                db_motif_lookup[match[motif.id][0]], pval
            ]
        return closest_match
Esempio n. 5
0
def _create_images(outdir, clusters):
    ids = []
    mc = MotifComparer()
    trim_ic = 0.2

    sys.stderr.write("Creating images\n")
    for cluster,members in clusters:
        cluster.trim(trim_ic)
        cluster.to_img(os.path.join(outdir,"%s.png" % cluster.id), fmt="PNG")
        ids.append([cluster.id, {"src":"%s.png" % cluster.id},[]])
        if len(members) > 1:
            scores = {}
            for motif in members:
                scores[motif] =  mc.compare_motifs(cluster, motif, "total", "wic", "mean", pval=True)    
            add_pos = sorted(scores.values(), key=lambda x: x[1])[0][1]
            for motif in members:
                _, pos, strand = scores[motif]
                add = pos - add_pos
                
                if strand in [1,"+"]:
                    pass
                else:
                    #print "RC %s" % motif.id
                    rc = motif.rc()
                    rc.id = motif.id
                    motif = rc
                #print "%s\t%s" % (motif.id, add)    
                motif.to_img(os.path.join(outdir, "%s.png" % motif.id.replace(" ", "_")), fmt="PNG", add_left=add)
        ids[-1][2] = [dict([("src", "%s.png" % m.id.replace(" ", "_")), ("alt", m.id.replace(" ", "_"))]) for m in members]
    return ids
Esempio n. 6
0
def match(args):
    sample = dict([(m.id, m) for m in read_motifs(args.pfmfile)])
    db = dict([(m.id, m) for m in read_motifs(args.dbpfmfile)])

    mc = MotifComparer()
    result = mc.get_best_matches(
        sample.values(), args.nmatches, db.values(), "partial", "seqcor", "mean"
    )

    plotdata = []
    print("Motif\tMatch\tScore\tP-value")
    for motif_name, matches in result.items():
        for match in matches:

            pval, pos, orient = mc.compare_motifs(
                sample[motif_name], db[match[0]], "partial", "seqcor", "mean", pval=True
            )
            print("%s\t%s\t%0.2f\t%0.3e" % (motif_name, match[0], match[1][0], pval))
            motif = sample[motif_name]
            dbmotif = db[match[0]]

            if args.img:
                if orient == -1:
                    tmp = dbmotif.id
                    dbmotif = dbmotif.rc()
                    dbmotif.id = tmp
                if pos < 0:
                    tmp = motif.id
                    motif = Motif([[0.25, 0.25, 0.25, 0.25]] * -pos + motif.pwm)
                    motif.id = tmp
                elif pos > 0:
                    tmp = dbmotif.id
                    dbmotif = Motif([[0.25, 0.25, 0.25, 0.25]] * pos + dbmotif.pwm)
                    dbmotif.id = tmp

                diff = len(motif) - len(dbmotif)
                if diff > 0:
                    dbmotif = Motif(dbmotif.pwm + [[0.25, 0.25, 0.25, 0.25]] * diff)
                else:
                    motif = Motif(motif.pwm + [[0.25, 0.25, 0.25, 0.25]] * -diff)

                plotdata.append((motif, dbmotif, pval))
    if args.img:
        match_plot(plotdata, args.img)
Esempio n. 7
0
    def determine_closest_match(self, motifs):
        self.logger.debug("Determining closest matching motifs in database")
        motif_db = self.config.get_default_params()["motif_db"]
        db = os.path.join(self.config.get_motif_dir(), motif_db)
        db_motifs = []
        if db.endswith("pwm") or db.endswith("pfm"):
            db_motifs = read_motifs(open(db), fmt="pwm")
        elif db.endswith("transfac"):
            db_motifs = read_motifs(db, fmt="transfac")

        closest_match = {}
        mc = MotifComparer()
        db_motif_lookup = dict([(m.id, m) for m in db_motifs])
        match = mc.get_closest_match(motifs, db_motifs, "partial", "wic", "mean", parallel=False)
        for motif in motifs:
            # Calculate p-value
            pval, pos, orient = mc.compare_motifs(motif, db_motif_lookup[match[motif.id][0]], "partial", "wic", "mean", pval=True)
            closest_match[motif.id] = [db_motif_lookup[match[motif.id][0]], pval]
        return closest_match
Esempio n. 8
0
def merge_motifs(motif_1, motif_2):
	"""Creates the consensus motif from two provided motifs, using the pos and orientation calculated by gimmemotifs get_all_scores()

	Parameter:
	----------
	motif_1 : Object of class Motif
		First gimmemotif object to create the consensus.
	motif_2 : Object of class Motif
		Second gimmemotif object to create consensus.
	Returns:
	--------
	consensus : Object of class Motif
		Consensus of both motifs with id composed of ids of motifs it was created.
	"""
	from gimmemotifs.comparison import MotifComparer

	mc = MotifComparer()
	_, pos, orientation = mc.compare_motifs(motif_1, motif_2, metric= "pcc")
	consensus = motif_1.average_motifs(motif_2, pos = pos, orientation = orientation)
	consensus.id = motif_1.id + "+" + motif_2.id

	return consensus
Esempio n. 9
0
def cluster(args):

    revcomp = not args.single

    outdir = os.path.abspath(args.outdir)
    if not os.path.exists(outdir):
        os.mkdir(outdir)

    trim_ic = 0.2
    clusters = []
    motifs = pwmfile_to_motifs(args.inputfile)
    if len(motifs) == 1:
        clusters = [[motifs[0], motifs]]
    else:
        tree = cluster_motifs(args.inputfile, "total", "wic", "mean", True, threshold=args.threshold, include_bg=True)
        clusters = tree.getResult()
    
    ids = []
    mc = MotifComparer()

    sys.stderr.write("Creating images\n")
    for cluster,members in clusters:
        cluster.trim(trim_ic)
        cluster.to_img(os.path.join(outdir,"%s.png" % cluster.id), format="PNG")
        ids.append([cluster.id, {"src":"%s.png" % cluster.id},[]])
        if len(members) > 1:
            scores = {}
            for motif in members:
                scores[motif] =  mc.compare_motifs(cluster, motif, "total", "wic", "mean", pval=True)    
            add_pos = sorted(scores.values(),cmp=lambda x,y: cmp(x[1], y[1]))[0][1]
            for motif in members:
                score, pos, strand = scores[motif]
                add = pos - add_pos
                
                if strand in [1,"+"]:
                    pass
                else:
                    #print "RC %s" % motif.id
                    rc = motif.rc()
                    rc.id = motif.id
                    motif = rc
                #print "%s\t%s" % (motif.id, add)    
                motif.to_img(os.path.join(outdir, "%s.png" % motif.id.replace(" ", "_")), format="PNG", add_left=add)
        ids[-1][2] = [dict([("src", "%s.png" % motif.id.replace(" ", "_")), ("alt", motif.id.replace(" ", "_"))]) for motif in members]
    
    config = MotifConfig()
    env = jinja2.Environment(loader=jinja2.FileSystemLoader([config.get_template_dir()]))
    template = env.get_template("cluster_template.jinja.html")
    result = template.render(motifs=ids)

    with open(os.path.join(outdir, "cluster_report.html"), "w") as f:
        f.write(result.encode('utf-8'))

    f = open(os.path.join(outdir, "cluster_key.txt"), "w")
    for id in ids:
        f.write("%s\t%s\n" % (id[0], ",".join([x["alt"] for x in id[2]])))
    f.close()

    f = open(os.path.join(outdir, "clustered_motifs.pwm"), "w")
    if len(clusters) == 1 and len(clusters[0][1]) == 1:
        f.write("%s\n" % clusters[0][0].to_pwm())
    else:
        for motif in tree.get_clustered_motifs():
            f.write("%s\n" % motif.to_pwm())
    f.close()
Esempio n. 10
0
    def _cluster_motifs(self, pfm_file, cluster_pwm, dir, threshold):
        self.logger.info("clustering significant motifs.")

        trim_ic = 0.2
        clusters = []
        motifs = read_motifs(open(pfm_file), fmt="pwm")
        if len(motifs) == 1:
            clusters = [[motifs[0], motifs]]
        else:
            tree = cluster_motifs(pfm_file,
                                  "total",
                                  "wic",
                                  "mean",
                                  True,
                                  threshold=float(threshold),
                                  include_bg=True,
                                  progress=False)
            clusters = tree.getResult()

        ids = []
        mc = MotifComparer()

        for cluster, members in clusters:
            cluster.trim(trim_ic)
            cluster.to_img(os.path.join(self.imgdir, "%s.png" % cluster.id),
                           format="PNG")
            ids.append([cluster.id, {"src": "images/%s.png" % cluster.id}, []])
            if len(members) > 1:
                scores = {}
                for motif in members:
                    scores[motif] = mc.compare_motifs(cluster,
                                                      motif,
                                                      "total",
                                                      "wic",
                                                      "mean",
                                                      pval=True)
                add_pos = sorted(scores.values(),
                                 cmp=lambda x, y: cmp(x[1], y[1]))[0][1]
                for motif in members:
                    score, pos, strand = scores[motif]
                    add = pos - add_pos

                    if strand in [1, "+"]:
                        pass
                    else:
                        #print "RC %s" % motif.id
                        rc = motif.rc()
                        rc.id = motif.id
                        motif = rc
                    #print "%s\t%s" % (motif.id, add)
                    motif.to_img(os.path.join(
                        self.imgdir, "%s.png" % motif.id.replace(" ", "_")),
                                 format="PNG",
                                 add_left=add)
            ids[-1][2] = [
                dict([("src", "images/%s.png" % motif.id.replace(" ", "_")),
                      ("alt", motif.id.replace(" ", "_"))])
                for motif in members
            ]

        env = jinja2.Environment(
            loader=jinja2.FileSystemLoader([self.config.get_template_dir()]))
        template = env.get_template("cluster_template.jinja.html")
        result = template.render(expname=self.basename,
                                 motifs=ids,
                                 inputfile=self.inputfile,
                                 date=datetime.today().strftime("%d/%m/%Y"),
                                 version=GM_VERSION)

        f = open(self.cluster_report, "w")
        f.write(result.encode('utf-8'))
        f.close()

        f = open(cluster_pwm, "w")
        if len(clusters) == 1 and len(clusters[0][1]) == 1:
            f.write("%s\n" % clusters[0][0].to_pwm())
        else:
            for motif in tree.get_clustered_motifs():
                f.write("%s\n" % motif.to_pwm())
        f.close()

        self.logger.debug("Clustering done. See the result in %s",
                          self.cluster_report)
        return clusters
Esempio n. 11
0
def cluster_motifs(motifs, match="total", metric="wic", combine="mean", pval=True, threshold=0.95, trim_edges=False, edge_ic_cutoff=0.2, include_bg=True, progress=True):
    """ 
    Clusters a set of sequence motifs. Required arg 'motifs' is a file containing
    positional frequency matrices or an array with motifs.

    Optional args:

    'match', 'metric' and 'combine' specify the method used to compare and score
    the motifs. By default the WIC score is used (metric='wic'), using the the
    score over the whole alignment (match='total'), with the total motif score
    calculated as the mean score of all positions (combine='mean').
    'match' can be either 'total' for the total alignment or 'subtotal' for the 
    maximum scoring subsequence of the alignment.
    'metric' can be any metric defined in MotifComparer, currently: 'pcc', 'ed',
    'distance', 'wic' or 'chisq' 
    'combine' determines how the total score is calculated from the score of 
    individual positions and can be either 'sum' or 'mean'
    
    'pval' can be True or False and determines if the score should be converted to 
    an empirical p-value

    'threshold' determines the score (or p-value) cutoff

    If 'trim_edges' is set to True, all motif edges with an IC below 
    'edge_ic_cutoff' will be removed before clustering

    When computing the average of two motifs 'include_bg' determines if, at a 
    position only present in one motif, the information in that motif should
    be kept, or if it should be averaged with background frequencies. Should
    probably be left set to True.

    """

    
    # First read pfm or pfm formatted motiffile
    if type([]) != type(motifs):
        motifs = read_motifs(open(motifs), fmt="pwm")
    
    mc = MotifComparer()

    # Trim edges with low information content
    if trim_edges:
        for motif in motifs:
            motif.trim(edge_ic_cutoff)
    
    # Make a MotifTree node for every motif
    nodes = [MotifTree(m) for m in motifs]
    
    # Determine all pairwise scores and maxscore per motif
    scores = {}
    motif_nodes = dict([(n.motif.id,n) for n in nodes])
    motifs = [n.motif for n in nodes]
    
    if progress:
        sys.stderr.write("Calculating initial scores\n")
    result = mc.get_all_scores(motifs, motifs, match, metric, combine, pval, parallel=True)
    
    for m1, other_motifs in result.items():
        for m2, score in other_motifs.items():
            if m1 == m2:
                if pval:
                    motif_nodes[m1].maxscore = 1 - score[0]
                else:
                    motif_nodes[m1].maxscore = score[0]
            else:
                if pval:
                    score = [1 - score[0]] + score[1:]
                scores[(motif_nodes[m1],motif_nodes[m2])] = score
               
    cluster_nodes = [node for node in nodes]
    ave_count = 1
    
    total = len(cluster_nodes)

    while len(cluster_nodes) > 1:
        l = sorted(scores.keys(), key=lambda x: scores[x][0])
        i = -1
        (n1, n2) = l[i]
        while not n1 in cluster_nodes or not n2 in cluster_nodes:
            i -= 1
            (n1,n2) = l[i]
        
        (score, pos, orientation) = scores[(n1,n2)]
        ave_motif = n1.motif.average_motifs(n2.motif, pos, orientation, include_bg=include_bg)
        
        ave_motif.trim(edge_ic_cutoff)
        ave_motif.id = "Average_%s" % ave_count
        ave_count += 1
        
        new_node = MotifTree(ave_motif)
        if pval:
             new_node.maxscore = 1 - mc.compare_motifs(new_node.motif, new_node.motif, match, metric, combine, pval)[0]
        else:
            new_node.maxscore = mc.compare_motifs(new_node.motif, new_node.motif, match, metric, combine, pval)[0]
            
        new_node.mergescore = score
        #print "%s + %s = %s with score %s" % (n1.motif.id, n2.motif.id, ave_motif.id, score)
        n1.parent = new_node
        n2.parent = new_node
        new_node.left = n1
        new_node.right = n2
        
        cmp_nodes = dict([(node.motif, node) for node in nodes if not node.parent])
        
        if progress:
            progress = (1 - len(cmp_nodes) / float(total)) * 100
            sys.stderr.write('\rClustering [{0}{1}] {2}%'.format(
                '#'*(int(progress)/10), 
                " "*(10 - int(progress)/10), 
                int(progress)))
        
        result = mc.get_all_scores(
                [new_node.motif], 
                cmp_nodes.keys(), 
                match, 
                metric, 
                combine, 
                pval, 
                parallel=True)
        
        for motif, n in cmp_nodes.items():
            x = result[new_node.motif.id][motif.id]
            if pval:
                x = [1 - x[0]] + x[1:]
            scores[(new_node, n)] = x
        
        nodes.append(new_node)

        cluster_nodes = [node for node in nodes if not node.parent]
     
    if progress:
        sys.stderr.write("\n") 
    root = nodes[-1]
    for node in [node for node in nodes if not node.left]:
         node.parent.checkMerge(root, threshold)
    
    return root
Esempio n. 12
0
def cluster_motifs_with_report(infile, outfile, outdir, threshold, title=None):
    # Cluster significant motifs

    if title is None:
        title = infile

    motifs = read_motifs(infile, fmt="pwm")

    trim_ic = 0.2
    clusters = []
    if len(motifs) == 0:
        return []
    elif len(motifs) == 1:
        clusters = [[motifs[0], motifs]]
    else:
        logger.info("clustering %d motifs.", len(motifs))
        tree = cluster_motifs(infile,
                              "total",
                              "wic",
                              "mean",
                              True,
                              threshold=float(threshold),
                              include_bg=True,
                              progress=False)
        clusters = tree.getResult()

    ids = []
    mc = MotifComparer()

    img_dir = os.path.join(outdir, "images")

    if not os.path.exists(img_dir):
        os.mkdir(img_dir)

    for cluster, members in clusters:
        cluster.trim(trim_ic)
        png = "images/{}.png".format(cluster.id)
        cluster.to_img(os.path.join(outdir, png), fmt="PNG")
        ids.append([cluster.id, {"src": png}, []])
        if len(members) > 1:
            scores = {}
            for motif in members:
                scores[motif] = mc.compare_motifs(cluster,
                                                  motif,
                                                  "total",
                                                  "wic",
                                                  "mean",
                                                  pval=True)
            add_pos = sorted(scores.values(), key=lambda x: x[1])[0][1]
            for motif in members:
                score, pos, strand = scores[motif]
                add = pos - add_pos

                if strand in [1, "+"]:
                    pass
                else:
                    rc = motif.rc()
                    rc.id = motif.id
                    motif = rc
                #print "%s\t%s" % (motif.id, add)
                png = "images/{}.png".format(motif.id.replace(" ", "_"))
                motif.to_img(os.path.join(outdir, png),
                             fmt="PNG",
                             add_left=add)
        ids[-1][2] = [
            dict([("src", "images/{}.png".format(motif.id.replace(" ", "_"))),
                  ("alt", motif.id.replace(" ", "_"))]) for motif in members
        ]

    config = MotifConfig()
    env = jinja2.Environment(
        loader=jinja2.FileSystemLoader([config.get_template_dir()]))
    template = env.get_template("cluster_template.jinja.html")
    result = template.render(motifs=ids,
                             inputfile=title,
                             date=datetime.today().strftime("%d/%m/%Y"),
                             version=__version__)

    cluster_report = os.path.join(outdir, "cluster_report.html")
    with open(cluster_report, "wb") as f:
        f.write(result.encode('utf-8'))

    f = open(outfile, "w")
    if len(clusters) == 1 and len(clusters[0][1]) == 1:
        f.write("%s\n" % clusters[0][0].to_pwm())
    else:
        for motif in tree.get_clustered_motifs():
            f.write("%s\n" % motif.to_pwm())
    f.close()

    logger.debug("Clustering done. See the result in %s", cluster_report)
    return clusters
Esempio n. 13
0
    def _cluster_motifs(self, pfm_file, cluster_pwm, dir, threshold):
        self.logger.info("clustering significant motifs.")

        trim_ic = 0.2
        clusters = []
        motifs = read_motifs(open(pfm_file), fmt="pwm")
        if len(motifs) == 1:
            clusters = [[motifs[0], motifs]]
        else:
            tree = cluster_motifs(
                    pfm_file, 
                    "total", 
                    "wic", 
                    "mean", 
                    True, 
                    threshold=float(threshold), 
                    include_bg=True,
                    progress=False
                    )
            clusters = tree.getResult()

        ids = []
        mc = MotifComparer()

        for cluster,members in clusters:
            cluster.trim(trim_ic)
            cluster.to_img(os.path.join(self.imgdir,"%s.png" % cluster.id), format="PNG")
            ids.append([cluster.id, {"src":"images/%s.png" % cluster.id},[]])
            if len(members) > 1:
                scores = {}
                for motif in members:
                    scores[motif] =  mc.compare_motifs(cluster, motif, "total", "wic", "mean", pval=True)
                add_pos = sorted(scores.values(),cmp=lambda x,y: cmp(x[1], y[1]))[0][1]
                for motif in members:
                    score, pos, strand = scores[motif]
                    add = pos - add_pos

                    if strand in [1,"+"]:
                        pass
                    else:
                        #print "RC %s" % motif.id
                        rc = motif.rc()
                        rc.id = motif.id
                        motif = rc
                    #print "%s\t%s" % (motif.id, add)
                    motif.to_img(os.path.join(self.imgdir, "%s.png" % motif.id.replace(" ", "_")), format="PNG", add_left=add)
            ids[-1][2] = [dict([("src", "images/%s.png" % motif.id.replace(" ", "_")), ("alt", motif.id.replace(" ", "_"))]) for motif in members]

        
        env = jinja2.Environment(loader=jinja2.FileSystemLoader([self.config.get_template_dir()]))
        template = env.get_template("cluster_template.jinja.html")
        result = template.render(expname=self.basename, motifs=ids, inputfile=self.inputfile, date=datetime.today().strftime("%d/%m/%Y"), version=GM_VERSION)
        
        f = open(self.cluster_report, "w")
        f.write(result.encode('utf-8'))
        f.close()

        f = open(cluster_pwm, "w")
        if len(clusters) == 1 and len(clusters[0][1]) == 1:
            f.write("%s\n" % clusters[0][0].to_pwm())
        else:
            for motif in tree.get_clustered_motifs():
                f.write("%s\n" % motif.to_pwm())
        f.close()

        self.logger.debug("Clustering done. See the result in %s", 
                self.cluster_report)
        return clusters
Esempio n. 14
0
def cluster_motifs(
    motifs,
    match="total",
    metric="wic",
    combine="mean",
    pval=True,
    threshold=0.95,
    trim_edges=False,
    edge_ic_cutoff=0.2,
    include_bg=True,
    progress=True,
):
    """ 
    Clusters a set of sequence motifs. Required arg 'motifs' is a file containing
    positional frequency matrices or an array with motifs.

    Optional args:

    'match', 'metric' and 'combine' specify the method used to compare and score
    the motifs. By default the WIC score is used (metric='wic'), using the the
    score over the whole alignment (match='total'), with the total motif score
    calculated as the mean score of all positions (combine='mean').
    'match' can be either 'total' for the total alignment or 'subtotal' for the 
    maximum scoring subsequence of the alignment.
    'metric' can be any metric defined in MotifComparer, currently: 'pcc', 'ed',
    'distance', 'wic' or 'chisq' 
    'combine' determines how the total score is calculated from the score of 
    individual positions and can be either 'sum' or 'mean'
    
    'pval' can be True or False and determines if the score should be converted to 
    an empirical p-value

    'threshold' determines the score (or p-value) cutoff

    If 'trim_edges' is set to True, all motif edges with an IC below 
    'edge_ic_cutoff' will be removed before clustering

    When computing the average of two motifs 'include_bg' determines if, at a 
    position only present in one motif, the information in that motif should
    be kept, or if it should be averaged with background frequencies. Should
    probably be left set to True.

    """

    # First read pfm or pfm formatted motiffile
    if type([]) != type(motifs):
        motifs = read_motifs(open(motifs), fmt="pwm")

    mc = MotifComparer()

    # Trim edges with low information content
    if trim_edges:
        for motif in motifs:
            motif.trim(edge_ic_cutoff)

    # Make a MotifTree node for every motif
    nodes = [MotifTree(m) for m in motifs]

    # Determine all pairwise scores and maxscore per motif
    scores = {}
    motif_nodes = dict([(n.motif.id, n) for n in nodes])
    motifs = [n.motif for n in nodes]

    if progress:
        sys.stderr.write("Calculating initial scores\n")
    result = mc.get_all_scores(motifs, motifs, match, metric, combine, pval, parallel=True)

    for m1, other_motifs in result.items():
        for m2, score in other_motifs.items():
            if m1 == m2:
                if pval:
                    motif_nodes[m1].maxscore = 1 - score[0]
                else:
                    motif_nodes[m1].maxscore = score[0]
            else:
                if pval:
                    score = [1 - score[0]] + score[1:]
                scores[(motif_nodes[m1], motif_nodes[m2])] = score

    cluster_nodes = [node for node in nodes]
    ave_count = 1

    total = len(cluster_nodes)

    while len(cluster_nodes) > 1:
        l = sorted(scores.keys(), key=lambda x: scores[x][0])
        i = -1
        (n1, n2) = l[i]
        while not n1 in cluster_nodes or not n2 in cluster_nodes:
            i -= 1
            (n1, n2) = l[i]

        (score, pos, orientation) = scores[(n1, n2)]
        ave_motif = n1.motif.average_motifs(n2.motif, pos, orientation, include_bg=include_bg)

        ave_motif.trim(edge_ic_cutoff)
        ave_motif.id = "Average_%s" % ave_count
        ave_count += 1

        new_node = MotifTree(ave_motif)
        if pval:
            new_node.maxscore = 1 - mc.compare_motifs(new_node.motif, new_node.motif, match, metric, combine, pval)[0]
        else:
            new_node.maxscore = mc.compare_motifs(new_node.motif, new_node.motif, match, metric, combine, pval)[0]

        new_node.mergescore = score
        # print "%s + %s = %s with score %s" % (n1.motif.id, n2.motif.id, ave_motif.id, score)
        n1.parent = new_node
        n2.parent = new_node
        new_node.left = n1
        new_node.right = n2

        cmp_nodes = dict([(node.motif, node) for node in nodes if not node.parent])

        if progress:
            progress = (1 - len(cmp_nodes) / float(total)) * 100
            sys.stderr.write(
                "\rClustering [{0}{1}] {2}%".format(
                    "#" * (int(progress) / 10), " " * (10 - int(progress) / 10), int(progress)
                )
            )

        result = mc.get_all_scores([new_node.motif], cmp_nodes.keys(), match, metric, combine, pval, parallel=True)

        for motif, n in cmp_nodes.items():
            x = result[new_node.motif.id][motif.id]
            if pval:
                x = [1 - x[0]] + x[1:]
            scores[(new_node, n)] = x

        nodes.append(new_node)

        cluster_nodes = [node for node in nodes if not node.parent]

    if progress:
        sys.stderr.write("\n")
    root = nodes[-1]
    for node in [node for node in nodes if not node.left]:
        node.parent.checkMerge(root, threshold)

    return root
Esempio n. 15
0
def cluster_motifs_with_report(infile, outfile, outdir, threshold, title=None):
    # Cluster significant motifs

    if title is None:
        title = infile

    motifs = read_motifs(infile, fmt="pwm")

    trim_ic = 0.2
    clusters = []
    if len(motifs) == 0:
        return []
    elif len(motifs) == 1:
        clusters = [[motifs[0], motifs]]
    else:
        logger.info("clustering %d motifs.", len(motifs))
        tree = cluster_motifs(
                infile,
                "total",
                "wic",
                "mean",
                True,
                threshold=float(threshold),
                include_bg=True,
                progress=False
                )
        clusters = tree.getResult()

    ids = []
    mc = MotifComparer()

    img_dir = os.path.join(outdir, "images")

    if not os.path.exists(img_dir):
        os.mkdir(img_dir)

    for cluster,members in clusters:
        cluster.trim(trim_ic)
        png = "images/{}.png".format(cluster.id)
        cluster.to_img(os.path.join(outdir, png), fmt="PNG")
        ids.append([cluster.id, {"src":png},[]])
        if len(members) > 1:
            scores = {}
            for motif in members:
                scores[motif] =  mc.compare_motifs(cluster, motif, "total", "wic", "mean", pval=True)
            add_pos = sorted(scores.values(),key=lambda x: x[1])[0][1]
            for motif in members:
                score, pos, strand = scores[motif]
                add = pos - add_pos

                if strand in [1,"+"]:
                    pass
                else:
                   rc = motif.rc()
                   rc.id = motif.id
                   motif = rc
                #print "%s\t%s" % (motif.id, add)
                png = "images/{}.png".format(motif.id.replace(" ", "_"))
                motif.to_img(os.path.join(outdir, png), fmt="PNG", add_left=add)
        ids[-1][2] = [dict([("src", "images/{}.png".format(motif.id.replace(" ", "_"))), ("alt", motif.id.replace(" ", "_"))]) for motif in members]

    config = MotifConfig()
    env = jinja2.Environment(loader=jinja2.FileSystemLoader([config.get_template_dir()]))
    template = env.get_template("cluster_template.jinja.html")
    result = template.render(
                motifs=ids,
                inputfile=title,
                date=datetime.today().strftime("%d/%m/%Y"),
                version=__version__)

    cluster_report = os.path.join(outdir, "cluster_report.html")
    with open(cluster_report, "wb") as f:
        f.write(result.encode('utf-8'))

    f = open(outfile, "w")
    if len(clusters) == 1 and len(clusters[0][1]) == 1:
        f.write("%s\n" % clusters[0][0].to_pwm())
    else:
        for motif in tree.get_clustered_motifs():
            f.write("%s\n" % motif.to_pwm())
    f.close()

    logger.debug("Clustering done. See the result in %s",
            cluster_report)
    return clusters