コード例 #1
0
ファイル: test_hmm.py プロジェクト: swamidass/argweaver
def test_trans():
    """
    Calculate transition probabilities
    """
    create_data = False
    if create_data:
        make_clean_dir('test/data/test_trans')

    k = 8
    n = 1e4
    rho = 1.5e-8 * 20
    length = 1000
    times = argweaver.get_time_points(ntimes=10, maxtime=200000)
    popsizes = [n] * len(times)
    ntests = 40

    # generate test data
    if create_data:
        for i in range(ntests):
            arg = arglib.sample_arg(k, 2*n, rho, start=0, end=length)
            argweaver.discretize_arg(arg, times)
            arg.write('test/data/test_trans/%d.arg' % i)

    for i in range(ntests):
        print 'arg', i
        arg = arglib.read_arg('test/data/test_trans/%d.arg' % i)
        argweaver.discretize_arg(arg, times)
        pos = 10
        tree = arg.get_marginal_tree(pos)

        assert argweaverc.assert_transition_probs(tree, times, popsizes, rho)
コード例 #2
0
def test_trans_switch():
    """
    Calculate transition probabilities for switch matrix

    Only calculate a single matrix
    """
    create_data = False
    if create_data:
        make_clean_dir('test/data/test_trans_switch')

    # model parameters
    k = 12
    n = 1e4
    rho = 1.5e-8 * 20
    length = 1000
    times = argweaver.get_time_points(ntimes=20, maxtime=200000)
    popsizes = [n] * len(times)
    ntests = 100

    # generate test data
    if create_data:
        for i in range(ntests):
            # Sample ARG with at least one recombination.
            while True:
                arg = argweaver.sample_arg_dsmc(k,
                                                2 * n,
                                                rho,
                                                start=0,
                                                end=length,
                                                times=times)
                if any(x.event == "recomb" for x in arg):
                    break
            arg.write('test/data/test_trans_switch/%d.arg' % i)

    for i in range(ntests):
        print('arg', i)
        arg = arglib.read_arg('test/data/test_trans_switch/%d.arg' % i)
        argweaver.discretize_arg(arg, times)
        recombs = [x.pos for x in arg if x.event == "recomb"]
        pos = recombs[0]
        tree = arg.get_marginal_tree(pos - .5)
        rpos, r, c = next(arglib.iter_arg_sprs(arg, start=pos - .5))
        spr = (r, c)

        if not argweaverc.assert_transition_switch_probs(
                tree, spr, times, popsizes, rho):
            tree2 = tree.get_tree()
            treelib.remove_single_children(tree2)
            treelib.draw_tree_names(tree2, maxlen=5, minlen=5)
            assert False
コード例 #3
0
ファイル: test_hmm.py プロジェクト: bredelings/argweaver
def test_trans_switch():
    """
    Calculate transition probabilities for switch matrix

    Only calculate a single matrix
    """
    create_data = False
    if create_data:
        make_clean_dir('test/data/test_trans_switch')

    # model parameters
    k = 12
    n = 1e4
    rho = 1.5e-8 * 20
    length = 1000
    times = argweaver.get_time_points(ntimes=20, maxtime=200000)
    popsizes = [n] * len(times)
    ntests = 100

    # generate test data
    if create_data:
        for i in range(ntests):
            # Sample ARG with at least one recombination.
            while True:
                arg = argweaver.sample_arg_dsmc(
                    k, 2*n, rho, start=0, end=length, times=times)
                if any(x.event == "recomb" for x in arg):
                    break
            arg.write('test/data/test_trans_switch/%d.arg' % i)

    for i in range(ntests):
        print 'arg', i
        arg = arglib.read_arg('test/data/test_trans_switch/%d.arg' % i)
        argweaver.discretize_arg(arg, times)
        recombs = [x.pos for x in arg if x.event == "recomb"]
        pos = recombs[0]
        tree = arg.get_marginal_tree(pos-.5)
        rpos, r, c = arglib.iter_arg_sprs(arg, start=pos-.5).next()
        spr = (r, c)

        if not argweaverc.assert_transition_switch_probs(
                tree, spr, times, popsizes, rho):
            tree2 = tree.get_tree()
            treelib.remove_single_children(tree2)
            treelib.draw_tree_names(tree2, maxlen=5, minlen=5)
            assert False
コード例 #4
0
def test_trans():
    """
    Calculate transition probabilities
    """

    k = 4
    n = 1e4
    rho = 1.5e-8 * 20
    length = 1000
    times = argweaver.get_time_points(ntimes=4, maxtime=200000)
    popsizes = [n] * len(times)

    arg = arglib.sample_arg(k, 2 * n, rho, start=0, end=length)
    argweaver.discretize_arg(arg, times)

    pos = 10
    tree = arg.get_marginal_tree(pos)

    assert argweaverc.assert_transition_probs(tree, times, popsizes, rho)
コード例 #5
0
ファイル: test_hmm.py プロジェクト: jeffhsu3/argweaver
def test_trans():
    """
    Calculate transition probabilities
    """

    k = 4
    n = 1e4
    rho = 1.5e-8 * 20
    length = 1000
    times = argweaver.get_time_points(ntimes=4, maxtime=200000)
    popsizes = [n] * len(times)

    arg = arglib.sample_arg(k, 2*n, rho, start=0, end=length)
    argweaver.discretize_arg(arg, times)

    pos = 10
    tree = arg.get_marginal_tree(pos)

    assert argweaverc.assert_transition_probs(tree, times, popsizes, rho)
コード例 #6
0
def test_trans_internal():
    """
    Calculate transition probabilities for internal branch re-sampling

    Only calculate a single matrix
    """

    k = 5
    n = 1e4
    rho = 1.5e-8 * 20
    length = 1000
    times = argweaver.get_time_points(ntimes=5, maxtime=200000)
    popsizes = [n] * len(times)

    arg = arglib.sample_arg(k, 2 * n, rho, start=0, end=length)
    argweaver.discretize_arg(arg, times)

    pos = 10
    tree = arg.get_marginal_tree(pos)

    assert argweaverc.assert_transition_probs_internal(tree, times, popsizes,
                                                       rho)
コード例 #7
0
ファイル: test_hmm.py プロジェクト: bredelings/argweaver
def test_trans_internal():
    """
    Calculate transition probabilities for internal branch re-sampling

    Only calculate a single matrix
    """

    k = 5
    n = 1e4
    rho = 1.5e-8 * 20
    length = 1000
    times = argweaver.get_time_points(ntimes=5, maxtime=200000)
    popsizes = [n] * len(times)

    arg = arglib.sample_arg(k, 2*n, rho, start=0, end=length)
    argweaver.discretize_arg(arg, times)

    pos = 10
    tree = arg.get_marginal_tree(pos)

    assert argweaverc.assert_transition_probs_internal(
        tree, times, popsizes, rho)
コード例 #8
0
def test_forward():

    k = 4
    n = 1e4
    rho = 1.5e-8 * 20
    mu = 2.5e-8 * 20
    length = int(100e3 / 20)
    times = argweaver.get_time_points(ntimes=100)

    arg = arglib.sample_arg_smc(k, 2 * n, rho, start=0, end=length)
    muts = arglib.sample_arg_mutations(arg, mu)
    seqs = arglib.make_alignment(arg, muts)

    print "muts", len(muts)
    print "recomb", len(arglib.get_recomb_pos(arg))

    argweaver.discretize_arg(arg, times)

    # remove chrom
    new_name = "n%d" % (k - 1)
    arg = argweaver.remove_arg_thread(arg, new_name)

    carg = argweaverc.arg2ctrees(arg, times)

    util.tic("C fast")
    probs1 = argweaverc.argweaver_forward_algorithm(carg, seqs, times=times)
    util.toc()

    util.tic("C slow")
    probs2 = argweaverc.argweaver_forward_algorithm(carg,
                                                    seqs,
                                                    times=times,
                                                    slow=True)
    util.toc()

    for i, (col1, col2) in enumerate(izip(probs1, probs2)):
        for a, b in izip(col1, col2):
            fequal(a, b, rel=.0001)
コード例 #9
0
ファイル: test_hmm.py プロジェクト: bredelings/argweaver
def test_forward():

    k = 4
    n = 1e4
    rho = 1.5e-8 * 20
    mu = 2.5e-8 * 20
    length = int(100e3 / 20)
    times = argweaver.get_time_points(ntimes=100)

    arg = arglib.sample_arg_smc(k, 2*n, rho, start=0, end=length)
    muts = arglib.sample_arg_mutations(arg, mu)
    seqs = arglib.make_alignment(arg, muts)

    print "muts", len(muts)
    print "recomb", len(arglib.get_recombs(arg))

    argweaver.discretize_arg(arg, times)

    # remove chrom
    new_name = "n%d" % (k - 1)
    arg = argweaver.remove_arg_thread(arg, new_name)

    carg = argweaverc.arg2ctrees(arg, times)

    util.tic("C fast")
    probs1 = argweaverc.argweaver_forward_algorithm(carg, seqs, times=times)
    util.toc()

    util.tic("C slow")
    probs2 = argweaverc.argweaver_forward_algorithm(carg, seqs, times=times,
                                                    slow=True)
    util.toc()

    for i, (col1, col2) in enumerate(izip(probs1, probs2)):
        for a, b in izip(col1, col2):
            fequal(a, b, rel=.0001)
コード例 #10
0
ファイル: test_prog.py プロジェクト: bredelings/argweaver
def show_plots(arg_file, sites_file, stats_file, output_prefix,
               rho, mu, popsize, ntimes=20, maxtime=200000):
    """
    Show plots of convergence.
    """

    # read true arg and seqs
    times = argweaver.get_time_points(ntimes=ntimes, maxtime=maxtime)
    arg = arglib.read_arg(arg_file)
    argweaver.discretize_arg(arg, times, ignore_top=False, round_age="closer")
    arg = arglib.smcify_arg(arg)
    seqs = argweaver.sites2seqs(argweaver.read_sites(sites_file))

    # compute true stats
    arglen = arglib.arglen(arg)
    arg = argweaverc.arg2ctrees(arg, times)
    nrecombs = argweaverc.get_local_trees_ntrees(arg[0]) - 1
    lk = argweaverc.calc_likelihood(
        arg, seqs, mu=mu, times=times,
        delete_arg=False)
    prior = argweaverc.calc_prior_prob(
        arg, rho=rho, times=times, popsizes=popsize,
                        delete_arg=False)
    joint = lk + prior

    data = read_table(stats_file)

    # joint
    y2 = joint
    y = data.cget("joint")
    rplot_start(output_prefix + ".trace.joint.pdf", width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="joint probability",
            xlab="iterations",
            ylab="joint probability")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)

    # lk
    y2 = lk
    y = data.cget("likelihood")
    rplot_start(output_prefix + ".trace.lk.pdf", width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="likelihood",
            xlab="iterations",
            ylab="likelihood")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)

    # prior
    y2 = prior
    y = data.cget("prior")
    rplot_start(output_prefix + ".trace.prior.pdf", width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="prior probability",
            xlab="iterations",
            ylab="prior probability")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)

    # nrecombs
    y2 = nrecombs
    y = data.cget("recombs")
    rplot_start(output_prefix + ".trace.nrecombs.pdf",
                width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="number of recombinations",
            xlab="iterations",
            ylab="number of recombinations")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)

    # arglen
    y2 = arglen
    y = data.cget("arglen")
    rplot_start(output_prefix + ".trace.arglen.pdf",
                width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="ARG branch length",
            xlab="iterations",
            ylab="ARG branch length")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)
コード例 #11
0
def sample_dsmc_sprs(k,
                     popsize,
                     rho,
                     recombmap=None,
                     start=0.0,
                     end=0.0,
                     times=None,
                     times2=None,
                     init_tree=None,
                     names=None,
                     make_names=True):
    """
    Sample ARG using Discrete Sequentially Markovian Coalescent (SMC)

    k          -- chromosomes
    popsize    -- effective population size (haploid)
    rho        -- recombination rate (recombinations / site / generation)
    recombmap  -- map for variable recombination rate
    start      -- staring chromosome coordinate
    end        -- ending chromsome coordinate
    t          -- initial time (default: 0)
    names      -- names to use for leaves (default: None)
    make_names -- make names using strings (default: True)
    """

    assert times is not None
    assert times2 is not None
    ntimes = len(times) - 1
    time_steps = [times[i] - times[i - 1] for i in range(1, ntimes + 1)]
    #    times2 = get_coal_times(times)

    if hasattr(popsize, "__len__"):
        popsizes = popsize
    else:
        popsizes = [popsize] * len(time_steps)

    # yield initial tree first
    if init_tree is None:
        init_tree = sample_tree(k,
                                popsizes,
                                times,
                                start=start,
                                end=end,
                                names=names,
                                make_names=make_names)
        argweaver.discretize_arg(init_tree, times2)
    yield init_tree

    # sample SPRs
    pos = start
    tree = init_tree.copy()
    while True:
        # sample next recomb point
        treelen = sum(x.get_dist() for x in tree)
        blocklen = int(
            sample_next_recomb(treelen,
                               rho,
                               pos=pos,
                               recombmap=recombmap,
                               minlen=1))
        pos += blocklen
        if pos >= end - 1:
            break

        root_age_index = times.index(tree.root.age)

        # choose time interval for recombination
        states = set(argweaver.iter_coal_states(tree, times))
        nbranches, nrecombs, ncoals = argweaver.get_nlineages_recomb_coal(
            tree, times)
        probs = [
            nbranches[i] * time_steps[i] for i in range(root_age_index + 1)
        ]
        recomb_time_index = stats.sample(probs)
        recomb_time = times[recomb_time_index]

        # choose branch for recombination
        branches = [
            x for x in states
            if x[1] == recomb_time_index and x[0] != tree.root.name
        ]
        recomb_node = tree[random.sample(branches, 1)[0][0]]

        # choose coal time
        j = recomb_time_index
        last_kj = nbranches[max(j - 1, 0)]
        while j < ntimes - 1:
            kj = nbranches[j]
            if ((recomb_node.name, j) in states
                    and recomb_node.parents[0].age > times[j]):
                kj -= 1
            assert kj > 0, (j, root_age_index, states)

            A = (times2[2 * j + 1] - times2[2 * j]) * kj
            if j > recomb_time_index:
                A += (times2[2 * j] - times2[2 * j - 1]) * last_kj
            coal_prob = 1.0 - exp(-A / float(popsizes[j]))
            if random.random() < coal_prob:
                break
            j += 1
            last_kj = kj
        coal_time_index = j
        coal_time = times[j]

        # choose coal node
        # since coal points collapse, exclude parent node, but allow sibling
        exclude = []

        def walk(node):
            exclude.append(node.name)
            if node.age == coal_time:
                for child in node.children:
                    walk(child)

        walk(recomb_node)
        exclude2 = (recomb_node.parents[0].name,
                    times.index(recomb_node.parents[0].age))
        branches = [
            x for x in states if x[1] == coal_time_index
            and x[0] not in exclude and x != exclude2
        ]
        coal_node = tree[random.sample(branches, 1)[0][0]]

        # yield SPR
        rleaves = list(tree.leaf_names(recomb_node))
        cleaves = list(tree.leaf_names(coal_node))

        yield pos, (rleaves, recomb_time), (cleaves, coal_time)

        # apply SPR to local tree
        broken = recomb_node.parents[0]
        recoal = tree.new_node(age=coal_time,
                               children=[recomb_node, coal_node])

        # add recoal node to tree
        recomb_node.parents[0] = recoal
        broken.children.remove(recomb_node)
        if coal_node.parents:
            recoal.parents.append(coal_node.parents[0])
            util.replace(coal_node.parents[0].children, coal_node, recoal)
            coal_node.parents[0] = recoal
        else:
            coal_node.parents.append(recoal)

        # remove broken node
        broken_child = broken.children[0]
        if broken.parents:
            broken_child.parents[0] = broken.parents[0]
            util.replace(broken.parents[0].children, broken, broken_child)
        else:
            broken_child.parents.remove(broken)

        del tree.nodes[broken.name]
        tree.set_root()
コード例 #12
0
ファイル: test_prog.py プロジェクト: jjberg2/argweaver
def show_plots(arg_file, sites_file, stats_file, output_prefix,
               rho, mu, popsize, ntimes=20, maxtime=200000):
    """
    Show plots of convergence.
    """

    # read true arg and seqs
    times = argweaver.get_time_points(ntimes=ntimes, maxtime=maxtime)
    arg = arglib.read_arg(arg_file)
    argweaver.discretize_arg(arg, times, ignore_top=False, round_age="closer")
    arg = arglib.smcify_arg(arg)
    seqs = argweaver.sites2seqs(argweaver.read_sites(sites_file))

    # compute true stats
    arglen = arglib.arglen(arg)
    arg = argweaverc.arg2ctrees(arg, times)
    nrecombs = argweaverc.get_local_trees_ntrees(arg[0]) - 1
    lk = argweaverc.calc_likelihood(
        arg, seqs, mu=mu, times=times,
        delete_arg=False)
    prior = argweaverc.calc_prior_prob(
        arg, rho=rho, times=times, popsizes=popsize,
                        delete_arg=False)
    joint = lk + prior

    data = read_table(stats_file)

    # joint
    y2 = joint
    y = data.cget("joint")
    rplot_start(output_prefix + ".trace.joint.pdf", width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="joint probability",
            xlab="iterations",
            ylab="joint probability")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)

    # lk
    y2 = lk
    y = data.cget("likelihood")
    rplot_start(output_prefix + ".trace.lk.pdf", width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="likelihood",
            xlab="iterations",
            ylab="likelihood")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)

    # prior
    y2 = prior
    y = data.cget("prior")
    rplot_start(output_prefix + ".trace.prior.pdf", width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="prior probability",
            xlab="iterations",
            ylab="prior probability")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)

    # nrecombs
    y2 = nrecombs
    y = data.cget("recombs")
    rplot_start(output_prefix + ".trace.nrecombs.pdf",
                width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="number of recombinations",
            xlab="iterations",
            ylab="number of recombinations")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)

    # arglen
    y2 = arglen
    y = data.cget("arglen")
    rplot_start(output_prefix + ".trace.arglen.pdf",
                width=8, height=5)
    rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)],
            main="ARG branch length",
            xlab="iterations",
            ylab="ARG branch length")
    rp.lines([0, len(y)], [y2, y2], col="gray")
    rplot_end(True)
コード例 #13
0
def test_trans_two():
    """
    Calculate transition probabilities for k=2

    Only calculate a single matrix
    """

    k = 2
    n = 1e4
    rho = 1.5e-8 * 20
    length = 1000
    times = argweaver.get_time_points(ntimes=5, maxtime=200000)
    time_steps = [times[i] - times[i - 1] for i in range(1, len(times))]
    time_steps.append(200000 * 10000.0)
    popsizes = [n] * len(times)

    arg = arglib.sample_arg(k, 2 * n, rho, start=0, end=length)

    argweaver.discretize_arg(arg, times)
    print "recomb", arglib.get_recomb_pos(arg)

    arg = argweaver.make_trunk_arg(0, length, "n0")

    pos = 10
    tree = arg.get_marginal_tree(pos)
    nlineages = argweaver.get_nlineages_recomb_coal(tree, times)
    states = list(argweaver.iter_coal_states(tree, times))
    mat = argweaver.calc_transition_probs(tree, states, nlineages, times,
                                          time_steps, popsizes, rho)

    nstates = len(states)

    def coal(j):
        return 1.0 - exp(-time_steps[j] / (2.0 * n))

    def recoal2(k, j):
        p = coal(j)
        for m in range(k, j):
            p *= 1.0 - coal(m)
        return p

    def recoal(k, j):
        if j == nstates - 1:
            return exp(-sum(time_steps[m] / (2.0 * n) for m in range(k, j)))
        else:
            return ((1.0 - exp(-time_steps[j] / (2.0 * n))) *
                    exp(-sum(time_steps[m] / (2.0 * n) for m in range(k, j))))

    def isrecomb(i):
        return 1.0 - exp(-max(rho * 2.0 * times[i], rho))

    def recomb(i, k):
        treelen = 2 * times[i] + time_steps[i]
        if k < i:
            return 2.0 * time_steps[k] / treelen / 2.0
        else:
            return time_steps[k] / treelen / 2.0

    def trans(i, j):
        a = states[i][1]
        b = states[j][1]

        p = sum(recoal(k, b) * recomb(a, k) for k in range(0, min(a, b) + 1))
        p += sum(recoal(k, b) * recomb(a, k) for k in range(0, min(a, b) + 1))
        p *= isrecomb(a)
        if i == j:
            p += 1.0 - isrecomb(a)
        return p

    for i in range(len(states)):
        for j in range(len(states)):
            print isrecomb(states[i][1])
            print states[i], states[j], mat[i][j], log(trans(i, j))
            fequal(mat[i][j], log(trans(i, j)))

        # recombs add up to 1
        fequal(sum(recomb(i, k) for k in range(i + 1)), 0.5)

        # recoal add up to 1
        fequal(sum(recoal(i, j) for j in range(i, nstates)), 1.0)

        # recomb * recoal add up to .5
        fequal(
            sum(
                sum(
                    recoal(k, j) * recomb(i, k)
                    for k in range(0,
                                   min(i, j) + 1))
                for j in range(0, nstates)), 0.5)

        fequal(sum(trans(i, j) for j in range(len(states))), 1.0)
コード例 #14
0
ファイル: sim.py プロジェクト: mjhubisz/argweaver
def sample_dsmc_sprs(
        k, popsize, rho, recombmap=None, start=0.0, end=0.0, times=None,
        times2=None, init_tree=None, names=None, make_names=True):
    """
    Sample ARG using Discrete Sequentially Markovian Coalescent (SMC)

    k          -- chromosomes
    popsize    -- effective population size (haploid)
    rho        -- recombination rate (recombinations / site / generation)
    recombmap  -- map for variable recombination rate
    start      -- staring chromosome coordinate
    end        -- ending chromsome coordinate
    t          -- initial time (default: 0)
    names      -- names to use for leaves (default: None)
    make_names -- make names using strings (default: True)
    """

    assert times is not None
    assert times2 is not None
    ntimes = len(times) - 1
    time_steps = [times[i] - times[i-1] for i in range(1, ntimes+1)]
#    times2 = get_coal_times(times)

    if hasattr(popsize, "__len__"):
        popsizes = popsize
    else:
        popsizes = [popsize] * len(time_steps)


    # yield initial tree first
    if init_tree is None:
        init_tree = sample_tree(k, popsizes, times, start=start, end=end,
                                names=names, make_names=make_names)
        argweaver.discretize_arg(init_tree, times2)
    yield init_tree

    # sample SPRs
    pos = start
    tree = init_tree.copy()
    while True:
        # sample next recomb point
        treelen = sum(x.get_dist() for x in tree)
        blocklen = int(sample_next_recomb(treelen, rho, pos=pos,
                                          recombmap=recombmap, minlen=1))
        pos += blocklen
        if pos >= end - 1:
            break

        root_age_index = times.index(tree.root.age)

        # choose time interval for recombination
        states = set(argweaver.iter_coal_states(tree, times))
        nbranches, nrecombs, ncoals = argweaver.get_nlineages_recomb_coal(
            tree, times)
        probs = [nbranches[i] * time_steps[i]
                 for i in range(root_age_index+1)]
        recomb_time_index = stats.sample(probs)
        recomb_time = times[recomb_time_index]

        # choose branch for recombination
        branches = [x for x in states if x[1] == recomb_time_index and
                    x[0] != tree.root.name]
        recomb_node = tree[random.sample(branches, 1)[0][0]]

        # choose coal time
        j = recomb_time_index
        last_kj = nbranches[max(j-1, 0)]
        while j < ntimes - 1:
            kj = nbranches[j]
            if ((recomb_node.name, j) in states and
                    recomb_node.parents[0].age > times[j]):
                kj -= 1
            assert kj > 0, (j, root_age_index, states)

            A = (times2[2*j+1] - times2[2*j]) * kj
            if j > recomb_time_index:
                A += (times2[2*j] - times2[2*j-1]) * last_kj
            coal_prob = 1.0 - exp(-A/float(popsizes[j]))
            if random.random() < coal_prob:
                break
            j += 1
            last_kj = kj
        coal_time_index = j
        coal_time = times[j]

        # choose coal node
        # since coal points collapse, exclude parent node, but allow sibling
        exclude = []

        def walk(node):
            exclude.append(node.name)
            if node.age == coal_time:
                for child in node.children:
                    walk(child)

        walk(recomb_node)
        exclude2 = (recomb_node.parents[0].name,
                    times.index(recomb_node.parents[0].age))
        branches = [x for x in states if x[1] == coal_time_index and
                    x[0] not in exclude and x != exclude2]
        coal_node = tree[random.sample(branches, 1)[0][0]]

        # yield SPR
        rleaves = list(tree.leaf_names(recomb_node))
        cleaves = list(tree.leaf_names(coal_node))

        yield pos, (rleaves, recomb_time), (cleaves, coal_time)

        # apply SPR to local tree
        broken = recomb_node.parents[0]
        recoal = tree.new_node(age=coal_time,
                               children=[recomb_node, coal_node])

        # add recoal node to tree
        recomb_node.parents[0] = recoal
        broken.children.remove(recomb_node)
        if coal_node.parents:
            recoal.parents.append(coal_node.parents[0])
            util.replace(coal_node.parents[0].children, coal_node, recoal)
            coal_node.parents[0] = recoal
        else:
            coal_node.parents.append(recoal)

        # remove broken node
        broken_child = broken.children[0]
        if broken.parents:
            broken_child.parents[0] = broken.parents[0]
            util.replace(broken.parents[0].children, broken, broken_child)
        else:
            broken_child.parents.remove(broken)

        del tree.nodes[broken.name]
        tree.set_root()
コード例 #15
0
ファイル: test_hmm.py プロジェクト: bredelings/argweaver
def test_trans_two():
    """
    Calculate transition probabilities for k=2

    Only calculate a single matrix
    """

    k = 2
    n = 1e4
    rho = 1.5e-8 * 20
    length = 1000
    times = argweaver.get_time_points(ntimes=5, maxtime=200000)
    time_steps = [times[i] - times[i-1]
                  for i in range(1, len(times))]
    time_steps.append(200000*10000.0)
    popsizes = [n] * len(times)

    arg = arglib.sample_arg(k, 2*n, rho, start=0, end=length)

    argweaver.discretize_arg(arg, times)
    print "recomb", arglib.get_recombs(arg)

    arg = argweaver.make_trunk_arg(0, length, "n0")

    pos = 10
    tree = arg.get_marginal_tree(pos)
    nlineages = argweaver.get_nlineages_recomb_coal(tree, times)
    states = list(argweaver.iter_coal_states(tree, times))
    mat = argweaver.calc_transition_probs(
        tree, states, nlineages,
        times, time_steps, popsizes, rho)

    nstates = len(states)

    def coal(j):
        return 1.0 - exp(-time_steps[j]/(2.0 * n))

    def recoal2(k, j):
        p = coal(j)
        for m in range(k, j):
            p *= 1.0 - coal(m)
        return p

    def recoal(k, j):
        if j == nstates-1:
            return exp(- sum(time_steps[m] / (2.0 * n)
                             for m in range(k, j)))
        else:
            return ((1.0 - exp(-time_steps[j]/(2.0 * n))) *
                    exp(- sum(time_steps[m] / (2.0 * n)
                              for m in range(k, j))))

    def isrecomb(i):
        return 1.0 - exp(-max(rho * 2.0 * times[i], rho))

    def recomb(i, k):
        treelen = 2*times[i] + time_steps[i]
        if k < i:
            return 2.0 * time_steps[k] / treelen / 2.0
        else:
            return time_steps[k] / treelen / 2.0

    def trans(i, j):
        a = states[i][1]
        b = states[j][1]

        p = sum(recoal(k, b) * recomb(a, k)
                for k in range(0, min(a, b)+1))
        p += sum(recoal(k, b) * recomb(a, k)
                 for k in range(0, min(a, b)+1))
        p *= isrecomb(a)
        if i == j:
            p += 1.0 - isrecomb(a)
        return p

    for i in range(len(states)):
        for j in range(len(states)):
            print isrecomb(states[i][1])
            print states[i], states[j], mat[i][j], log(trans(i, j))
            fequal(mat[i][j], log(trans(i, j)))

        # recombs add up to 1
        fequal(sum(recomb(i, k) for k in range(i+1)), 0.5)

        # recoal add up to 1
        fequal(sum(recoal(i, j) for j in range(i, nstates)), 1.0)

        # recomb * recoal add up to .5
        fequal(sum(sum(recoal(k, j) * recomb(i, k)
                       for k in range(0, min(i, j)+1))
                   for j in range(0, nstates)), 0.5)

        fequal(sum(trans(i, j) for j in range(len(states))), 1.0)