コード例 #1
0
ファイル: test_basic.py プロジェクト: bredelings/argweaver
    def test_post(self):

        k = 6
        n = 1e4
        rho = 1.5e-8 * 10
        mu = 2.5e-8 * 10
        length = 10000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)
        print "muts", len(muts)
        print "recombs", len(arglib.get_recomb_pos(arg))

        times = arghmm.get_time_points(ntimes=10)
        arghmm.discretize_arg(arg, times)

        tree = arg.get_marginal_tree(0)
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)

        # remove chrom
        new_name = "n%d" % (k - 1)
        keep = set(arg.leaf_names()) - set([new_name])
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)

        model = arghmm.ArgHmm(arg, seqs, new_name=new_name,
                              times=times, rho=rho, mu=mu)
        print "states", len(model.states[0])

        probs = arghmm.get_posterior_probs(model, length, verbose=True)

        for pcol in probs:
            p = sum(map(exp, pcol))
            print p, " ".join("%.3f" % f for f in map(exp, pcol))
            fequal(p, 1.0, rel=1e-2)
コード例 #2
0
ファイル: test_basic.py プロジェクト: swamidass/argweaver
    def test_post_real(self):

        k = 3
        n = 1e4
        rho = 1.5e-8
        mu = 2.5e-8
        length = 100000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        #arg = arglib.read_arg("test/data/real.arg")
        #seqs = fasta.read_fasta("test/data/real.fa")

        #arglib.write_arg("test/data/real.arg", arg)
        #fasta.write_fasta("test/data/real.fa", seqs)

        times = arghmm.get_time_points(maxtime=50000, ntimes=20)
        arghmm.discretize_arg(arg, times)

        new_name = "n%d" % (k - 1)
        thread = list(
            arghmm.iter_chrom_thread(arg, arg[new_name], by_block=False))
        tree = arg.get_marginal_tree(0)
        print tree.root.age
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)
        p = plot(cget(thread, 1), style="lines", ymin=10, ylog=10)

        #alignlib.print_align(seqs)

        # remove chrom
        keep = ["n%d" % i for i in range(k - 1)]
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)

        model = arghmm.ArgHmm(arg,
                              seqs,
                              new_name=new_name,
                              times=times,
                              rho=rho,
                              mu=mu)

        print "states", len(model.states[0])
        #print "muts", len(muts)
        print "recomb", len(model.recomb_pos) - 2, model.recomb_pos[1:-1]

        probs = arghmm.get_posterior_probs(model, length, verbose=True)

        high = list(arghmm.iter_posterior_times(model, probs, .95))
        low = list(arghmm.iter_posterior_times(model, probs, .05))
        p.plot(high, style="lines")
        p.plot(low, style="lines")

        pause()
コード例 #3
0
    def test_smcify_arg(self):

        rho = 1.5e-8   # recomb/site/gen
        l = 100000     # length of locus
        k = 6          # number of lineages
        n = 2*10000    # effective popsize

        arg = arglib.sample_arg(k, n, rho, 0, l)
        arg = arglib.smcify_arg(arg)

        for pos, (rnode, rtime), (cnode, ctime) in arglib.iter_arg_sprs(arg):
            self.assertNotEqual(rnode, cnode)
コード例 #4
0
    def test_smcify_arg_remove_thread(self):

        rho = 1.5e-8   # recomb/site/gen
        l = 100000      # length of locus
        k = 6         # number of lineages
        n = 2*10000    # effective popsize

        arg = arglib.sample_arg(k, n, rho, 0, l)
        remove_chroms = set("n%d" % (k-1))
        keep = [x for x in arg.leaf_names() if x not in remove_chroms]
        arg = arg.copy()
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)
コード例 #5
0
ファイル: test_basic.py プロジェクト: bredelings/argweaver
    def test_post_real(self):

        k = 3
        n = 1e4
        rho = 1.5e-8
        mu = 2.5e-8
        length = 100000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        #arg = arglib.read_arg("test/data/real.arg")
        #seqs = fasta.read_fasta("test/data/real.fa")

        #arglib.write_arg("test/data/real.arg", arg)
        #fasta.write_fasta("test/data/real.fa", seqs)

        times = arghmm.get_time_points(maxtime=50000, ntimes=20)
        arghmm.discretize_arg(arg, times)

        new_name = "n%d" % (k - 1)
        thread = list(arghmm.iter_chrom_thread(arg, arg[new_name],
                                               by_block=False))
        tree = arg.get_marginal_tree(0)
        print tree.root.age
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)
        p = plot(cget(thread, 1), style="lines", ymin=10, ylog=10)

        #alignlib.print_align(seqs)

        # remove chrom
        keep = ["n%d" % i for i in range(k-1)]
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)

        model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times,
                              rho=rho, mu=mu)

        print "states", len(model.states[0])
        #print "muts", len(muts)
        print "recomb", len(model.recomb_pos) - 2, model.recomb_pos[1:-1]

        probs = arghmm.get_posterior_probs(model, length, verbose=True)

        high = list(arghmm.iter_posterior_times(model, probs, .95))
        low = list(arghmm.iter_posterior_times(model, probs, .05))
        p.plot(high, style="lines")
        p.plot(low, style="lines")

        pause()
コード例 #6
0
ファイル: test_basic.py プロジェクト: swamidass/argweaver
    def test_post2(self):

        k = 2
        n = 1e4
        rho = 1.5e-8 * 10
        mu = 2.5e-8 * 10
        length = 10000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        print "muts", len(muts)

        times = arghmm.get_time_points()
        arghmm.discretize_arg(arg, times)

        thread = list(arghmm.iter_chrom_thread(arg, arg["n1"], by_block=False))
        tree = arg.get_marginal_tree(0)
        print tree.root.age
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)
        p = plot(cget(thread, 1), style="lines", ymin=0)

        #alignlib.print_align(seqs)

        # remove chrom
        keep = ["n0"]
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)

        model = arghmm.ArgHmm(arg,
                              seqs,
                              new_name="n1",
                              times=times,
                              rho=rho,
                              mu=mu)
        print "states", len(model.states[0])

        probs = arghmm.get_posterior_probs(model, length, verbose=True)

        high = list(arghmm.iter_posterior_times(model, probs, .95))
        low = list(arghmm.iter_posterior_times(model, probs, .05))
        p.plot(high, style="lines")
        p.plot(low, style="lines")

        pause()
コード例 #7
0
ファイル: test_basic.py プロジェクト: bredelings/argweaver
    def test_post2(self):

        k = 2
        n = 1e4
        rho = 1.5e-8 * 10
        mu = 2.5e-8 * 10
        length = 10000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        print "muts", len(muts)

        times = arghmm.get_time_points()
        arghmm.discretize_arg(arg, times)

        thread = list(arghmm.iter_chrom_thread(arg, arg["n1"], by_block=False))
        tree = arg.get_marginal_tree(0)
        print tree.root.age
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)
        p = plot(cget(thread, 1), style="lines", ymin=0)

        #alignlib.print_align(seqs)

        # remove chrom
        keep = ["n0"]
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)

        model = arghmm.ArgHmm(arg, seqs, new_name="n1", times=times,
                              rho=rho, mu=mu)
        print "states", len(model.states[0])

        probs = arghmm.get_posterior_probs(model, length, verbose=True)

        high = list(arghmm.iter_posterior_times(model, probs, .95))
        low = list(arghmm.iter_posterior_times(model, probs, .05))
        p.plot(high, style="lines")
        p.plot(low, style="lines")

        pause()
コード例 #8
0
ファイル: test_basic.py プロジェクト: swamidass/argweaver
    def test_post(self):

        k = 6
        n = 1e4
        rho = 1.5e-8 * 10
        mu = 2.5e-8 * 10
        length = 10000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)
        print "muts", len(muts)
        print "recombs", len(arglib.get_recomb_pos(arg))

        times = arghmm.get_time_points(ntimes=10)
        arghmm.discretize_arg(arg, times)

        tree = arg.get_marginal_tree(0)
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)

        # remove chrom
        new_name = "n%d" % (k - 1)
        keep = set(arg.leaf_names()) - set([new_name])
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)

        model = arghmm.ArgHmm(arg,
                              seqs,
                              new_name=new_name,
                              times=times,
                              rho=rho,
                              mu=mu)
        print "states", len(model.states[0])

        probs = arghmm.get_posterior_probs(model, length, verbose=True)

        for pcol in probs:
            p = sum(map(exp, pcol))
            print p, " ".join("%.3f" % f for f in map(exp, pcol))
            fequal(p, 1.0, rel=1e-2)
コード例 #9
0
ファイル: test_prog.py プロジェクト: bredelings/argweaver
def show_plots(arg_file, sites_file, stats_file, output_prefix,
               rho, mu, popsize, ntimes=20, maxtime=200000):
    """
    Show plots of convergence.
    """

    # read true arg and seqs
    times = argweaver.get_time_points(ntimes=ntimes, maxtime=maxtime)
    arg = arglib.read_arg(arg_file)
    argweaver.discretize_arg(arg, times, ignore_top=False, round_age="closer")
    arg = arglib.smcify_arg(arg)
    seqs = argweaver.sites2seqs(argweaver.read_sites(sites_file))

    # compute true stats
    arglen = arglib.arglen(arg)
    arg = argweaverc.arg2ctrees(arg, times)
    nrecombs = argweaverc.get_local_trees_ntrees(arg[0]) - 1
    lk = argweaverc.calc_likelihood(
        arg, seqs, mu=mu, times=times,
        delete_arg=False)
    prior = argweaverc.calc_prior_prob(
        arg, rho=rho, times=times, popsizes=popsize,
                        delete_arg=False)
    joint = lk + prior

    data = read_table(stats_file)

    # joint
    y2 = joint
    y = data.cget("joint")
    rplot_start(output_prefix + ".trace.joint.pdf", width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="joint probability",
            xlab="iterations",
            ylab="joint probability")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)

    # lk
    y2 = lk
    y = data.cget("likelihood")
    rplot_start(output_prefix + ".trace.lk.pdf", width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="likelihood",
            xlab="iterations",
            ylab="likelihood")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)

    # prior
    y2 = prior
    y = data.cget("prior")
    rplot_start(output_prefix + ".trace.prior.pdf", width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="prior probability",
            xlab="iterations",
            ylab="prior probability")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)

    # nrecombs
    y2 = nrecombs
    y = data.cget("recombs")
    rplot_start(output_prefix + ".trace.nrecombs.pdf",
                width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="number of recombinations",
            xlab="iterations",
            ylab="number of recombinations")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)

    # arglen
    y2 = arglen
    y = data.cget("arglen")
    rplot_start(output_prefix + ".trace.arglen.pdf",
                width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="ARG branch length",
            xlab="iterations",
            ylab="ARG branch length")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)
コード例 #10
0
ファイル: compare.py プロジェクト: mdrasmus/smc-analysis
from compbio import arglib

if 1:
    cwr_coals_list = []
    smc_coals_list = []

    for i in range(20):
        k = 10
        n = 10e3
        length = 500e3
        rho = 1.5e-8

        # simulate an ARG from the CwR and convert it into SMC-style
        tic("simulate %d" % i)
        cwr_arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        cwr_arg_converted = arglib.smcify_arg(cwr_arg)
        toc()

        # simulate an ARG directly from SMC process
        smc_arg = arglib.sample_arg_smc(k, n, rho, start=0, end=length)

        # gather all coalescence times
        cwr_coals = [node.age for node in cwr_arg_converted
                     if node.event == 'coal']
        smc_coals = [node.age for node in smc_arg
                     if node.event == 'coal']
        print len(cwr_coals), len(smc_coals)

        cwr_coals_list.append(cwr_coals)
        smc_coals_list.append(smc_coals)
コード例 #11
0
ファイル: test_prog.py プロジェクト: jjberg2/argweaver
def show_plots(arg_file, sites_file, stats_file, output_prefix,
               rho, mu, popsize, ntimes=20, maxtime=200000):
    """
    Show plots of convergence.
    """

    # read true arg and seqs
    times = argweaver.get_time_points(ntimes=ntimes, maxtime=maxtime)
    arg = arglib.read_arg(arg_file)
    argweaver.discretize_arg(arg, times, ignore_top=False, round_age="closer")
    arg = arglib.smcify_arg(arg)
    seqs = argweaver.sites2seqs(argweaver.read_sites(sites_file))

    # compute true stats
    arglen = arglib.arglen(arg)
    arg = argweaverc.arg2ctrees(arg, times)
    nrecombs = argweaverc.get_local_trees_ntrees(arg[0]) - 1
    lk = argweaverc.calc_likelihood(
        arg, seqs, mu=mu, times=times,
        delete_arg=False)
    prior = argweaverc.calc_prior_prob(
        arg, rho=rho, times=times, popsizes=popsize,
                        delete_arg=False)
    joint = lk + prior

    data = read_table(stats_file)

    # joint
    y2 = joint
    y = data.cget("joint")
    rplot_start(output_prefix + ".trace.joint.pdf", width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="joint probability",
            xlab="iterations",
            ylab="joint probability")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)

    # lk
    y2 = lk
    y = data.cget("likelihood")
    rplot_start(output_prefix + ".trace.lk.pdf", width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="likelihood",
            xlab="iterations",
            ylab="likelihood")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)

    # prior
    y2 = prior
    y = data.cget("prior")
    rplot_start(output_prefix + ".trace.prior.pdf", width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="prior probability",
            xlab="iterations",
            ylab="prior probability")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)

    # nrecombs
    y2 = nrecombs
    y = data.cget("recombs")
    rplot_start(output_prefix + ".trace.nrecombs.pdf",
                width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="number of recombinations",
            xlab="iterations",
            ylab="number of recombinations")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)

    # arglen
    y2 = arglen
    y = data.cget("arglen")
    rplot_start(output_prefix + ".trace.arglen.pdf",
                width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="ARG branch length",
            xlab="iterations",
            ylab="ARG branch length")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)
コード例 #12
0
ファイル: test_basic.py プロジェクト: swamidass/argweaver
    def test_determ(self):

        k = 8
        n = 1e4
        rho = 1.5e-8
        mu = 2.5e-8
        length = 100000

        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)
        times = arghmm.get_time_points(maxtime=50000, ntimes=20)
        arghmm.discretize_arg(arg, times)

        new_name = "n%d" % (k - 1)
        thread = list(
            arghmm.iter_chrom_thread(arg, arg[new_name], by_block=False))
        thread_clades = list(
            arghmm.iter_chrom_thread(arg,
                                     arg[new_name],
                                     by_block=True,
                                     use_clades=True))

        # remove chrom
        keep = ["n%d" % i for i in range(k - 1)]
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)

        model = arghmm.ArgHmm(arg,
                              seqs,
                              new_name=new_name,
                              times=times,
                              rho=rho,
                              mu=mu)

        for i, rpos in enumerate(model.recomb_pos[1:-1]):
            pos = rpos + 1

            model.check_local_tree(pos, force=True)

            #recomb = arghmm.find_tree_next_recomb(arg, pos - 1)
            tree = arg.get_marginal_tree(pos - .5)
            last_tree = arg.get_marginal_tree(pos - 1 - .5)
            states1 = model.states[pos - 1]
            states2 = model.states[pos]
            (recomb_branch, recomb_time), (coal_branch, coal_time) = \
                arghmm.find_recomb_coal(tree, last_tree, pos=rpos)
            recomb_time = times.index(recomb_time)
            coal_time = times.index(coal_time)

            determ = arghmm.get_deterministic_transitions(
                states1, states2, times, tree, last_tree, recomb_branch,
                recomb_time, coal_branch, coal_time)

            leaves1, time1, block1 = thread_clades[i]
            leaves2, time2, block2 = thread_clades[i + 1]
            if new_name in leaves1:
                leaves1.remove(new_name)
            if new_name in leaves2:
                leaves2.remove(new_name)
            node1 = arghmm.arg_lca(arg, leaves1, None, pos - 1).name
            node2 = arghmm.arg_lca(arg, leaves2, None, pos).name

            state1 = (node1, times.index(time1))
            state2 = (node2, times.index(time2))
            print pos, state1, state2
            try:
                statei1 = states1.index(state1)
                statei2 = states2.index(state2)
            except:
                print "states1", states1
                print "states2", states2
                raise

            statei3 = determ[statei1]
            print "  ", statei1, statei2, statei3, states2[statei3]
            if statei2 != statei3 and statei3 != -1:
                tree = tree.get_tree()
                treelib.remove_single_children(tree)
                last_tree = last_tree.get_tree()
                treelib.remove_single_children(last_tree)

                print "block1", block1
                print "block2", block2
                print "r=", (recomb_branch, recomb_time)
                print "c=", (coal_branch, coal_time)

                print "tree"
                treelib.draw_tree_names(tree, minlen=8, maxlen=8)

                print "last_tree"
                treelib.draw_tree_names(last_tree, minlen=8, maxlen=8)
                assert False
コード例 #13
0
ファイル: test_basic.py プロジェクト: bredelings/argweaver
    def test_determ(self):

        k = 8
        n = 1e4
        rho = 1.5e-8
        mu = 2.5e-8
        length = 100000

        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)
        times = arghmm.get_time_points(maxtime=50000, ntimes=20)
        arghmm.discretize_arg(arg, times)

        new_name = "n%d" % (k - 1)
        thread = list(arghmm.iter_chrom_thread(arg, arg[new_name],
                                               by_block=False))
        thread_clades = list(arghmm.iter_chrom_thread(
            arg, arg[new_name], by_block=True, use_clades=True))

        # remove chrom
        keep = ["n%d" % i for i in range(k-1)]
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)

        model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times,
                              rho=rho, mu=mu)


        for i, rpos in enumerate(model.recomb_pos[1:-1]):
            pos = rpos + 1

            model.check_local_tree(pos, force=True)

            #recomb = arghmm.find_tree_next_recomb(arg, pos - 1)
            tree = arg.get_marginal_tree(pos-.5)
            last_tree = arg.get_marginal_tree(pos-1-.5)
            states1 = model.states[pos-1]
            states2 = model.states[pos]
            (recomb_branch, recomb_time), (coal_branch, coal_time) = \
                arghmm.find_recomb_coal(tree, last_tree, pos=rpos)
            recomb_time = times.index(recomb_time)
            coal_time = times.index(coal_time)

            determ = arghmm.get_deterministic_transitions(
                states1, states2, times,
                tree, last_tree,
                recomb_branch, recomb_time,
                coal_branch, coal_time)


            leaves1, time1, block1 = thread_clades[i]
            leaves2, time2, block2 = thread_clades[i+1]
            if new_name in leaves1:
                leaves1.remove(new_name)
            if new_name in leaves2:
                leaves2.remove(new_name)
            node1 = arghmm.arg_lca(arg, leaves1, None, pos-1).name
            node2 = arghmm.arg_lca(arg, leaves2, None, pos).name

            state1 = (node1, times.index(time1))
            state2 = (node2, times.index(time2))
            print pos, state1, state2
            try:
                statei1 = states1.index(state1)
                statei2 = states2.index(state2)
            except:
                print "states1", states1
                print "states2", states2
                raise

            statei3 = determ[statei1]
            print "  ", statei1, statei2, statei3, states2[statei3]
            if statei2 != statei3 and statei3 != -1:
                tree = tree.get_tree()
                treelib.remove_single_children(tree)
                last_tree = last_tree.get_tree()
                treelib.remove_single_children(last_tree)

                print "block1", block1
                print "block2", block2
                print "r=", (recomb_branch, recomb_time)
                print "c=", (coal_branch, coal_time)

                print "tree"
                treelib.draw_tree_names(tree, minlen=8, maxlen=8)

                print "last_tree"
                treelib.draw_tree_names(last_tree, minlen=8, maxlen=8)
                assert False