Esempio n. 1
0
    def _test_branch_prior_predup(self):
        """Test branch prior"""

        prep_dir("test/output/branch_prior_predup")
        # out = open("test/output/branch_prior/flies.nt.approx.txt", "w")
        out = sys.stderr
        treeid = "predup"

        tree = read_tree("test/data/flies.predup.tree")
        drawTree(tree)

        stree = read_tree("test/data/flies.stree")
        gene2species = phylo.read_gene2species("test/data/flies.smap")
        params = spidir.read_params("test/data/flies.param")
        birth = 0.4
        death = 0.39
        pretime = 1.0
        nsamples = 100

        recon = phylo.reconcile(tree, stree, gene2species)
        events = phylo.label_events(tree, recon)
        p = [
            spidir.branch_prior(tree, stree, recon, events, params, birth, death, pretime, nsamples, True)
            for i in xrange(30)
        ]
        p2 = [
            spidir.branch_prior(tree, stree, recon, events, params, birth, death, pretime, nsamples, False)
            for i in xrange(30)
        ]
        print >> out, "\t".join(map(str, [treeid, mean(p), sdev(p), mean(p2), sdev(p2)]))
Esempio n. 2
0
    def _test_branch_prior_predup(self):
        """Test branch prior"""

        prep_dir("test/output/branch_prior_predup")
        #out = open("test/output/branch_prior/flies.nt.approx.txt", "w")
        out = sys.stderr
        treeid = "predup"

        tree = read_tree("test/data/flies.predup.tree")
        drawTree(tree)

        stree = read_tree("test/data/flies.stree")
        gene2species = phylo.read_gene2species("test/data/flies.smap")
        params = spidir.read_params("test/data/flies.param")
        birth = .4
        death = .39
        pretime = 1.0
        nsamples = 100

        recon = phylo.reconcile(tree, stree, gene2species)
        events = phylo.label_events(tree, recon)
        p = [
            spidir.branch_prior(tree, stree, recon, events, params, birth,
                                death, pretime, nsamples, True)
            for i in xrange(30)
        ]
        p2 = [
            spidir.branch_prior(tree, stree, recon, events, params, birth,
                                death, pretime, nsamples, False)
            for i in xrange(30)
        ]
        print >> out, "\t".join(
            map(str,
                [treeid, mean(p), sdev(p),
                 mean(p2), sdev(p2)]))
Esempio n. 3
0
def calc_birth_death_prior(tree, stree, recon, birth, death, 
                           events=None):
    """Returns the topology prior of a gene tree"""

    from rasmus.bio import phylo
    
    if events is None:
        events = phylo.label_events(tree, recon)

    ptree, nodes, nodelookup = make_ptree(tree)
    pstree, snodes, snodelookup = make_ptree(stree)

    ctree = tree2ctree(tree)
    cstree = tree2ctree(stree)
    recon2 = make_recon_array(tree, recon, nodes, snodelookup)
    events2 = make_events_array(nodes, events)

    doomtable = c_list(c_double, [0] * len(stree.nodes))
    calcDoomTable(cstree, birth, death, doomtable)
    
    p = birthDeathTreePriorFull(ctree, cstree,
                                c_list(c_int, recon2), 
                                c_list(c_int, events2),
                                birth, death, doomtable)
    deleteTree(ctree)
    deleteTree(cstree)

    return p
Esempio n. 4
0
    def _test_branch_prior_samples(self):
        """Test branch prior"""

        prep_dir("test/output/branch_prior")

        treeids = os.listdir("test/data/flies")
        treeids = ["3"]

        for treeid in treeids:

            tree = read_tree("test/data/flies-duploss/%s/%s.tree" % (treeid, treeid))

            print treeid
            draw_tree(tree)

            stree = read_tree("test/data/flies.stree")
            gene2species = phylo.read_gene2species("test/data/flies.smap")
            params = spidir.read_params("test/data/flies.param")
            birth = 0.0012
            death = 0.0013
            pretime = 1.0
            nsamples = 100

            recon = phylo.reconcile(tree, stree, gene2species)
            events = phylo.label_events(tree, recon)

            p = [
                spidir.branch_prior(tree, stree, recon, events, params, birth, death, nsamples=nsamples, approx=True)
                for i in xrange(30)
            ]

            # row = [treeid,
            #       mean(p), exc_default(lambda: sdev(p), INF)]
            print treeid, p
Esempio n. 5
0
def calc_birth_death_prior(tree, stree, recon, birth, death, events=None):
    """Returns the topology prior of a gene tree"""

    from rasmus.bio import phylo

    if events is None:
        events = phylo.label_events(tree, recon)

    ptree, nodes, nodelookup = make_ptree(tree)
    pstree, snodes, snodelookup = make_ptree(stree)

    ctree = tree2ctree(tree)
    cstree = tree2ctree(stree)
    recon2 = make_recon_array(tree, recon, nodes, snodelookup)
    events2 = make_events_array(nodes, events)

    doomtable = c_list(c_double, [0] * len(stree.nodes))
    calcDoomTable(cstree, birth, death, doomtable)

    p = birthDeathTreePriorFull(ctree, cstree, c_list(c_int, recon2),
                                c_list(c_int, events2), birth, death,
                                doomtable)
    deleteTree(ctree)
    deleteTree(cstree)

    return p
Esempio n. 6
0
    def test_branch_prior_simple1(self):
        """Test branch prior"""

        tree = treelib.parse_newick("((a1:1, b1:1):2, c1:3);")
        stree = treelib.parse_newick("((A:2, B:2):1, C:3);")

        gene2species = lambda x: x[0].upper()

        params = {
            "A": (1.0, 1.0),
            "B": (3.0, 3.0),
            "C": (4, 3.5),
            2: (2.0, 2.0),
            1: (1.0, 1.0),
            "baserate": (11.0, 10.0),
        }
        birth = 0.01
        death = 0.02
        pretime = 1.0
        nsamples = 1

        recon = phylo.reconcile(tree, stree, gene2species)
        events = phylo.label_events(tree, recon)
        # pd(mapdict(recon, key=lambda x: x.name, val=lambda x: x.name))
        # pd(mapdict(events, key=lambda x: x.name))

        p = spidir.branch_prior(
            tree, stree, recon, events, params, birth, death, nsamples=nsamples, approx=False, generate=1
        )

        tot = 0.0
        gs = list(frange(0.0001, 4, 0.01))
        gs = list(frange(1, 1.01, 0.01))
        for g in gs:
            pg = invgammaPdf(g, params["baserate"])
            pa = gammaPdf(tree.nodes["a1"].dist, [params["A"][0], params["A"][1] / (g * stree.nodes["A"].dist)])

            pb = gammaPdf(tree.nodes["b1"].dist, [params["B"][0], params["B"][1] / (g * stree.nodes["B"].dist)])

            pc = spidir.gammaSumPdf(
                tree.nodes["c1"].dist + tree.nodes[2].dist,
                2,
                [params["C"][0], params[2][0]],
                [params["C"][1] / (g * stree.nodes["C"].dist), params[2][1] / (g * stree.nodes[2].dist)],
                0.001,
            )

            print g, pg, pa, pb, pc
            tot += pg * pa * pb * pc
        tot /= len(gs)

        print (
            tree.nodes["c1"].dist + tree.nodes[2].dist,
            [params["C"][0], params[2][0]],
            [params["C"][1], params[2][1]],
        )

        print "C", p
        print "P", log(tot)
Esempio n. 7
0
    def test_all_terms(self):
        """Test all terms"""

        prep_dir("test/output/all_terms")
        out = open("test/output/all_terms/flies.txt", "w")
        #out = sys.stderr

        treeids = os.listdir("test/data/flies")[:100]
        #treeids = ["0"]

        for treeid in treeids:

            tree = read_tree("test/data/flies/%s/%s.nt.tree" %
                             (treeid, treeid))
            align = read_fasta("test/data/flies/%s/%s.nt.align" %
                               (treeid, treeid))

            print >> out, treeid
            draw_tree(tree, out=out)

            stree = read_tree("test/data/flies.norm.stree")
            gene2species = phylo.read_gene2species("test/data/flies.smap")
            params = spidir.read_params("test/data/flies.nt.param")
            birth = .4
            death = .39
            pretime = 1.0
            nsamples = 100
            maxdoom = 20
            bgfreq = [.258, .267, .266, .209]
            kappa = 1.59

            recon = phylo.reconcile(tree, stree, gene2species)
            events = phylo.label_events(tree, recon)

            branchp, topp, seqlk = spidir.calc_joint_prob(align,
                                                          tree,
                                                          stree,
                                                          recon,
                                                          events,
                                                          params,
                                                          birth,
                                                          death,
                                                          pretime,
                                                          bgfreq,
                                                          kappa,
                                                          maxdoom=maxdoom,
                                                          terms=True)
            joint = topp + branchp + seqlk

            print >> out, "topp   ", topp
            print >> out, "branchp", branchp
            print >> out, "seqlk  ", seqlk
            print >> out, "joint  ", joint

        out.close()
Esempio n. 8
0
 def evalUserTree(tree):        
     setTreeDistances(conf, tree, distmat, labels)
     logl = treeLogLikelihood(conf, tree, stree, gene2species, params)
     
     thash = phylo.hash_tree(tree)
     if thash in visited:
         a, b, count = visited[thash]
     else:
         count = 0
     visited[thash] = [logl, tree.copy(), count+1]
     
     if isDebug(DEBUG_LOW):
         debug("\nuser given tree:")
         recon = phylo.reconcile(tree, stree, gene2species)
         events = phylo.label_events(tree, recon)
         drawTreeLogl(tree, events=events)        
Esempio n. 9
0
    def test_all_terms(self):
        """Test all terms"""

        prep_dir("test/output/all_terms")
        out = open("test/output/all_terms/flies.txt", "w")
        #out = sys.stderr

        treeids = os.listdir("test/data/flies")[:100]
        #treeids = ["0"]

        for treeid in treeids:
        
            tree = read_tree("test/data/flies/%s/%s.nt.tree" % (treeid, treeid))
            align = read_fasta("test/data/flies/%s/%s.nt.align" % (treeid, treeid))

            print >>out, treeid
            draw_tree(tree, out=out)
            
            stree = read_tree("test/data/flies.norm.stree")
            gene2species = phylo.read_gene2species("test/data/flies.smap")
            params = spidir.read_params("test/data/flies.nt.param")
            birth = .4
            death = .39
            pretime = 1.0
            nsamples = 100
            maxdoom = 20
            bgfreq = [.258,.267,.266,.209]
            kappa = 1.59
        
            recon = phylo.reconcile(tree, stree, gene2species)
            events = phylo.label_events(tree, recon)

            branchp, topp, seqlk = spidir.calc_joint_prob(
                align, tree, stree, recon, events, params,
                birth, death, pretime,
                bgfreq, kappa, maxdoom=maxdoom, terms=True)
            joint = topp + branchp + seqlk
            
            
            print >>out, "topp   ", topp
            print >>out, "branchp", branchp
            print >>out, "seqlk  ", seqlk
            print >>out, "joint  ", joint


        out.close()
Esempio n. 10
0
    def _test_branch_prior_samples(self):
        """Test branch prior"""

        prep_dir("test/output/branch_prior")

        treeids = os.listdir("test/data/flies")
        treeids = ["3"]

        for treeid in treeids:

            tree = read_tree("test/data/flies-duploss/%s/%s.tree" %
                             (treeid, treeid))

            print treeid
            draw_tree(tree)

            stree = read_tree("test/data/flies.stree")
            gene2species = phylo.read_gene2species("test/data/flies.smap")
            params = spidir.read_params("test/data/flies.param")
            birth = .0012
            death = .0013
            pretime = 1.0
            nsamples = 100

            recon = phylo.reconcile(tree, stree, gene2species)
            events = phylo.label_events(tree, recon)

            p = [
                spidir.branch_prior(tree,
                                    stree,
                                    recon,
                                    events,
                                    params,
                                    birth,
                                    death,
                                    nsamples=nsamples,
                                    approx=True) for i in xrange(30)
            ]

            #row = [treeid,
            #       mean(p), exc_default(lambda: sdev(p), INF)]
            print treeid, p
Esempio n. 11
0
    def _test_branch_prior_approx(self):
        """Test branch prior"""

        prep_dir("test/output/branch_prior")
        out = open("test/output/branch_prior/flies.approx.txt", "w")
        out = sys.stderr

        treeids = os.listdir("test/data/flies")

        for treeid in treeids:

            tree = read_tree("test/data/flies-duploss/%s/%s.nt.tree" % (treeid, treeid))

            print treeid
            draw_tree(tree)

            stree = read_tree("test/data/flies.stree")
            gene2species = phylo.read_gene2species("test/data/flies.smap")
            params = spidir.read_params("test/data/flies.param")
            birth = 0.0012
            death = 0.0013
            pretime = 1.0
            nsamples = 100

            recon = phylo.reconcile(tree, stree, gene2species)
            events = phylo.label_events(tree, recon)
            p = [
                spidir.branch_prior(tree, stree, recon, events, params, birth, death, nsamples=nsamples, approx=False)
                for i in xrange(30)
            ]
            p2 = [
                spidir.branch_prior(tree, stree, recon, events, params, birth, death, nsamples=nsamples, approx=True)
                for i in xrange(30)
            ]

            row = [treeid, mean(p), exc_default(lambda: sdev(p), INF), mean(p2), exc_default(lambda: sdev(p2), INF)]

            print >> out, "\t".join(map(str, row))
            self.assert_(INF not in row and -INF not in row)

        out.close()
Esempio n. 12
0
def treeLogLikelihood_python(conf, tree, stree, gene2species, params, 
                             baserate=None, integration="fastsampling"):

    # debug info
    if isDebug(DEBUG_MED):
        util.tic("find logl")
    
    # derive relative branch lengths
    tree.clear_data("logl", "extra", "fracs", "params", "unfold")
    recon = phylo.reconcile(tree, stree, gene2species)
    events = phylo.label_events(tree, recon)
    
    # determine if top branch unfolds
    if recon[tree.root] ==  stree.root and \
       events[tree.root] == "dup":
        for child in tree.root.children:
            if recon[child] != stree.root:
                child.data["unfold"] = True
    
    if baserate == None:
        baserate = getBaserate(tree, stree, params, recon=recon)

    phylo.midroot_recon(tree, stree, recon, events, params, baserate)
    
    # top branch is "free"
    params[stree.root.name] = [0,0]
    this = util.Bundle(logl=0.0)
    
    # recurse through indep sub-trees
    def walk(node):
        if events[node] == "spec" or \
           node == tree.root:
            this.logl += subtreeLikelihood(conf, node, recon, events, 
                                           stree, params, baserate, 
                                           integration=integration)
        node.recurse(walk)
    walk(tree.root)
    
    
    # calc probability of rare events
    tree.data["eventlogl"] = rareEventsLikelihood(conf, tree, stree, recon, events)
    this.logl += tree.data["eventlogl"]
    
    # calc penality of error
    tree.data["errorlogl"] = tree.data.get("error", 0.0) * \
                             conf.get("errorcost", 0.0)
    this.logl += tree.data["errorlogl"]

    # family rate likelihood
    if conf["famprob"]:
        this.logl += log(stats.gammaPdf(baserate, params["baserate"]))
    
    tree.data["baserate"] = baserate
    tree.data["logl"] = this.logl
    
    
    if isDebug(DEBUG_MED):
        util.toc()
        debug("\n\n")
        drawTreeLogl(tree, events=events)
    
    return this.logl
Esempio n. 13
0
def spidir(conf, distmat, labels, stree, gene2species, params):
    """Main function for the SPIDIR algorithm"""
    
    setDebug(conf["debug"])

    if isDebug(DEBUG_HIGH) and pyspidir:
        pyspidir.set_log(3, "")
        
    
    if "out" in conf:
        # create debug table
        conf["debugtab_file"] = file(conf["out"] + ".debug.tab", "w")
        
        debugtab = tablelib.Table(headers=["correct",
                                           "logl", "treelen", "baserate", 
                                           "error", "errorlogl", 
                                           "eventlogl", "tree",
                                           "topology", "species_hash"],
                                  types={"correct": bool,
                                         "logl": float, 
                                         "treelen": float, 
                                         "baserate": float, 
                                         "error": float, 
                                         "errorlogl": float,
                                         "eventlogl": float, 
                                         "tree": str,
                                         "topology": str,
                                         "species_hash": str})
        debugtab.writeHeader(conf["debugtab_file"])
        conf["debugtab"] = debugtab
    else:
        conf["debugfile"] = None
    
    
    trees = []
    logls = []
    tree = None
    visited = {}
    
    util.tic("SPIDIR")
    
    # do auto searches
    for search in conf["search"]:
        util.tic("Search by %s" % search)
        
        if search == "greedy":
            tree, logl = Search.searchGreedy(conf, distmat, labels, stree, 
                                      gene2species, params,
                                      visited=visited)
            
        elif search == "mcmc":
            tree, logl = Search.searchMCMC(conf, distmat, labels, stree, 
                                    gene2species, params, initTree=tree,
                                    visited=visited)
                                    
        elif search == "regraft":
            tree, logl = Search.searchRegraft(conf, distmat, labels, stree, 
                                    gene2species, params, initTree=tree,
                                    visited=visited, proposeFunc=Search.proposeTree3)
                                    
        elif search == "exhaustive":
            if tree == None:
                tree = phylo.neighborjoin(distmat, labels)
                tree = phylo.recon_root(tree, stree, gene2species)
            
            tree, logl = Search.searchExhaustive(conf, distmat, labels, tree, stree, 
                                          gene2species, params, 
                                          depth=conf["depth"],
                                          visited=visited)
        elif search == "hillclimb":
            tree, logl = Search.searchHillClimb(conf, distmat, labels, stree, 
                                         gene2species, params, initTree=tree,
                                         visited=visited)
        
        elif search == "none":
            break
        else:
            raise SindirError("unknown search '%s'" % search)
        
        util.toc()
        
        Search.printMCMC(conf, "N/A", tree, stree, gene2species, visited)
        
        printVisitedTrees(visited)
        

    def evalUserTree(tree):        
        setTreeDistances(conf, tree, distmat, labels)
        logl = treeLogLikelihood(conf, tree, stree, gene2species, params)
        
        thash = phylo.hash_tree(tree)
        if thash in visited:
            a, b, count = visited[thash]
        else:
            count = 0
        visited[thash] = [logl, tree.copy(), count+1]
        
        if isDebug(DEBUG_LOW):
            debug("\nuser given tree:")
            recon = phylo.reconcile(tree, stree, gene2species)
            events = phylo.label_events(tree, recon)
            drawTreeLogl(tree, events=events)        
    
    # eval the user given trees
    for treefile in conf["tree"]:
        tree = treelib.read_tree(treefile)
        evalUserTree(tree)
    
    for topfile in conf["tops"]:
        infile = file(topfile)
        strees = []
        
        while True:
            try:
                strees.append(treelib.read_tree(infile))
            except:
                break
        
        print len(strees)
        
        for top in strees:
            tree = phylo.stree2gtree(top, labels, gene2species)
            evalUserTree(tree)    
    
    if len(conf["tops"]) > 0:
        printVisitedTrees(visited)    
    
    
    
    # eval correcttree for debug only
    if "correcttree" in conf:
        tree = conf["correcttree"]
        setTreeDistances(conf, tree, distmat, labels)
        logl = treeLogLikelihood(conf, tree, stree, gene2species, params)
        
        if isDebug(DEBUG_LOW):
            debug("\ncorrect tree:")
            recon = phylo.reconcile(tree, stree, gene2species)
            events = phylo.label_events(tree, recon)
            drawTreeLogl(tree, events=events)
    
    
    util.toc()
    
    if len(visited) == 0:
        raise SindirError("No search or tree topologies given")
    
    
    if "correcthash" in conf:
        if conf["correcthash"] in visited:
            debug("SEARCH: visited correct tree")
        else:
            debug("SEARCH: NEVER saw correct tree")

    
    # return ML tree
    trees = [x[1] for x in visited.itervalues()]
    i = util.argmax([x.data["logl"] for x in trees])
    return trees[i], trees[i].data["logl"]
Esempio n. 14
0
def treeLogLikelihood(conf, tree, stree, gene2species, params, baserate=None):
    conf.setdefault("bestlogl", -util.INF)
    
    if pyspidir == None or conf.get("python_only", False):
        return Likelihood.treeLogLikelihood_python(conf, tree, stree, 
                                                   gene2species, params, 
                                                   baserate=baserate, 
                                                   integration="fastsampling")

    # debug info
    if isDebug(DEBUG_MED):
        util.tic("find logl")
    

    # derive relative branch lengths
    #tree.clearData("logl", "extra", "fracs", "params", "unfold")
    recon = phylo.reconcile(tree, stree, gene2species)
    events = phylo.label_events(tree, recon)
    
    # determine if top branch unfolds
    if recon[tree.root] ==  stree.root and \
       events[tree.root] == "dup":
        for child in tree.root.children:
            if recon[child] != stree.root:
                child.data["unfold"] = True

    # top branch is "free"
    params[stree.root.name] = [0,0]
    this = util.Bundle(logl=0.0)
    
    if conf.get("generate_int", False):
        baserate = -99.0 # indicates in integration over gene rates is requested
    elif baserate == None:
        baserate = Likelihood.getBaserate(tree, stree, params, recon=recon)
        
    
    phylo.midroot_recon(tree, stree, recon, events, params, baserate)
    
    # calc likelihood in C
    this.logl = treeLikelihood_C(conf, tree, recon, events, stree, params, 
                                 baserate, gene2species)
    
    # calc probability of rare events
    tree.data["eventlogl"] = Likelihood.rareEventsLikelihood(conf, tree, stree, recon, events)
    
    # calc penality of error
    tree.data["errorlogl"] = tree.data.get("error", 0.0) * \
                             conf.get("errorcost", 0.0)
    this.logl += tree.data["errorlogl"]
    
    # add logl of sequence evolution
    this.logl += tree.data.get("distlogl", 0.0)
    
    if baserate == -99.0: # indicates in integration over gene rates is requested
        baserate = Likelihood.getBaserate(tree, stree, params, recon=recon) 
    
    tree.data["baserate"] = baserate
    tree.data["logl"] = this.logl
    
    if isDebug(DEBUG_MED):
        util.toc()
        debug("\n\n")
        drawTreeLogl(tree, events=events)
        
    return this.logl
Esempio n. 15
0
    def test_branch_prior_simple1(self):
        """Test branch prior"""

        tree = treelib.parse_newick("((a1:1, b1:1):2, c1:3);")
        stree = treelib.parse_newick("((A:2, B:2):1, C:3);")

        gene2species = lambda x: x[0].upper()

        params = {
            "A": (1.0, 1.0),
            "B": (3.0, 3.0),
            "C": (4, 3.5),
            2: (2.0, 2.0),
            1: (1.0, 1.0),
            "baserate": (11.0, 10.0)
        }
        birth = .01
        death = .02
        pretime = 1.0
        nsamples = 1

        recon = phylo.reconcile(tree, stree, gene2species)
        events = phylo.label_events(tree, recon)
        #pd(mapdict(recon, key=lambda x: x.name, val=lambda x: x.name))
        #pd(mapdict(events, key=lambda x: x.name))

        p = spidir.branch_prior(tree,
                                stree,
                                recon,
                                events,
                                params,
                                birth,
                                death,
                                nsamples=nsamples,
                                approx=False,
                                generate=1)

        tot = 0.0
        gs = list(frange(.0001, 4, .01))
        gs = list(frange(1, 1.01, .01))
        for g in gs:
            pg = invgammaPdf(g, params["baserate"])
            pa = gammaPdf(
                tree.nodes["a1"].dist,
                [params["A"][0], params["A"][1] / (g * stree.nodes["A"].dist)])

            pb = gammaPdf(
                tree.nodes["b1"].dist,
                [params["B"][0], params["B"][1] / (g * stree.nodes["B"].dist)])

            pc = spidir.gammaSumPdf(
                tree.nodes["c1"].dist + tree.nodes[2].dist, 2,
                [params["C"][0], params[2][0]], [
                    params["C"][1] / (g * stree.nodes["C"].dist),
                    params[2][1] / (g * stree.nodes[2].dist)
                ], .001)

            print g, pg, pa, pb, pc
            tot += pg * pa * pb * pc
        tot /= len(gs)

        print(tree.nodes["c1"].dist + tree.nodes[2].dist,
              [params["C"][0], params[2][0]], [params["C"][1], params[2][1]])

        print "C", p
        print "P", log(tot)
Esempio n. 16
0
    def _test_branch_prior_approx(self):
        """Test branch prior"""

        prep_dir("test/output/branch_prior")
        out = open("test/output/branch_prior/flies.approx.txt", "w")
        out = sys.stderr

        treeids = os.listdir("test/data/flies")

        for treeid in treeids:

            tree = read_tree("test/data/flies-duploss/%s/%s.nt.tree" %
                             (treeid, treeid))

            print treeid
            draw_tree(tree)

            stree = read_tree("test/data/flies.stree")
            gene2species = phylo.read_gene2species("test/data/flies.smap")
            params = spidir.read_params("test/data/flies.param")
            birth = .0012
            death = .0013
            pretime = 1.0
            nsamples = 100

            recon = phylo.reconcile(tree, stree, gene2species)
            events = phylo.label_events(tree, recon)
            p = [
                spidir.branch_prior(tree,
                                    stree,
                                    recon,
                                    events,
                                    params,
                                    birth,
                                    death,
                                    nsamples=nsamples,
                                    approx=False) for i in xrange(30)
            ]
            p2 = [
                spidir.branch_prior(tree,
                                    stree,
                                    recon,
                                    events,
                                    params,
                                    birth,
                                    death,
                                    nsamples=nsamples,
                                    approx=True) for i in xrange(30)
            ]

            row = [
                treeid,
                mean(p),
                exc_default(lambda: sdev(p), INF),
                mean(p2),
                exc_default(lambda: sdev(p2), INF)
            ]

            print >> out, "\t".join(map(str, row))
            self.assert_(INF not in row and -INF not in row)

        out.close()
Esempio n. 17
0
    def test_branch_prior_simple2(self):
        """Test branch prior 2"""

        tree = treelib.parse_newick("((a1:2, a2:3):.4, b1:2);")
        stree = treelib.parse_newick("(A:2, B:2);")

        gene2species = lambda x: x[0].upper()

        params = {
            "A": (1.0, 1.0),
            "B": (3.0, 3.0),
            1: (1.0, 1.0),
            "baserate": (11.0, 10.0)
        }
        birth = .01
        death = .02
        pretime = 1.0
        nsamples = 100

        recon = phylo.reconcile(tree, stree, gene2species)
        events = phylo.label_events(tree, recon)
        #pd(mapdict(recon, key=lambda x: x.name, val=lambda x: x.name))
        #pd(mapdict(events, key=lambda x: x.name))

        p = spidir.branch_prior(tree,
                                stree,
                                recon,
                                events,
                                params,
                                birth,
                                death,
                                nsamples=nsamples,
                                approx=False)

        tot = 0.0

        gstart = 0.01
        gend = 3.0
        step = (gend - gstart) / 20.0
        s2 = step / 2.0
        gs = list(frange(gstart + s2, gend + s2, step))
        for g in gs:
            pg = invgammaPdf(g, params["baserate"])

            pa = 0.0

            for i in range(nsamples):

                t = birthdeath.sample_birth_wait_time(1, stree.nodes["A"].dist,
                                                      birth, death)
                #print t

                t2 = stree.nodes["A"].dist - t

                pa1 = gammaPdf(tree.nodes["a1"].dist,
                               [params["A"][0], params["A"][1] / (g * t2)])

                pa2 = gammaPdf(tree.nodes["a2"].dist,
                               [params["A"][0], params["A"][1] / (g * t2)])

                pb = spidir.gammaSumPdf(
                    tree.nodes["b1"].dist + tree.nodes[2].dist, 2,
                    [params["B"][0], params["A"][0]], [
                        params["B"][1] /
                        (g * stree.nodes["B"].dist), params["A"][1] / (g * t)
                    ], .001)

                if "nan" not in map(str, [pa1, pa2, pb]):
                    pa += pa1 * pa2 * pb / nsamples

            tot += pg * pa * step
        #tot /= len(gs)

        print "unfold", (tree.nodes["b1"].dist + tree.nodes[2].dist,
                         [params["B"][0], params["A"][0]], [
                             params["B"][1] / (g * stree.nodes["B"].dist),
                             params["A"][1] / (g * t)
                         ])

        print "C", p
        print "P", log(tot)
Esempio n. 18
0
def draw_tree(tree, labels={}, xscale=100, yscale=20, canvas=None,
              leafPadding=10,
              labelOffset=None, fontSize=10, labelSize=None,
              minlen=1, maxlen=util.INF, filename=sys.stdout,
              rmargin=150, lmargin=10, tmargin=0, bmargin=None,
              colormap=None,
              stree=None,
              layout=None,
              gene2species=None,
              lossColor=(0, 0, 1),
              dupColor=(1, 0, 0),
              eventSize=4,
              legendScale=False, autoclose=None):
    
    # set defaults
    fontRatio = 8. / 11.
    
    if labelSize == None:
        labelSize = .7 * fontSize
    
    if labelOffset == None:
        labelOffset = -1
    
    if bmargin == None:
        bmargin = yscale
    
    if sum(x.dist for x in tree.nodes.values()) == 0:
        legendScale = False
        minlen = xscale
    
    if colormap == None:
        for node in tree:
            node.color = (0, 0, 0)
    else:
        colormap(tree)
    
    if stree and gene2species:
        recon = phylo.reconcile(tree, stree, gene2species)
        events = phylo.label_events(tree, recon)
        losses = phylo.find_loss(tree, stree, recon)
    else:
        events = None
        losses = None
    
    # layout tree
    if layout is None:
        coords = treelib.layout_tree(tree, xscale, yscale, minlen, maxlen)
    else:
        coords = layout
    
    xcoords, ycoords = zip(* coords.values())
    maxwidth = max(xcoords)
    maxheight = max(ycoords) + labelOffset
    
    
    # initialize canvas
    if canvas == None:
        canvas = svg.Svg(util.open_stream(filename, "w"))
        width = int(rmargin + maxwidth + lmargin)
        height = int(tmargin + maxheight + bmargin)
        
        canvas.beginSvg(width, height)
        
        if autoclose == None:
            autoclose = True
    else:
        if autoclose == None:
            autoclose = False
    
    
    # draw tree
    def walk(node):
        x, y = coords[node]
        if node.parent:
            parentx = coords[node.parent][0]
        else:
            parentx = 0
        
        # draw branch
        canvas.line(parentx, y, x, y, color=node.color)
        if node.name in labels:
            branchlen = x - parentx
            lines = str(labels[node.name]).split("\n")
            labelwidth = max(map(len, lines))
            labellen = min(labelwidth * fontRatio * fontSize, 
                           max(int(branchlen-1), 0))
            
            for i, line in enumerate(lines):
                canvas.text(line,
                            parentx + (branchlen - labellen)/2., 
                            y + labelOffset 
                            +(-len(lines)+1+i)*(labelSize+1),
                            labelSize)
        
        if node.isLeaf():
            canvas.text(str(node.name), 
                        x + leafPadding, y+fontSize/2., fontSize,
                        fillColor=node.color)
        else:
            top = coords[node.children[0]][1]
            bot = coords[node.children[-1]][1]
            
            # draw children
            canvas.line(x, top, x, bot, color=node.color)
            
            for child in node.children:
                walk(child)
    
    canvas.beginTransform(("translate", lmargin, tmargin))
    walk(tree.root)
        
    if stree and gene2species:
        draw_events(canvas, tree, coords, events, losses,
                    lossColor=lossColor,
                    dupColor=dupColor,
                    size=eventSize)
    canvas.endTransform()
    
    # draw legend
    if legendScale:
        if legendScale == True:
            # automatically choose a scale
            length = maxwidth / float(xscale)
            order = math.floor(math.log10(length))
            length = 10 ** order
    
        drawScale(lmargin, tmargin + maxheight + bmargin - fontSize, 
                  length, xscale, fontSize, canvas=canvas)
    
    if autoclose:
        canvas.endSvg()
    
    return canvas
Esempio n. 19
0
    def test_branch_prior_simple2(self):
        """Test branch prior 2"""

        tree = treelib.parse_newick("((a1:2, a2:3):.4, b1:2);")
        stree = treelib.parse_newick("(A:2, B:2);")

        gene2species = lambda x: x[0].upper()

        params = {"A": (1.0, 1.0), "B": (3.0, 3.0), 1: (1.0, 1.0), "baserate": (11.0, 10.0)}
        birth = 0.01
        death = 0.02
        pretime = 1.0
        nsamples = 100

        recon = phylo.reconcile(tree, stree, gene2species)
        events = phylo.label_events(tree, recon)
        # pd(mapdict(recon, key=lambda x: x.name, val=lambda x: x.name))
        # pd(mapdict(events, key=lambda x: x.name))

        p = spidir.branch_prior(tree, stree, recon, events, params, birth, death, nsamples=nsamples, approx=False)

        tot = 0.0

        gstart = 0.01
        gend = 3.0
        step = (gend - gstart) / 20.0
        s2 = step / 2.0
        gs = list(frange(gstart + s2, gend + s2, step))
        for g in gs:
            pg = invgammaPdf(g, params["baserate"])

            pa = 0.0

            for i in range(nsamples):

                t = birthdeath.sample_birth_wait_time(1, stree.nodes["A"].dist, birth, death)
                # print t

                t2 = stree.nodes["A"].dist - t

                pa1 = gammaPdf(tree.nodes["a1"].dist, [params["A"][0], params["A"][1] / (g * t2)])

                pa2 = gammaPdf(tree.nodes["a2"].dist, [params["A"][0], params["A"][1] / (g * t2)])

                pb = spidir.gammaSumPdf(
                    tree.nodes["b1"].dist + tree.nodes[2].dist,
                    2,
                    [params["B"][0], params["A"][0]],
                    [params["B"][1] / (g * stree.nodes["B"].dist), params["A"][1] / (g * t)],
                    0.001,
                )

                if "nan" not in map(str, [pa1, pa2, pb]):
                    pa += pa1 * pa2 * pb / nsamples

            tot += pg * pa * step
        # tot /= len(gs)

        print "unfold", (
            tree.nodes["b1"].dist + tree.nodes[2].dist,
            [params["B"][0], params["A"][0]],
            [params["B"][1] / (g * stree.nodes["B"].dist), params["A"][1] / (g * t)],
        )

        print "C", p
        print "P", log(tot)