Example #1
0
    def test_trans_single(self):
        """
        Calculate transition probabilities

        Only calculate a single matrix
        """

        k = 4
        n = 1e4
        rho = 1.5e-8
        mu = 2.5e-8
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        times = arghmm.get_time_points(10)
        arghmm.discretize_arg(arg, times)
        print "recomb", arglib.get_recomb_pos(arg)

        new_name = "n%d" % (k-1)
        arg = arghmm.remove_arg_thread(arg, new_name)
        model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times)

        pos = 10
        tree = arg.get_marginal_tree(pos)
        mat = arghmm.calc_transition_probs(
            tree, model.states[pos], model.nlineages,
            model.times, model.time_steps, model.popsizes, rho)
        print model.states[pos]
        pc(mat)

        for row in mat:
            print sum(map(exp, row))
Example #2
0
    def test_prior(self):
        """
        Calculate state priors
        """

        k = 10
        n = 1e4
        rho = 1.5e-8
        mu = 2.5e-8
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        times = arghmm.get_time_points()
        arghmm.discretize_arg(arg, times)
        new_name = "n%d" % (k - 1)
        arg = arghmm.remove_arg_thread(arg, new_name)

        model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times)

        prior = [
            model.prob_prior(0, j) for j in xrange(model.get_num_states(0))
        ]
        print prior
        print sum(map(exp, prior))
        fequal(sum(map(exp, prior)), 1.0, rel=.01)
Example #3
0
    def test_prior(self):
        """
        Calculate state priors
        """

        k = 10
        n = 1e4
        rho = 1.5e-8
        mu = 2.5e-8
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        times = arghmm.get_time_points()
        arghmm.discretize_arg(arg, times)
        new_name = "n%d" % (k-1)
        arg = arghmm.remove_arg_thread(arg, new_name)

        model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times)

        prior = [model.prob_prior(0, j)
                 for j in xrange(model.get_num_states(0))]
        print prior
        print sum(map(exp, prior))
        fequal(sum(map(exp, prior)), 1.0, rel=.01)
Example #4
0
    def test_trans_single(self):
        """
        Calculate transition probabilities

        Only calculate a single matrix
        """

        k = 4
        n = 1e4
        rho = 1.5e-8
        mu = 2.5e-8
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        times = arghmm.get_time_points(10)
        arghmm.discretize_arg(arg, times)
        print "recomb", arglib.get_recomb_pos(arg)

        new_name = "n%d" % (k - 1)
        arg = arghmm.remove_arg_thread(arg, new_name)
        model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times)

        pos = 10
        tree = arg.get_marginal_tree(pos)
        mat = arghmm.calc_transition_probs(tree, model.states[pos],
                                           model.nlineages, model.times,
                                           model.time_steps, model.popsizes,
                                           rho)
        print model.states[pos]
        pc(mat)

        for row in mat:
            print sum(map(exp, row))
Example #5
0
    def test_post(self):

        k = 6
        n = 1e4
        rho = 1.5e-8 * 10
        mu = 2.5e-8 * 10
        length = 10000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)
        print "muts", len(muts)
        print "recombs", len(arglib.get_recomb_pos(arg))

        times = arghmm.get_time_points(ntimes=10)
        arghmm.discretize_arg(arg, times)

        tree = arg.get_marginal_tree(0)
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)

        # remove chrom
        new_name = "n%d" % (k - 1)
        keep = set(arg.leaf_names()) - set([new_name])
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)

        model = arghmm.ArgHmm(arg, seqs, new_name=new_name,
                              times=times, rho=rho, mu=mu)
        print "states", len(model.states[0])

        probs = arghmm.get_posterior_probs(model, length, verbose=True)

        for pcol in probs:
            p = sum(map(exp, pcol))
            print p, " ".join("%.3f" % f for f in map(exp, pcol))
            fequal(p, 1.0, rel=1e-2)
Example #6
0
    def test_post_plot(self):

        k = 6
        n = 1e4
        rho = 1.5e-8 * 50
        mu = 2.5e-8 * 50
        length = 10000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        times = arghmm.get_time_points(ntimes=30)
        arghmm.discretize_arg(arg, times)

        pause()

        # save
        #arglib.write_arg("test/data/k4.arg", arg)
        #fasta.write_fasta("test/data/k4.fa", seqs)

        new_name = "n%d" % (k - 1)
        thread = list(
            arghmm.iter_chrom_thread(arg, arg[new_name], by_block=False))
        p = plot(cget(thread, 1), style="lines", ymin=times[1], ylog=10)

        # remove chrom
        new_name = "n%d" % (k - 1)
        arg = arghmm.remove_arg_thread(arg, new_name)

        model = arghmm.ArgHmm(arg,
                              seqs,
                              new_name=new_name,
                              times=times,
                              rho=rho,
                              mu=mu)
        print "states", len(model.states[0])
        print "muts", len(muts)
        print "recomb", len(model.recomb_pos) - 2, model.recomb_pos[1:-1]

        p.plot(model.recomb_pos, [10000] * len(model.recomb_pos),
               style="points")

        probs = arghmm.get_posterior_probs(model, length, verbose=True)
        print "done"

        high = list(arghmm.iter_posterior_times(model, probs, .95))
        low = list(arghmm.iter_posterior_times(model, probs, .05))
        p.gnuplot("set linestyle 2")
        p.plot(high, style="lines")
        p.gnuplot("set linestyle 2")
        p.plot(low, style="lines")

        #write_list("test/data/post_real.txt", cget(thread, 1))
        #write_list("test/data/post_high.txt", high)
        #write_list("test/data/post_low.txt", low)

        pause()
Example #7
0
    def test_post_plot(self):

        k = 6
        n = 1e4
        rho = 1.5e-8 * 50
        mu = 2.5e-8 * 50
        length = 10000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        times = arghmm.get_time_points(ntimes=30)
        arghmm.discretize_arg(arg, times)

        pause()

        # save
        #arglib.write_arg("test/data/k4.arg", arg)
        #fasta.write_fasta("test/data/k4.fa", seqs)

        new_name = "n%d" % (k-1)
        thread = list(arghmm.iter_chrom_thread(arg, arg[new_name],
                                               by_block=False))
        p = plot(cget(thread, 1), style="lines", ymin=times[1],
                 ylog=10)

        # remove chrom
        new_name = "n%d" % (k-1)
        arg = arghmm.remove_arg_thread(arg, new_name)

        model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times,
                              rho=rho, mu=mu)
        print "states", len(model.states[0])
        print "muts", len(muts)
        print "recomb", len(model.recomb_pos) - 2, model.recomb_pos[1:-1]

        p.plot(model.recomb_pos, [10000] * len(model.recomb_pos),
               style="points")

        probs = arghmm.get_posterior_probs(model, length, verbose=True)
        print "done"

        high = list(arghmm.iter_posterior_times(model, probs, .95))
        low = list(arghmm.iter_posterior_times(model, probs, .05))
        p.gnuplot("set linestyle 2")
        p.plot(high, style="lines")
        p.gnuplot("set linestyle 2")
        p.plot(low, style="lines")


        #write_list("test/data/post_real.txt", cget(thread, 1))
        #write_list("test/data/post_high.txt", high)
        #write_list("test/data/post_low.txt", low)

        pause()
Example #8
0
    def test_norecomb_plot(self):

        k = 50
        n = 1e4
        rho = 1.5e-8 * .0001
        rho2 = 1.5e-8 * 10
        mu = 2.5e-8 * 100
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        times = arghmm.get_time_points(ntimes=20)
        arghmm.discretize_arg(arg, times)

        tree = arg.get_marginal_tree(0)
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)

        # get thread
        new_name = "n%d" % (k - 1)
        keep = ["n%d" % i for i in range(k - 1)]
        arglib.subarg_by_leaf_names(arg, keep)
        arg.set_ancestral()
        arg.prune()

        model = arghmm.ArgHmm(arg,
                              seqs,
                              new_name=new_name,
                              times=times,
                              rho=rho2,
                              mu=mu)
        print "states", len(model.states[0])
        print "muts", len(muts)

        # simulate a new thread
        states = list(islice(hmm.sample_hmm_states(model), 0, arg.end))
        data = list(hmm.sample_hmm_data(model, states))

        seqs[new_name] = "".join(data)
        #alignlib.print_align(seqs)

        thread = [
            model.times[model.states[i][s][1]] for i, s in enumerate(states)
        ]
        p = plot(thread, style="lines")

        probs = arghmm.get_posterior_probs(model, length, verbose=True)
        print "done"

        high = list(arghmm.iter_posterior_times(model, probs, .75))
        low = list(arghmm.iter_posterior_times(model, probs, .25))
        p.plot(high, style="lines")
        p.plot(low, style="lines")

        pause()
Example #9
0
    def test_post_real(self):

        k = 3
        n = 1e4
        rho = 1.5e-8
        mu = 2.5e-8
        length = 100000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        #arg = arglib.read_arg("test/data/real.arg")
        #seqs = fasta.read_fasta("test/data/real.fa")

        #arglib.write_arg("test/data/real.arg", arg)
        #fasta.write_fasta("test/data/real.fa", seqs)

        times = arghmm.get_time_points(maxtime=50000, ntimes=20)
        arghmm.discretize_arg(arg, times)

        new_name = "n%d" % (k - 1)
        thread = list(
            arghmm.iter_chrom_thread(arg, arg[new_name], by_block=False))
        tree = arg.get_marginal_tree(0)
        print tree.root.age
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)
        p = plot(cget(thread, 1), style="lines", ymin=10, ylog=10)

        #alignlib.print_align(seqs)

        # remove chrom
        keep = ["n%d" % i for i in range(k - 1)]
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)

        model = arghmm.ArgHmm(arg,
                              seqs,
                              new_name=new_name,
                              times=times,
                              rho=rho,
                              mu=mu)

        print "states", len(model.states[0])
        #print "muts", len(muts)
        print "recomb", len(model.recomb_pos) - 2, model.recomb_pos[1:-1]

        probs = arghmm.get_posterior_probs(model, length, verbose=True)

        high = list(arghmm.iter_posterior_times(model, probs, .95))
        low = list(arghmm.iter_posterior_times(model, probs, .05))
        p.plot(high, style="lines")
        p.plot(low, style="lines")

        pause()
Example #10
0
    def test_trans_switch_single(self):
        """
        Calculate transitions probabilities for switching between blocks

        Only calculate a single matrix
        """

        k = 5
        n = 1e4
        rho = 1.5e-8 * 100
        mu = 2.5e-8
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        #arglib.write_arg("tmp/a.arg", arg)
        #arg = arglib.read_arg("tmp/a.arg")
        #arg.set_ancestral()


        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        times = arghmm.get_time_points(5)
        arghmm.discretize_arg(arg, times)

        new_name = "n%d" % (k-1)
        arg = arghmm.remove_arg_thread(arg, new_name)

        model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times)

        # get recombs
        recombs = list(x.pos for x in arghmm.iter_visible_recombs(arg))
        print "recomb", recombs

        pos = recombs[0] + 1
        tree = arg.get_marginal_tree(pos-.5)
        last_tree = arg.get_marginal_tree(pos-1-.5)

        print "states1>>", model.states[pos-1]
        print "states2>>", model.states[pos]

        treelib.draw_tree_names(last_tree.get_tree(), minlen=5, maxlen=5)
        treelib.draw_tree_names(tree.get_tree(), minlen=5, maxlen=5)

        print "pos>>", pos
        recomb = [x for x in tree
                  if x.event == "recomb" and x.pos+1 == pos][0]
        mat = arghmm.calc_transition_probs_switch(
            tree, last_tree, recomb.name,
            model.states[pos-1], model.states[pos],
            model.nlineages, model.times,
            model.time_steps, model.popsizes, rho)
        pc(mat)
Example #11
0
    def test_norecomb_plot(self):

        k = 50
        n = 1e4
        rho = 1.5e-8 * .0001
        rho2 = 1.5e-8 * 10
        mu = 2.5e-8 * 100
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)


        times = arghmm.get_time_points(ntimes=20)
        arghmm.discretize_arg(arg, times)

        tree = arg.get_marginal_tree(0)
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)

        # get thread
        new_name = "n%d" % (k-1)
        keep = ["n%d" % i for i in range(k-1)]
        arglib.subarg_by_leaf_names(arg, keep)
        arg.set_ancestral()
        arg.prune()

        model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times,
                              rho=rho2, mu=mu)
        print "states", len(model.states[0])
        print "muts", len(muts)

        # simulate a new thread
        states = list(islice(hmm.sample_hmm_states(model), 0, arg.end))
        data = list(hmm.sample_hmm_data(model, states))

        seqs[new_name] = "".join(data)
        #alignlib.print_align(seqs)

        thread = [model.times[model.states[i][s][1]]
                  for i, s in enumerate(states)]
        p = plot(thread, style="lines")


        probs = arghmm.get_posterior_probs(model, length, verbose=True)
        print "done"

        high = list(arghmm.iter_posterior_times(model, probs, .75))
        low = list(arghmm.iter_posterior_times(model, probs, .25))
        p.plot(high, style="lines")
        p.plot(low, style="lines")

        pause()
Example #12
0
    def test_trans_switch_single(self):
        """
        Calculate transitions probabilities for switching between blocks

        Only calculate a single matrix
        """

        k = 5
        n = 1e4
        rho = 1.5e-8 * 100
        mu = 2.5e-8
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        #arglib.write_arg("tmp/a.arg", arg)
        #arg = arglib.read_arg("tmp/a.arg")
        #arg.set_ancestral()

        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        times = arghmm.get_time_points(5)
        arghmm.discretize_arg(arg, times)

        new_name = "n%d" % (k - 1)
        arg = arghmm.remove_arg_thread(arg, new_name)

        model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times)

        # get recombs
        recombs = list(x.pos for x in arghmm.iter_visible_recombs(arg))
        print "recomb", recombs

        pos = recombs[0] + 1
        tree = arg.get_marginal_tree(pos - .5)
        last_tree = arg.get_marginal_tree(pos - 1 - .5)

        print "states1>>", model.states[pos - 1]
        print "states2>>", model.states[pos]

        treelib.draw_tree_names(last_tree.get_tree(), minlen=5, maxlen=5)
        treelib.draw_tree_names(tree.get_tree(), minlen=5, maxlen=5)

        print "pos>>", pos
        recomb = [x for x in tree
                  if x.event == "recomb" and x.pos + 1 == pos][0]
        mat = arghmm.calc_transition_probs_switch(tree, last_tree, recomb.name,
                                                  model.states[pos - 1],
                                                  model.states[pos],
                                                  model.nlineages, model.times,
                                                  model.time_steps,
                                                  model.popsizes, rho)
        pc(mat)
Example #13
0
    def test_post_c(self):

        k = 3
        n = 1e4
        rho = 1.5e-8 * 30
        mu = 2.5e-8 * 100
        length = 100
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        arg.prune()
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        print arglib.get_recomb_pos(arg)
        print "muts", len(muts)
        print "recomb", len(arglib.get_recomb_pos(arg))

        times = arghmm.get_time_points(ntimes=10)
        arghmm.discretize_arg(arg, times)

        tree = arg.get_marginal_tree(0)
        print tree.root.age
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)

        # remove chrom
        keep = ["n%d" % i for i in range(k - 1)]
        arglib.subarg_by_leaf_names(arg, keep)

        model = arghmm.ArgHmm(arg,
                              seqs,
                              new_name="n%d" % (k - 1),
                              times=times,
                              rho=rho,
                              mu=mu)
        print "states", len(model.states[0])

        util.tic("C")
        probs1 = list(arghmm.get_posterior_probs(model, length, verbose=True))
        util.toc()

        util.tic("python")
        probs2 = list(hmm.get_posterior_probs(model, length, verbose=True))
        util.toc()

        print "probs1"
        pc(probs1)

        print "probs2"
        pc(probs2)

        for col1, col2 in izip(probs1, probs2):
            for a, b in izip(col1, col2):
                fequal(a, b)
Example #14
0
    def test_post_real(self):

        k = 3
        n = 1e4
        rho = 1.5e-8
        mu = 2.5e-8
        length = 100000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        #arg = arglib.read_arg("test/data/real.arg")
        #seqs = fasta.read_fasta("test/data/real.fa")

        #arglib.write_arg("test/data/real.arg", arg)
        #fasta.write_fasta("test/data/real.fa", seqs)

        times = arghmm.get_time_points(maxtime=50000, ntimes=20)
        arghmm.discretize_arg(arg, times)

        new_name = "n%d" % (k - 1)
        thread = list(arghmm.iter_chrom_thread(arg, arg[new_name],
                                               by_block=False))
        tree = arg.get_marginal_tree(0)
        print tree.root.age
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)
        p = plot(cget(thread, 1), style="lines", ymin=10, ylog=10)

        #alignlib.print_align(seqs)

        # remove chrom
        keep = ["n%d" % i for i in range(k-1)]
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)

        model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times,
                              rho=rho, mu=mu)

        print "states", len(model.states[0])
        #print "muts", len(muts)
        print "recomb", len(model.recomb_pos) - 2, model.recomb_pos[1:-1]

        probs = arghmm.get_posterior_probs(model, length, verbose=True)

        high = list(arghmm.iter_posterior_times(model, probs, .95))
        low = list(arghmm.iter_posterior_times(model, probs, .05))
        p.plot(high, style="lines")
        p.plot(low, style="lines")

        pause()
Example #15
0
    def test_post_c(self):

        k = 3
        n = 1e4
        rho = 1.5e-8 * 30
        mu = 2.5e-8 * 100
        length = 100
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        arg.prune()
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        print arglib.get_recomb_pos(arg)
        print "muts", len(muts)
        print "recomb", len(arglib.get_recomb_pos(arg))

        times = arghmm.get_time_points(ntimes=10)
        arghmm.discretize_arg(arg, times)

        tree = arg.get_marginal_tree(0)
        print tree.root.age
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)

        # remove chrom
        keep = ["n%d" % i for i in range(k-1)]
        arglib.subarg_by_leaf_names(arg, keep)

        model = arghmm.ArgHmm(arg, seqs, new_name="n%d" % (k-1), times=times,
                              rho=rho, mu=mu)
        print "states", len(model.states[0])

        util.tic("C")
        probs1 = list(arghmm.get_posterior_probs(model, length, verbose=True))
        util.toc()

        util.tic("python")
        probs2 = list(hmm.get_posterior_probs(model, length, verbose=True))
        util.toc()

        print "probs1"
        pc(probs1)

        print "probs2"
        pc(probs2)


        for col1, col2 in izip(probs1, probs2):
            for a, b in izip(col1, col2):
                fequal(a, b)
Example #16
0
    def test_post3(self):

        k = 3
        n = 1e4
        rho = 1.5e-8 * 3
        mu = 2.5e-8 * 100
        length = 10000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        arg.prune()
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        times = arghmm.get_time_points(ntimes=10)
        arghmm.discretize_arg(arg, times)

        tree = arg.get_marginal_tree(0)
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)

        thread = list(arghmm.iter_chrom_thread(arg, arg["n2"], by_block=False))
        p = plot(cget(thread, 1), style="lines", ymin=0)

        # remove chrom
        keep = ["n0", "n1"]
        arglib.subarg_by_leaf_names(arg, keep)
        arg.set_ancestral()
        arg.prune()

        model = arghmm.ArgHmm(arg,
                              seqs,
                              new_name="n2",
                              times=times,
                              rho=rho,
                              mu=mu)
        print "states", len(model.states[0])
        print "muts", len(muts)
        print "recomb", len(model.recomb_pos) - 2, model.recomb_pos[1:-1]

        p.plot(model.recomb_pos, [1000] * len(model.recomb_pos),
               style="points")

        probs = arghmm.get_posterior_probs(model, length, verbose=True)

        high = list(arghmm.iter_posterior_times(model, probs, .95))
        low = list(arghmm.iter_posterior_times(model, probs, .05))
        p.plot(high, style="lines")
        p.plot(low, style="lines")

        pause()
Example #17
0
    def test_post3(self):

        k = 3
        n = 1e4
        rho = 1.5e-8 * 3
        mu = 2.5e-8 * 100
        length = 10000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        arg.prune()
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        times = arghmm.get_time_points(ntimes=10)
        arghmm.discretize_arg(arg, times)


        tree = arg.get_marginal_tree(0)
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)

        thread = list(arghmm.iter_chrom_thread(arg, arg["n2"], by_block=False))
        p = plot(cget(thread, 1), style="lines", ymin=0)

        # remove chrom
        keep = ["n0", "n1"]
        arglib.subarg_by_leaf_names(arg, keep)
        arg.set_ancestral()
        arg.prune()


        model = arghmm.ArgHmm(arg, seqs, new_name="n2", times=times,
                              rho=rho, mu=mu)
        print "states", len(model.states[0])
        print "muts", len(muts)
        print "recomb", len(model.recomb_pos) - 2, model.recomb_pos[1:-1]


        p.plot(model.recomb_pos, [1000] * len(model.recomb_pos),
               style="points")

        probs = arghmm.get_posterior_probs(model, length, verbose=True)

        high = list(arghmm.iter_posterior_times(model, probs, .95))
        low = list(arghmm.iter_posterior_times(model, probs, .05))
        p.plot(high, style="lines")
        p.plot(low, style="lines")

        pause()
Example #18
0
    def test_post2(self):

        k = 2
        n = 1e4
        rho = 1.5e-8 * 10
        mu = 2.5e-8 * 10
        length = 10000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        print "muts", len(muts)

        times = arghmm.get_time_points()
        arghmm.discretize_arg(arg, times)

        thread = list(arghmm.iter_chrom_thread(arg, arg["n1"], by_block=False))
        tree = arg.get_marginal_tree(0)
        print tree.root.age
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)
        p = plot(cget(thread, 1), style="lines", ymin=0)

        #alignlib.print_align(seqs)

        # remove chrom
        keep = ["n0"]
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)

        model = arghmm.ArgHmm(arg,
                              seqs,
                              new_name="n1",
                              times=times,
                              rho=rho,
                              mu=mu)
        print "states", len(model.states[0])

        probs = arghmm.get_posterior_probs(model, length, verbose=True)

        high = list(arghmm.iter_posterior_times(model, probs, .95))
        low = list(arghmm.iter_posterior_times(model, probs, .05))
        p.plot(high, style="lines")
        p.plot(low, style="lines")

        pause()
Example #19
0
    def test_thread(self):
        """
        Test thread retrieval
        """

        k = 10
        n = 1e4
        rho = 1.5e-8 * 10
        mu = 2.5e-8 * 100
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        for (block, tree), threadi in izip(
            arglib.iter_tree_tracks(arg),
            arghmm.iter_chrom_thread(arg, arg["n9"], by_block=True)):
            print block
            print threadi
            treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)
Example #20
0
    def test_thread(self):
        """
        Test thread retrieval
        """

        k = 10
        n = 1e4
        rho = 1.5e-8 * 10
        mu = 2.5e-8 * 100
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        for (block, tree), threadi in izip(
                arglib.iter_tree_tracks(arg),
                arghmm.iter_chrom_thread(arg, arg["n9"], by_block=True)):
            print block
            print threadi
            treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)
Example #21
0
    def test_post2(self):

        k = 2
        n = 1e4
        rho = 1.5e-8 * 10
        mu = 2.5e-8 * 10
        length = 10000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        print "muts", len(muts)

        times = arghmm.get_time_points()
        arghmm.discretize_arg(arg, times)

        thread = list(arghmm.iter_chrom_thread(arg, arg["n1"], by_block=False))
        tree = arg.get_marginal_tree(0)
        print tree.root.age
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)
        p = plot(cget(thread, 1), style="lines", ymin=0)

        #alignlib.print_align(seqs)

        # remove chrom
        keep = ["n0"]
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)

        model = arghmm.ArgHmm(arg, seqs, new_name="n1", times=times,
                              rho=rho, mu=mu)
        print "states", len(model.states[0])

        probs = arghmm.get_posterior_probs(model, length, verbose=True)

        high = list(arghmm.iter_posterior_times(model, probs, .95))
        low = list(arghmm.iter_posterior_times(model, probs, .05))
        p.plot(high, style="lines")
        p.plot(low, style="lines")

        pause()
Example #22
0
    def test_backward(self):
        """
        Run backward algorithm
        """

        k = 3
        n = 1e4
        rho = 1.5e-8 * 100
        mu = 2.5e-8 * 100
        length = 10000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        times = arghmm.get_time_points(ntimes=10)
        arghmm.discretize_arg(arg, times)

        tree = arg.get_marginal_tree(0)
        print tree.root.age
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)

        # remove chrom
        new_name = "n%d" % (k - 1)
        arg = arghmm.remove_arg_thread(arg, new_name)

        model = arghmm.ArgHmm(arg,
                              seqs,
                              new_name=new_name,
                              times=times,
                              rho=rho,
                              mu=mu)
        print "states", len(model.states[0])
        print "recomb", model.recomb_pos
        print "muts", len(muts)

        probs = hmm.backward_algorithm(model, length, verbose=True)

        for pcol in probs:
            p = sum(map(exp, pcol))
            print p, " ".join("%.3f" % f for f in map(exp, pcol))
Example #23
0
    def test_emit_argmax(self):
        """
        Calculate emission probabilities
        """

        k = 10
        n = 1e4
        rho = 0.0
        mu = 2.5e-8 * 100
        length = 10000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        times = arghmm.get_time_points(10)
        arghmm.discretize_arg(arg, times)

        new_name = "n%d" % (k - 1)
        thread = list(arghmm.iter_chrom_thread(arg, arg[new_name]))
        arg = arghmm.remove_arg_thread(arg, new_name)

        model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times)

        nstates = model.get_num_states(1)
        probs = [0.0 for j in xrange(nstates)]
        for i in xrange(1, length):
            if i % 100 == 0:
                print i
            for j in xrange(nstates):
                probs[j] += model.prob_emission(i, j)
        print

        # is the maximum likelihood emission matching truth
        data = sorted(zip(probs, model.states[0]), reverse=True)
        pc(data[:20])

        state = (thread[0][0], times.index(thread[0][1]))

        print data[0][1], state
        assert data[0][1] == state
Example #24
0
    def test_emit_argmax(self):
        """
        Calculate emission probabilities
        """

        k = 10
        n = 1e4
        rho = 0.0
        mu = 2.5e-8 * 100
        length = 10000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        times = arghmm.get_time_points(10)
        arghmm.discretize_arg(arg, times)

        new_name = "n%d" % (k-1)
        thread = list(arghmm.iter_chrom_thread(arg, arg[new_name]))
        arg = arghmm.remove_arg_thread(arg, new_name)

        model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times)

        nstates = model.get_num_states(1)
        probs = [0.0 for j in xrange(nstates)]
        for i in xrange(1, length):
            if i % 100 == 0:
                print i
            for j in xrange(nstates):
                probs[j] += model.prob_emission(i, j)
        print

        # is the maximum likelihood emission matching truth
        data = sorted(zip(probs, model.states[0]), reverse=True)
        pc(data[:20])

        state = (thread[0][0], times.index(thread[0][1]))

        print data[0][1], state
        assert data[0][1] == state
Example #25
0
    def test_post(self):

        k = 6
        n = 1e4
        rho = 1.5e-8 * 10
        mu = 2.5e-8 * 10
        length = 10000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)
        print "muts", len(muts)
        print "recombs", len(arglib.get_recomb_pos(arg))

        times = arghmm.get_time_points(ntimes=10)
        arghmm.discretize_arg(arg, times)

        tree = arg.get_marginal_tree(0)
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)

        # remove chrom
        new_name = "n%d" % (k - 1)
        keep = set(arg.leaf_names()) - set([new_name])
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)

        model = arghmm.ArgHmm(arg,
                              seqs,
                              new_name=new_name,
                              times=times,
                              rho=rho,
                              mu=mu)
        print "states", len(model.states[0])

        probs = arghmm.get_posterior_probs(model, length, verbose=True)

        for pcol in probs:
            p = sum(map(exp, pcol))
            print p, " ".join("%.3f" % f for f in map(exp, pcol))
            fequal(p, 1.0, rel=1e-2)
Example #26
0
def test_forward():

    k = 4
    n = 1e4
    rho = 1.5e-8 * 20
    mu = 2.5e-8 * 20
    length = int(100e3 / 20)
    times = argweaver.get_time_points(ntimes=100)

    arg = arglib.sample_arg_smc(k, 2 * n, rho, start=0, end=length)
    muts = arglib.sample_arg_mutations(arg, mu)
    seqs = arglib.make_alignment(arg, muts)

    print "muts", len(muts)
    print "recomb", len(arglib.get_recomb_pos(arg))

    argweaver.discretize_arg(arg, times)

    # remove chrom
    new_name = "n%d" % (k - 1)
    arg = argweaver.remove_arg_thread(arg, new_name)

    carg = argweaverc.arg2ctrees(arg, times)

    util.tic("C fast")
    probs1 = argweaverc.argweaver_forward_algorithm(carg, seqs, times=times)
    util.toc()

    util.tic("C slow")
    probs2 = argweaverc.argweaver_forward_algorithm(carg,
                                                    seqs,
                                                    times=times,
                                                    slow=True)
    util.toc()

    for i, (col1, col2) in enumerate(izip(probs1, probs2)):
        for a, b in izip(col1, col2):
            fequal(a, b, rel=.0001)
Example #27
0
    def test_pars_seq(self):
        """
        Test parsimony ancestral sequence inference
        """

        k = 10
        n = 1e4
        rho = 1.5e-8
        mu = 2.5e-8 * 100
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        pos = int(muts[0][2])
        tree = arg.get_marginal_tree(pos)

        print "pos =", pos
        treelib.draw_tree_names(tree.get_tree(), scale=4e-4, minlen=5)

        arglib.remove_single_lineages(tree)
        ancestral = arghmm.emit.parsimony_ancestral_seq(tree, seqs, pos)
        util.print_dict(ancestral)
Example #28
0
    def test_pars_seq(self):
        """
        Test parsimony ancestral sequence inference
        """

        k = 10
        n = 1e4
        rho = 1.5e-8
        mu = 2.5e-8 * 100
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        pos = int(muts[0][2])
        tree = arg.get_marginal_tree(pos)

        print "pos =", pos
        treelib.draw_tree_names(tree.get_tree(), scale=4e-4, minlen=5)

        arglib.remove_single_lineages(tree)
        ancestral = arghmm.emit.parsimony_ancestral_seq(tree, seqs, pos)
        util.print_dict(ancestral)
Example #29
0
def test_forward():

    k = 4
    n = 1e4
    rho = 1.5e-8 * 20
    mu = 2.5e-8 * 20
    length = int(100e3 / 20)
    times = argweaver.get_time_points(ntimes=100)

    arg = arglib.sample_arg_smc(k, 2*n, rho, start=0, end=length)
    muts = arglib.sample_arg_mutations(arg, mu)
    seqs = arglib.make_alignment(arg, muts)

    print "muts", len(muts)
    print "recomb", len(arglib.get_recombs(arg))

    argweaver.discretize_arg(arg, times)

    # remove chrom
    new_name = "n%d" % (k - 1)
    arg = argweaver.remove_arg_thread(arg, new_name)

    carg = argweaverc.arg2ctrees(arg, times)

    util.tic("C fast")
    probs1 = argweaverc.argweaver_forward_algorithm(carg, seqs, times=times)
    util.toc()

    util.tic("C slow")
    probs2 = argweaverc.argweaver_forward_algorithm(carg, seqs, times=times,
                                                    slow=True)
    util.toc()

    for i, (col1, col2) in enumerate(izip(probs1, probs2)):
        for a, b in izip(col1, col2):
            fequal(a, b, rel=.0001)
Example #30
0
    def test_backward(self):
        """
        Run backward algorithm
        """

        k = 3
        n = 1e4
        rho = 1.5e-8 * 100
        mu = 2.5e-8 * 100
        length = 10000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        times = arghmm.get_time_points(ntimes=10)
        arghmm.discretize_arg(arg, times)

        tree = arg.get_marginal_tree(0)
        print tree.root.age
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)

        # remove chrom
        new_name = "n%d" % (k-1)
        arg = arghmm.remove_arg_thread(arg, new_name)

        model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times,
                              rho=rho, mu=mu)
        print "states", len(model.states[0])
        print "recomb", model.recomb_pos
        print "muts", len(muts)

        probs = hmm.backward_algorithm(model, length, verbose=True)

        for pcol in probs:
            p = sum(map(exp, pcol))
            print p, " ".join("%.3f" % f for f in map(exp, pcol))
Example #31
0
        move_layout(layout2, x=(i+1)*(nleaves+2))
        mapping = get_mapping(tree1, tree2, pos, times)
        draw_mapping(tree1, tree2, layout1, layout2, times, mapping)
        

#=============================================================================
if 0:
    k = 10
    n = 1e4
    rho = 1.5e-8 * 100
    mu = 2.5e-8
    length = 1000
    
    arg = arglib.sample_arg(k, n, rho, start=0, end=length)
    arglib.write_arg("tmp/a.arg", arg)
    muts = arglib.sample_arg_mutations(arg, mu)
    seqs = arglib.make_alignment(arg, muts)

    #arg = arglib.read_arg("tmp/a.arg")
    #arg.set_ancestral()
    #find_recomb_coal(tree, last_tree, recomb_name=None, pos=None)

    times = arghmm.get_time_points(30, maxtime=60e3)
    arghmm.discretize_arg(arg, times)

    # get recombs
    recombs = list(x.pos for x in arghmm.iter_visible_recombs(arg))
    print "recomb", recombs

    pos = recombs[0] + 1
    tree = arg.get_marginal_tree(pos-.5)
Example #32
0
        move_layout(layout1, x=i * (nleaves + 2))
        move_layout(layout2, x=(i + 1) * (nleaves + 2))
        mapping = get_mapping(tree1, tree2, pos, times)
        draw_mapping(tree1, tree2, layout1, layout2, times, mapping)

#=============================================================================
if 0:
    k = 10
    n = 1e4
    rho = 1.5e-8 * 100
    mu = 2.5e-8
    length = 1000

    arg = arglib.sample_arg(k, n, rho, start=0, end=length)
    arglib.write_arg("tmp/a.arg", arg)
    muts = arglib.sample_arg_mutations(arg, mu)
    seqs = arglib.make_alignment(arg, muts)

    #arg = arglib.read_arg("tmp/a.arg")
    #arg.set_ancestral()
    #find_recomb_coal(tree, last_tree, recomb_name=None, pos=None)

    times = arghmm.get_time_points(30, maxtime=60e3)
    arghmm.discretize_arg(arg, times)

    # get recombs
    recombs = list(x.pos for x in arghmm.iter_visible_recombs(arg))
    print "recomb", recombs

    pos = recombs[0] + 1
    tree = arg.get_marginal_tree(pos - .5)
Example #33
0
    def test_trans2(self):
        """
        Calculate transition probabilities for k=2

        Only calculate a single matrix
        """

        k = 2
        n = 1e4
        rho = 1.5e-8 * 20
        mu = 2.5e-8 * 20
        length = 1000
        times = arghmm.get_time_points(ntimes=5, maxtime=200000)

        arg = arglib.sample_arg(k, 2*n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        arghmm.discretize_arg(arg, times)
        print "recomb", arglib.get_recomb_pos(arg)

        new_name = "n%d" % (k-1)
        arg = arghmm.make_trunk_arg(0, length, "n0")
        model = arghmm.ArgHmm(arg, seqs, new_name=new_name,
                              popsize=n, rho=rho, mu=mu,
                              times=times)

        pos = 10
        tree = arg.get_marginal_tree(pos)
        model.check_local_tree(pos, force=True)
        mat = arghmm.calc_transition_probs(
            tree, model.states[pos], model.nlineages,
            model.times, model.time_steps, model.popsizes, rho)

        states = model.states[pos]
        nstates = len(states)

        def coal(j):
            return 1.0 - exp(-model.time_steps[j]/(2.0 * n))

        def recoal2(k, j):
            p = coal(j)
            for m in range(k, j):
                p *= 1.0 - coal(m)
            return p

        def recoal(k, j):
            if j == nstates-1:
                return exp(- sum(model.time_steps[m] / (2.0 * n)
                              for m in range(k, j)))
            else:
                return ((1.0 - exp(-model.time_steps[j]/(2.0 * n))) *
                        exp(- sum(model.time_steps[m] / (2.0 * n)
                                  for m in range(k, j))))

        def isrecomb(i):
            return 1.0 - exp(-max(rho * 2.0 * model.times[i], rho))

        def recomb(i, k):
            treelen = 2*model.times[i] + model.time_steps[i]
            if k < i:
                return 2.0 * model.time_steps[k] / treelen / 2.0
            else:
                return model.time_steps[k] / treelen / 2.0

        def trans(i, j):
            a = states[i][1]
            b = states[j][1]

            p = sum(recoal(k, b) * recomb(a, k)
                    for k in range(0, min(a, b)+1))
            p += sum(recoal(k, b) * recomb(a, k)
                     for k in range(0, min(a, b)+1))
            p *= isrecomb(a)
            if i == j:
                p += 1.0 - isrecomb(a)
            return p


        for i in range(len(states)):
            for j in range(len(states)):
                print isrecomb(states[i][1])
                print states[i], states[j], mat[i][j], log(trans(i, j))
                fequal(mat[i][j], log(trans(i, j)))


            # recombs add up to 1
            fequal(sum(recomb(i, k) for k in range(i+1)), 0.5)

            # recoal add up to 1
            fequal(sum(recoal(i, j) for j in range(i, nstates)), 1.0)

            # recomb * recoal add up to .5
            fequal(sum(sum(recoal(k, j) * recomb(i, k)
                           for k in range(0, min(i, j)+1))
                       for j in range(0, nstates)), 0.5)

            fequal(sum(trans(i, j) for j in range(len(states))), 1.0)
Example #34
0
    def test_determ(self):

        k = 8
        n = 1e4
        rho = 1.5e-8
        mu = 2.5e-8
        length = 100000

        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)
        times = arghmm.get_time_points(maxtime=50000, ntimes=20)
        arghmm.discretize_arg(arg, times)

        new_name = "n%d" % (k - 1)
        thread = list(arghmm.iter_chrom_thread(arg, arg[new_name],
                                               by_block=False))
        thread_clades = list(arghmm.iter_chrom_thread(
            arg, arg[new_name], by_block=True, use_clades=True))

        # remove chrom
        keep = ["n%d" % i for i in range(k-1)]
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)

        model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times,
                              rho=rho, mu=mu)


        for i, rpos in enumerate(model.recomb_pos[1:-1]):
            pos = rpos + 1

            model.check_local_tree(pos, force=True)

            #recomb = arghmm.find_tree_next_recomb(arg, pos - 1)
            tree = arg.get_marginal_tree(pos-.5)
            last_tree = arg.get_marginal_tree(pos-1-.5)
            states1 = model.states[pos-1]
            states2 = model.states[pos]
            (recomb_branch, recomb_time), (coal_branch, coal_time) = \
                arghmm.find_recomb_coal(tree, last_tree, pos=rpos)
            recomb_time = times.index(recomb_time)
            coal_time = times.index(coal_time)

            determ = arghmm.get_deterministic_transitions(
                states1, states2, times,
                tree, last_tree,
                recomb_branch, recomb_time,
                coal_branch, coal_time)


            leaves1, time1, block1 = thread_clades[i]
            leaves2, time2, block2 = thread_clades[i+1]
            if new_name in leaves1:
                leaves1.remove(new_name)
            if new_name in leaves2:
                leaves2.remove(new_name)
            node1 = arghmm.arg_lca(arg, leaves1, None, pos-1).name
            node2 = arghmm.arg_lca(arg, leaves2, None, pos).name

            state1 = (node1, times.index(time1))
            state2 = (node2, times.index(time2))
            print pos, state1, state2
            try:
                statei1 = states1.index(state1)
                statei2 = states2.index(state2)
            except:
                print "states1", states1
                print "states2", states2
                raise

            statei3 = determ[statei1]
            print "  ", statei1, statei2, statei3, states2[statei3]
            if statei2 != statei3 and statei3 != -1:
                tree = tree.get_tree()
                treelib.remove_single_children(tree)
                last_tree = last_tree.get_tree()
                treelib.remove_single_children(last_tree)

                print "block1", block1
                print "block2", block2
                print "r=", (recomb_branch, recomb_time)
                print "c=", (coal_branch, coal_time)

                print "tree"
                treelib.draw_tree_names(tree, minlen=8, maxlen=8)

                print "last_tree"
                treelib.draw_tree_names(last_tree, minlen=8, maxlen=8)
                assert False
Example #35
0
    def test_determ(self):

        k = 8
        n = 1e4
        rho = 1.5e-8
        mu = 2.5e-8
        length = 100000

        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)
        times = arghmm.get_time_points(maxtime=50000, ntimes=20)
        arghmm.discretize_arg(arg, times)

        new_name = "n%d" % (k - 1)
        thread = list(
            arghmm.iter_chrom_thread(arg, arg[new_name], by_block=False))
        thread_clades = list(
            arghmm.iter_chrom_thread(arg,
                                     arg[new_name],
                                     by_block=True,
                                     use_clades=True))

        # remove chrom
        keep = ["n%d" % i for i in range(k - 1)]
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)

        model = arghmm.ArgHmm(arg,
                              seqs,
                              new_name=new_name,
                              times=times,
                              rho=rho,
                              mu=mu)

        for i, rpos in enumerate(model.recomb_pos[1:-1]):
            pos = rpos + 1

            model.check_local_tree(pos, force=True)

            #recomb = arghmm.find_tree_next_recomb(arg, pos - 1)
            tree = arg.get_marginal_tree(pos - .5)
            last_tree = arg.get_marginal_tree(pos - 1 - .5)
            states1 = model.states[pos - 1]
            states2 = model.states[pos]
            (recomb_branch, recomb_time), (coal_branch, coal_time) = \
                arghmm.find_recomb_coal(tree, last_tree, pos=rpos)
            recomb_time = times.index(recomb_time)
            coal_time = times.index(coal_time)

            determ = arghmm.get_deterministic_transitions(
                states1, states2, times, tree, last_tree, recomb_branch,
                recomb_time, coal_branch, coal_time)

            leaves1, time1, block1 = thread_clades[i]
            leaves2, time2, block2 = thread_clades[i + 1]
            if new_name in leaves1:
                leaves1.remove(new_name)
            if new_name in leaves2:
                leaves2.remove(new_name)
            node1 = arghmm.arg_lca(arg, leaves1, None, pos - 1).name
            node2 = arghmm.arg_lca(arg, leaves2, None, pos).name

            state1 = (node1, times.index(time1))
            state2 = (node2, times.index(time2))
            print pos, state1, state2
            try:
                statei1 = states1.index(state1)
                statei2 = states2.index(state2)
            except:
                print "states1", states1
                print "states2", states2
                raise

            statei3 = determ[statei1]
            print "  ", statei1, statei2, statei3, states2[statei3]
            if statei2 != statei3 and statei3 != -1:
                tree = tree.get_tree()
                treelib.remove_single_children(tree)
                last_tree = last_tree.get_tree()
                treelib.remove_single_children(last_tree)

                print "block1", block1
                print "block2", block2
                print "r=", (recomb_branch, recomb_time)
                print "c=", (coal_branch, coal_time)

                print "tree"
                treelib.draw_tree_names(tree, minlen=8, maxlen=8)

                print "last_tree"
                treelib.draw_tree_names(last_tree, minlen=8, maxlen=8)
                assert False
Example #36
0
    def test_trans2(self):
        """
        Calculate transition probabilities for k=2

        Only calculate a single matrix
        """

        k = 2
        n = 1e4
        rho = 1.5e-8 * 20
        mu = 2.5e-8 * 20
        length = 1000
        times = arghmm.get_time_points(ntimes=5, maxtime=200000)

        arg = arglib.sample_arg(k, 2 * n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        arghmm.discretize_arg(arg, times)
        print "recomb", arglib.get_recomb_pos(arg)

        new_name = "n%d" % (k - 1)
        arg = arghmm.make_trunk_arg(0, length, "n0")
        model = arghmm.ArgHmm(arg,
                              seqs,
                              new_name=new_name,
                              popsize=n,
                              rho=rho,
                              mu=mu,
                              times=times)

        pos = 10
        tree = arg.get_marginal_tree(pos)
        model.check_local_tree(pos, force=True)
        mat = arghmm.calc_transition_probs(tree, model.states[pos],
                                           model.nlineages, model.times,
                                           model.time_steps, model.popsizes,
                                           rho)

        states = model.states[pos]
        nstates = len(states)

        def coal(j):
            return 1.0 - exp(-model.time_steps[j] / (2.0 * n))

        def recoal2(k, j):
            p = coal(j)
            for m in range(k, j):
                p *= 1.0 - coal(m)
            return p

        def recoal(k, j):
            if j == nstates - 1:
                return exp(-sum(model.time_steps[m] / (2.0 * n)
                                for m in range(k, j)))
            else:
                return ((1.0 - exp(-model.time_steps[j] / (2.0 * n))) *
                        exp(-sum(model.time_steps[m] / (2.0 * n)
                                 for m in range(k, j))))

        def isrecomb(i):
            return 1.0 - exp(-max(rho * 2.0 * model.times[i], rho))

        def recomb(i, k):
            treelen = 2 * model.times[i] + model.time_steps[i]
            if k < i:
                return 2.0 * model.time_steps[k] / treelen / 2.0
            else:
                return model.time_steps[k] / treelen / 2.0

        def trans(i, j):
            a = states[i][1]
            b = states[j][1]

            p = sum(
                recoal(k, b) * recomb(a, k) for k in range(0,
                                                           min(a, b) + 1))
            p += sum(
                recoal(k, b) * recomb(a, k) for k in range(0,
                                                           min(a, b) + 1))
            p *= isrecomb(a)
            if i == j:
                p += 1.0 - isrecomb(a)
            return p

        for i in range(len(states)):
            for j in range(len(states)):
                print isrecomb(states[i][1])
                print states[i], states[j], mat[i][j], log(trans(i, j))
                fequal(mat[i][j], log(trans(i, j)))

            # recombs add up to 1
            fequal(sum(recomb(i, k) for k in range(i + 1)), 0.5)

            # recoal add up to 1
            fequal(sum(recoal(i, j) for j in range(i, nstates)), 1.0)

            # recomb * recoal add up to .5
            fequal(
                sum(
                    sum(
                        recoal(k, j) * recomb(i, k)
                        for k in range(0,
                                       min(i, j) + 1))
                    for j in range(0, nstates)), 0.5)

            fequal(sum(trans(i, j) for j in range(len(states))), 1.0)