Beispiel #1
0
def _test_bounded_multicoal_tree(stree, n, T, nsamples):
    """test multicoal_tree"""
    tops = {}

    for i in xrange(nsamples):

        # use rejection sampling
        #tree, recon = coal.sample_bounded_multicoal_tree_reject(
        #    stree, n, T, namefunc=lambda x: x)

        # sample tree
        tree, recon = coal.sample_bounded_multicoal_tree(
            stree, n, T, namefunc=lambda x: x)

        top = phylo.hash_tree(tree)
        tops.setdefault(top, [0, tree, recon])[0] += 1

    tab = Table(headers=["top", "simple_top", "percent", "prob"])
    for top, (num, tree, recon) in tops.items():
        tree2 = tree.copy()
        treelib.remove_single_children(tree2)
        tab.add(top=top,
                simple_top=phylo.hash_tree(tree2),
                percent=num/float(nsamples),
                prob=exp(coal.prob_bounded_multicoal_recon_topology(
                    tree, recon, stree, n, T)))
    tab.sort(col="prob", reverse=True)

    return tab, tops
Beispiel #2
0
    def test_top(self):

        outdir = 'test/tmp/test_coal/BMC_test_top/'
        make_clean_dir(outdir)

        stree = treelib.parse_newick(
            "(((A:200, E:200):800, B:1000):500, (C:700, D:700):800);")
        n = 500
        T = 2000
        nsamples = 4000

        # compare top hist with simpler rejection sampling
        tops = {}
        tops2 = {}

        for i in xrange(nsamples):
            # use rejection sampling
            tree, recon = coal.sample_bounded_multicoal_tree_reject(
                stree, n, T, namefunc=lambda x: x)

            # sample tree
            tree2, recon2 = coal.sample_bounded_multicoal_tree(
                stree, n, T, namefunc=lambda x: x)

            top = phylo.hash_tree(tree)
            top2 = phylo.hash_tree(tree2)

            tops.setdefault(top, [0, tree, recon])[0] += 1
            tops.setdefault(top2, [0, tree2, recon2])

            tops2.setdefault(top2, [0, tree2, recon2])[0] += 1
            tops2.setdefault(top, [0, tree, recon])

        keys = tops.keys()
        x = [safelog(tops[i][0], default=0) for i in keys]
        y = [safelog(tops2[i][0], default=0) for i in keys]

        self.assertTrue(stats.corr(x, y) > .9)

        p = Gnuplot()
        p.enableOutput(False)
        p.plot(x, y)
        p.plot([min(x), max(x)], [min(x), max(x)], style="lines")
        p.enableOutput(True)
        p.save(outdir + 'plot.png')