コード例 #1
0
ファイル: test_basic.py プロジェクト: swamidass/argweaver
    def test_post_c(self):

        k = 3
        n = 1e4
        rho = 1.5e-8 * 30
        mu = 2.5e-8 * 100
        length = 100
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        arg.prune()
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        print arglib.get_recomb_pos(arg)
        print "muts", len(muts)
        print "recomb", len(arglib.get_recomb_pos(arg))

        times = arghmm.get_time_points(ntimes=10)
        arghmm.discretize_arg(arg, times)

        tree = arg.get_marginal_tree(0)
        print tree.root.age
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)

        # remove chrom
        keep = ["n%d" % i for i in range(k - 1)]
        arglib.subarg_by_leaf_names(arg, keep)

        model = arghmm.ArgHmm(arg,
                              seqs,
                              new_name="n%d" % (k - 1),
                              times=times,
                              rho=rho,
                              mu=mu)
        print "states", len(model.states[0])

        util.tic("C")
        probs1 = list(arghmm.get_posterior_probs(model, length, verbose=True))
        util.toc()

        util.tic("python")
        probs2 = list(hmm.get_posterior_probs(model, length, verbose=True))
        util.toc()

        print "probs1"
        pc(probs1)

        print "probs2"
        pc(probs2)

        for col1, col2 in izip(probs1, probs2):
            for a, b in izip(col1, col2):
                fequal(a, b)
コード例 #2
0
ファイル: test_basic.py プロジェクト: bredelings/argweaver
    def test_post_c(self):

        k = 3
        n = 1e4
        rho = 1.5e-8 * 30
        mu = 2.5e-8 * 100
        length = 100
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        arg.prune()
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        print arglib.get_recomb_pos(arg)
        print "muts", len(muts)
        print "recomb", len(arglib.get_recomb_pos(arg))

        times = arghmm.get_time_points(ntimes=10)
        arghmm.discretize_arg(arg, times)

        tree = arg.get_marginal_tree(0)
        print tree.root.age
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)

        # remove chrom
        keep = ["n%d" % i for i in range(k-1)]
        arglib.subarg_by_leaf_names(arg, keep)

        model = arghmm.ArgHmm(arg, seqs, new_name="n%d" % (k-1), times=times,
                              rho=rho, mu=mu)
        print "states", len(model.states[0])

        util.tic("C")
        probs1 = list(arghmm.get_posterior_probs(model, length, verbose=True))
        util.toc()

        util.tic("python")
        probs2 = list(hmm.get_posterior_probs(model, length, verbose=True))
        util.toc()

        print "probs1"
        pc(probs1)

        print "probs2"
        pc(probs2)


        for col1, col2 in izip(probs1, probs2):
            for a, b in izip(col1, col2):
                fequal(a, b)
コード例 #3
0
ファイル: test_basic.py プロジェクト: swamidass/argweaver
    def test_trans_single(self):
        """
        Calculate transition probabilities

        Only calculate a single matrix
        """

        k = 4
        n = 1e4
        rho = 1.5e-8
        mu = 2.5e-8
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        times = arghmm.get_time_points(10)
        arghmm.discretize_arg(arg, times)
        print "recomb", arglib.get_recomb_pos(arg)

        new_name = "n%d" % (k - 1)
        arg = arghmm.remove_arg_thread(arg, new_name)
        model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times)

        pos = 10
        tree = arg.get_marginal_tree(pos)
        mat = arghmm.calc_transition_probs(tree, model.states[pos],
                                           model.nlineages, model.times,
                                           model.time_steps, model.popsizes,
                                           rho)
        print model.states[pos]
        pc(mat)

        for row in mat:
            print sum(map(exp, row))
コード例 #4
0
ファイル: test_basic.py プロジェクト: bredelings/argweaver
    def test_post(self):

        k = 6
        n = 1e4
        rho = 1.5e-8 * 10
        mu = 2.5e-8 * 10
        length = 10000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)
        print "muts", len(muts)
        print "recombs", len(arglib.get_recomb_pos(arg))

        times = arghmm.get_time_points(ntimes=10)
        arghmm.discretize_arg(arg, times)

        tree = arg.get_marginal_tree(0)
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)

        # remove chrom
        new_name = "n%d" % (k - 1)
        keep = set(arg.leaf_names()) - set([new_name])
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)

        model = arghmm.ArgHmm(arg, seqs, new_name=new_name,
                              times=times, rho=rho, mu=mu)
        print "states", len(model.states[0])

        probs = arghmm.get_posterior_probs(model, length, verbose=True)

        for pcol in probs:
            p = sum(map(exp, pcol))
            print p, " ".join("%.3f" % f for f in map(exp, pcol))
            fequal(p, 1.0, rel=1e-2)
コード例 #5
0
ファイル: test_basic.py プロジェクト: bredelings/argweaver
    def test_trans_single(self):
        """
        Calculate transition probabilities

        Only calculate a single matrix
        """

        k = 4
        n = 1e4
        rho = 1.5e-8
        mu = 2.5e-8
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        times = arghmm.get_time_points(10)
        arghmm.discretize_arg(arg, times)
        print "recomb", arglib.get_recomb_pos(arg)

        new_name = "n%d" % (k-1)
        arg = arghmm.remove_arg_thread(arg, new_name)
        model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times)

        pos = 10
        tree = arg.get_marginal_tree(pos)
        mat = arghmm.calc_transition_probs(
            tree, model.states[pos], model.nlineages,
            model.times, model.time_steps, model.popsizes, rho)
        print model.states[pos]
        pc(mat)

        for row in mat:
            print sum(map(exp, row))
コード例 #6
0
ファイル: test_basic.py プロジェクト: swamidass/argweaver
    def test_post(self):

        k = 6
        n = 1e4
        rho = 1.5e-8 * 10
        mu = 2.5e-8 * 10
        length = 10000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)
        print "muts", len(muts)
        print "recombs", len(arglib.get_recomb_pos(arg))

        times = arghmm.get_time_points(ntimes=10)
        arghmm.discretize_arg(arg, times)

        tree = arg.get_marginal_tree(0)
        treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)

        # remove chrom
        new_name = "n%d" % (k - 1)
        keep = set(arg.leaf_names()) - set([new_name])
        arglib.subarg_by_leaf_names(arg, keep)
        arg = arglib.smcify_arg(arg)

        model = arghmm.ArgHmm(arg,
                              seqs,
                              new_name=new_name,
                              times=times,
                              rho=rho,
                              mu=mu)
        print "states", len(model.states[0])

        probs = arghmm.get_posterior_probs(model, length, verbose=True)

        for pcol in probs:
            p = sum(map(exp, pcol))
            print p, " ".join("%.3f" % f for f in map(exp, pcol))
            fequal(p, 1.0, rel=1e-2)
コード例 #7
0
def test_forward():

    k = 4
    n = 1e4
    rho = 1.5e-8 * 20
    mu = 2.5e-8 * 20
    length = int(100e3 / 20)
    times = argweaver.get_time_points(ntimes=100)

    arg = arglib.sample_arg_smc(k, 2 * n, rho, start=0, end=length)
    muts = arglib.sample_arg_mutations(arg, mu)
    seqs = arglib.make_alignment(arg, muts)

    print "muts", len(muts)
    print "recomb", len(arglib.get_recomb_pos(arg))

    argweaver.discretize_arg(arg, times)

    # remove chrom
    new_name = "n%d" % (k - 1)
    arg = argweaver.remove_arg_thread(arg, new_name)

    carg = argweaverc.arg2ctrees(arg, times)

    util.tic("C fast")
    probs1 = argweaverc.argweaver_forward_algorithm(carg, seqs, times=times)
    util.toc()

    util.tic("C slow")
    probs2 = argweaverc.argweaver_forward_algorithm(carg,
                                                    seqs,
                                                    times=times,
                                                    slow=True)
    util.toc()

    for i, (col1, col2) in enumerate(izip(probs1, probs2)):
        for a, b in izip(col1, col2):
            fequal(a, b, rel=.0001)
コード例 #8
0
ファイル: test_basic.py プロジェクト: swamidass/argweaver
    def test_recomb(self):
        """
        Investigate the fact that some recombinations are not visible
        """

        k = 3
        n = 1e4
        rho = 1.5e-8 * 20
        mu = 2.5e-8
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)

        times = arghmm.get_time_points(10)
        arghmm.discretize_arg(arg, times)
        arg.set_ancestral()
        arg.prune()

        recombs = arglib.get_recomb_pos(arg)

        # find recombs by walking
        recombs2 = []
        i = 0
        while True:
            tree = arg.get_marginal_tree(i - .5)
            recomb = arghmm.find_tree_next_recomb(tree, i + 1, tree=True)
            if recomb:
                recombs2.append(recomb.pos)
                i = recomb.pos
            else:
                break

        # these are suppose to differ because some recombination occur
        # in the hole of ancestral sequence intervals
        print recombs
        print recombs2

        arglib.write_arg("tmp/b.arg", arg)
コード例 #9
0
ファイル: test_basic.py プロジェクト: bredelings/argweaver
    def test_recomb(self):
        """
        Investigate the fact that some recombinations are not visible
        """

        k = 3
        n = 1e4
        rho = 1.5e-8 * 20
        mu = 2.5e-8
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)

        times = arghmm.get_time_points(10)
        arghmm.discretize_arg(arg, times)
        arg.set_ancestral()
        arg.prune()

        recombs = arglib.get_recomb_pos(arg)

        # find recombs by walking
        recombs2 = []
        i = 0
        while True:
            tree = arg.get_marginal_tree(i-.5)
            recomb = arghmm.find_tree_next_recomb(tree, i+1, tree=True)
            if recomb:
                recombs2.append(recomb.pos)
                i = recomb.pos
            else:
                break

        # these are suppose to differ because some recombination occur
        # in the hole of ancestral sequence intervals
        print recombs
        print recombs2

        arglib.write_arg("tmp/b.arg", arg)
コード例 #10
0
ファイル: test_basic.py プロジェクト: swamidass/argweaver
    def test_remove_thread(self):
        """
        Remove a leaf from an ARG
        """

        k = 3
        n = 1e4
        rho = 1.5e-8 * 20
        mu = 2.5e-8
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)

        times = arghmm.get_time_points(10)
        arghmm.discretize_arg(arg, times)
        chrom = "n%d" % (k - 1)
        arg = arghmm.remove_arg_thread(arg, chrom)

        recomb = arglib.get_recomb_pos(arg)
        print "recomb", recomb

        tree = arg.get_marginal_tree(-.5)

        draw_tree_names(tree.get_tree(), minlen=5, maxlen=5)
        print sorted([(x.pos, x.event) for x in tree])
コード例 #11
0
ファイル: test_basic.py プロジェクト: bredelings/argweaver
    def test_remove_thread(self):
        """
        Remove a leaf from an ARG
        """

        k = 3
        n = 1e4
        rho = 1.5e-8 * 20
        mu = 2.5e-8
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)

        times = arghmm.get_time_points(10)
        arghmm.discretize_arg(arg, times)
        chrom = "n%d" % (k-1)
        arg = arghmm.remove_arg_thread(arg, chrom)

        recomb = arglib.get_recomb_pos(arg)
        print "recomb", recomb

        tree = arg.get_marginal_tree(-.5)

        draw_tree_names(tree.get_tree(), minlen=5, maxlen=5)
        print sorted([(x.pos, x.event) for x in tree])
コード例 #12
0
ファイル: test_hmm.py プロジェクト: jeffhsu3/argweaver
def test_forward():

    k = 4
    n = 1e4
    rho = 1.5e-8 * 20
    mu = 2.5e-8 * 20
    length = int(100e3 / 20)
    times = argweaver.get_time_points(ntimes=100)

    arg = arglib.sample_arg_smc(k, 2*n, rho, start=0, end=length)
    muts = arglib.sample_arg_mutations(arg, mu)
    seqs = arglib.make_alignment(arg, muts)

    print "muts", len(muts)
    print "recomb", len(arglib.get_recomb_pos(arg))

    argweaver.discretize_arg(arg, times)

    # remove chrom
    new_name = "n%d" % (k - 1)
    arg = argweaver.remove_arg_thread(arg, new_name)

    carg = argweaverc.arg2ctrees(arg, times)

    util.tic("C fast")
    probs1 = argweaverc.argweaver_forward_algorithm(carg, seqs, times=times)
    util.toc()

    util.tic("C slow")
    probs2 = argweaverc.argweaver_forward_algorithm(carg, seqs, times=times,
                                                    slow=True)
    util.toc()

    for i, (col1, col2) in enumerate(izip(probs1, probs2)):
        for a, b in izip(col1, col2):
            fequal(a, b, rel=.0001)
コード例 #13
0
def test_trans_two():
    """
    Calculate transition probabilities for k=2

    Only calculate a single matrix
    """

    k = 2
    n = 1e4
    rho = 1.5e-8 * 20
    length = 1000
    times = argweaver.get_time_points(ntimes=5, maxtime=200000)
    time_steps = [times[i] - times[i - 1] for i in range(1, len(times))]
    time_steps.append(200000 * 10000.0)
    popsizes = [n] * len(times)

    arg = arglib.sample_arg(k, 2 * n, rho, start=0, end=length)

    argweaver.discretize_arg(arg, times)
    print "recomb", arglib.get_recomb_pos(arg)

    arg = argweaver.make_trunk_arg(0, length, "n0")

    pos = 10
    tree = arg.get_marginal_tree(pos)
    nlineages = argweaver.get_nlineages_recomb_coal(tree, times)
    states = list(argweaver.iter_coal_states(tree, times))
    mat = argweaver.calc_transition_probs(tree, states, nlineages, times,
                                          time_steps, popsizes, rho)

    nstates = len(states)

    def coal(j):
        return 1.0 - exp(-time_steps[j] / (2.0 * n))

    def recoal2(k, j):
        p = coal(j)
        for m in range(k, j):
            p *= 1.0 - coal(m)
        return p

    def recoal(k, j):
        if j == nstates - 1:
            return exp(-sum(time_steps[m] / (2.0 * n) for m in range(k, j)))
        else:
            return ((1.0 - exp(-time_steps[j] / (2.0 * n))) *
                    exp(-sum(time_steps[m] / (2.0 * n) for m in range(k, j))))

    def isrecomb(i):
        return 1.0 - exp(-max(rho * 2.0 * times[i], rho))

    def recomb(i, k):
        treelen = 2 * times[i] + time_steps[i]
        if k < i:
            return 2.0 * time_steps[k] / treelen / 2.0
        else:
            return time_steps[k] / treelen / 2.0

    def trans(i, j):
        a = states[i][1]
        b = states[j][1]

        p = sum(recoal(k, b) * recomb(a, k) for k in range(0, min(a, b) + 1))
        p += sum(recoal(k, b) * recomb(a, k) for k in range(0, min(a, b) + 1))
        p *= isrecomb(a)
        if i == j:
            p += 1.0 - isrecomb(a)
        return p

    for i in range(len(states)):
        for j in range(len(states)):
            print isrecomb(states[i][1])
            print states[i], states[j], mat[i][j], log(trans(i, j))
            fequal(mat[i][j], log(trans(i, j)))

        # recombs add up to 1
        fequal(sum(recomb(i, k) for k in range(i + 1)), 0.5)

        # recoal add up to 1
        fequal(sum(recoal(i, j) for j in range(i, nstates)), 1.0)

        # recomb * recoal add up to .5
        fequal(
            sum(
                sum(
                    recoal(k, j) * recomb(i, k)
                    for k in range(0,
                                   min(i, j) + 1))
                for j in range(0, nstates)), 0.5)

        fequal(sum(trans(i, j) for j in range(len(states))), 1.0)
コード例 #14
0
ファイル: test_basic.py プロジェクト: swamidass/argweaver
    def test_trans2(self):
        """
        Calculate transition probabilities for k=2

        Only calculate a single matrix
        """

        k = 2
        n = 1e4
        rho = 1.5e-8 * 20
        mu = 2.5e-8 * 20
        length = 1000
        times = arghmm.get_time_points(ntimes=5, maxtime=200000)

        arg = arglib.sample_arg(k, 2 * n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        arghmm.discretize_arg(arg, times)
        print "recomb", arglib.get_recomb_pos(arg)

        new_name = "n%d" % (k - 1)
        arg = arghmm.make_trunk_arg(0, length, "n0")
        model = arghmm.ArgHmm(arg,
                              seqs,
                              new_name=new_name,
                              popsize=n,
                              rho=rho,
                              mu=mu,
                              times=times)

        pos = 10
        tree = arg.get_marginal_tree(pos)
        model.check_local_tree(pos, force=True)
        mat = arghmm.calc_transition_probs(tree, model.states[pos],
                                           model.nlineages, model.times,
                                           model.time_steps, model.popsizes,
                                           rho)

        states = model.states[pos]
        nstates = len(states)

        def coal(j):
            return 1.0 - exp(-model.time_steps[j] / (2.0 * n))

        def recoal2(k, j):
            p = coal(j)
            for m in range(k, j):
                p *= 1.0 - coal(m)
            return p

        def recoal(k, j):
            if j == nstates - 1:
                return exp(-sum(model.time_steps[m] / (2.0 * n)
                                for m in range(k, j)))
            else:
                return ((1.0 - exp(-model.time_steps[j] / (2.0 * n))) *
                        exp(-sum(model.time_steps[m] / (2.0 * n)
                                 for m in range(k, j))))

        def isrecomb(i):
            return 1.0 - exp(-max(rho * 2.0 * model.times[i], rho))

        def recomb(i, k):
            treelen = 2 * model.times[i] + model.time_steps[i]
            if k < i:
                return 2.0 * model.time_steps[k] / treelen / 2.0
            else:
                return model.time_steps[k] / treelen / 2.0

        def trans(i, j):
            a = states[i][1]
            b = states[j][1]

            p = sum(
                recoal(k, b) * recomb(a, k) for k in range(0,
                                                           min(a, b) + 1))
            p += sum(
                recoal(k, b) * recomb(a, k) for k in range(0,
                                                           min(a, b) + 1))
            p *= isrecomb(a)
            if i == j:
                p += 1.0 - isrecomb(a)
            return p

        for i in range(len(states)):
            for j in range(len(states)):
                print isrecomb(states[i][1])
                print states[i], states[j], mat[i][j], log(trans(i, j))
                fequal(mat[i][j], log(trans(i, j)))

            # recombs add up to 1
            fequal(sum(recomb(i, k) for k in range(i + 1)), 0.5)

            # recoal add up to 1
            fequal(sum(recoal(i, j) for j in range(i, nstates)), 1.0)

            # recomb * recoal add up to .5
            fequal(
                sum(
                    sum(
                        recoal(k, j) * recomb(i, k)
                        for k in range(0,
                                       min(i, j) + 1))
                    for j in range(0, nstates)), 0.5)

            fequal(sum(trans(i, j) for j in range(len(states))), 1.0)
コード例 #15
0
ファイル: test_basic.py プロジェクト: bredelings/argweaver
    def test_trans2(self):
        """
        Calculate transition probabilities for k=2

        Only calculate a single matrix
        """

        k = 2
        n = 1e4
        rho = 1.5e-8 * 20
        mu = 2.5e-8 * 20
        length = 1000
        times = arghmm.get_time_points(ntimes=5, maxtime=200000)

        arg = arglib.sample_arg(k, 2*n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        arghmm.discretize_arg(arg, times)
        print "recomb", arglib.get_recomb_pos(arg)

        new_name = "n%d" % (k-1)
        arg = arghmm.make_trunk_arg(0, length, "n0")
        model = arghmm.ArgHmm(arg, seqs, new_name=new_name,
                              popsize=n, rho=rho, mu=mu,
                              times=times)

        pos = 10
        tree = arg.get_marginal_tree(pos)
        model.check_local_tree(pos, force=True)
        mat = arghmm.calc_transition_probs(
            tree, model.states[pos], model.nlineages,
            model.times, model.time_steps, model.popsizes, rho)

        states = model.states[pos]
        nstates = len(states)

        def coal(j):
            return 1.0 - exp(-model.time_steps[j]/(2.0 * n))

        def recoal2(k, j):
            p = coal(j)
            for m in range(k, j):
                p *= 1.0 - coal(m)
            return p

        def recoal(k, j):
            if j == nstates-1:
                return exp(- sum(model.time_steps[m] / (2.0 * n)
                              for m in range(k, j)))
            else:
                return ((1.0 - exp(-model.time_steps[j]/(2.0 * n))) *
                        exp(- sum(model.time_steps[m] / (2.0 * n)
                                  for m in range(k, j))))

        def isrecomb(i):
            return 1.0 - exp(-max(rho * 2.0 * model.times[i], rho))

        def recomb(i, k):
            treelen = 2*model.times[i] + model.time_steps[i]
            if k < i:
                return 2.0 * model.time_steps[k] / treelen / 2.0
            else:
                return model.time_steps[k] / treelen / 2.0

        def trans(i, j):
            a = states[i][1]
            b = states[j][1]

            p = sum(recoal(k, b) * recomb(a, k)
                    for k in range(0, min(a, b)+1))
            p += sum(recoal(k, b) * recomb(a, k)
                     for k in range(0, min(a, b)+1))
            p *= isrecomb(a)
            if i == j:
                p += 1.0 - isrecomb(a)
            return p


        for i in range(len(states)):
            for j in range(len(states)):
                print isrecomb(states[i][1])
                print states[i], states[j], mat[i][j], log(trans(i, j))
                fequal(mat[i][j], log(trans(i, j)))


            # recombs add up to 1
            fequal(sum(recomb(i, k) for k in range(i+1)), 0.5)

            # recoal add up to 1
            fequal(sum(recoal(i, j) for j in range(i, nstates)), 1.0)

            # recomb * recoal add up to .5
            fequal(sum(sum(recoal(k, j) * recomb(i, k)
                           for k in range(0, min(i, j)+1))
                       for j in range(0, nstates)), 0.5)

            fequal(sum(trans(i, j) for j in range(len(states))), 1.0)
コード例 #16
0
ファイル: test_hmm.py プロジェクト: jeffhsu3/argweaver
def test_trans_two():
    """
    Calculate transition probabilities for k=2

    Only calculate a single matrix
    """

    k = 2
    n = 1e4
    rho = 1.5e-8 * 20
    length = 1000
    times = argweaver.get_time_points(ntimes=5, maxtime=200000)
    time_steps = [times[i] - times[i-1]
                  for i in range(1, len(times))]
    time_steps.append(200000*10000.0)
    popsizes = [n] * len(times)

    arg = arglib.sample_arg(k, 2*n, rho, start=0, end=length)

    argweaver.discretize_arg(arg, times)
    print "recomb", arglib.get_recomb_pos(arg)

    arg = argweaver.make_trunk_arg(0, length, "n0")

    pos = 10
    tree = arg.get_marginal_tree(pos)
    nlineages = argweaver.get_nlineages_recomb_coal(tree, times)
    states = list(argweaver.iter_coal_states(tree, times))
    mat = argweaver.calc_transition_probs(
        tree, states, nlineages,
        times, time_steps, popsizes, rho)

    nstates = len(states)

    def coal(j):
        return 1.0 - exp(-time_steps[j]/(2.0 * n))

    def recoal2(k, j):
        p = coal(j)
        for m in range(k, j):
            p *= 1.0 - coal(m)
        return p

    def recoal(k, j):
        if j == nstates-1:
            return exp(- sum(time_steps[m] / (2.0 * n)
                             for m in range(k, j)))
        else:
            return ((1.0 - exp(-time_steps[j]/(2.0 * n))) *
                    exp(- sum(time_steps[m] / (2.0 * n)
                              for m in range(k, j))))

    def isrecomb(i):
        return 1.0 - exp(-max(rho * 2.0 * times[i], rho))

    def recomb(i, k):
        treelen = 2*times[i] + time_steps[i]
        if k < i:
            return 2.0 * time_steps[k] / treelen / 2.0
        else:
            return time_steps[k] / treelen / 2.0

    def trans(i, j):
        a = states[i][1]
        b = states[j][1]

        p = sum(recoal(k, b) * recomb(a, k)
                for k in range(0, min(a, b)+1))
        p += sum(recoal(k, b) * recomb(a, k)
                 for k in range(0, min(a, b)+1))
        p *= isrecomb(a)
        if i == j:
            p += 1.0 - isrecomb(a)
        return p

    for i in range(len(states)):
        for j in range(len(states)):
            print isrecomb(states[i][1])
            print states[i], states[j], mat[i][j], log(trans(i, j))
            fequal(mat[i][j], log(trans(i, j)))

        # recombs add up to 1
        fequal(sum(recomb(i, k) for k in range(i+1)), 0.5)

        # recoal add up to 1
        fequal(sum(recoal(i, j) for j in range(i, nstates)), 1.0)

        # recomb * recoal add up to .5
        fequal(sum(sum(recoal(k, j) * recomb(i, k)
                       for k in range(0, min(i, j)+1))
                   for j in range(0, nstates)), 0.5)

        fequal(sum(trans(i, j) for j in range(len(states))), 1.0)