Ejemplo n.º 1
0
    def test_est_popsize(self):
        """
        Fully sample an ARG from stratch using API
        """

        k = 50
        rho = 1.5e-8
        mu = 2.5e-8
        length = int(1e6)
        times = arghmm.get_time_points(ntimes=30, maxtime=200000)
        popsize = 1e4
        refine = 0

        util.tic("sim ARG")
        arg = arghmm.sample_arg_dsmc(k, 2 * popsize,
                                     rho, start=0, end=length, times=times)
        #arg = arglib.sample_arg_smc(k, 2 * popsize,
        #                            rho, start=0, end=length)
        #arg = arglib.sample_arg(k, 2 * popsize, rho, start=0, end=length)
        util.toc()

        x = []
        for tree in arglib.iter_marginal_trees(arg):
            arglib.remove_single_lineages(tree)
            x.append(mle_popsize_tree(tree, mintime=0))
        
        p = plot(x, ymin=0)
        p.plot([0, len(x)], [popsize, popsize], style='lines')
        
        pause()
Ejemplo n.º 2
0
    def test_est_popsize(self):
        """
        Fully sample an ARG from stratch using API
        """

        k = 50
        rho = 1.5e-8
        mu = 2.5e-8
        length = int(1e6)
        times = arghmm.get_time_points(ntimes=30, maxtime=200000)
        popsize = 1e4
        refine = 0

        util.tic("sim ARG")
        arg = arghmm.sample_arg_dsmc(k, 2 * popsize,
                                     rho, start=0, end=length, times=times)
        #arg = arglib.sample_arg_smc(k, 2 * popsize,
        #                            rho, start=0, end=length)
        #arg = arglib.sample_arg(k, 2 * popsize, rho, start=0, end=length)
        util.toc()

        x = []
        for tree in arglib.iter_marginal_trees(arg):
            arglib.remove_single_lineages(tree)
            x.append(mle_popsize_tree(tree, mintime=0))
        
        p = plot(x, ymin=0)
        p.plot([0, len(x)], [popsize, popsize], style='lines')
        
        pause()
Ejemplo n.º 3
0
def sample_arg_mutations(arg, mu, times=None):
    """
    Simulate mutations on an ARG.

    Mutations are represented as (node, parent, site, time).

    arg -- ARG on which to simulate mutations
    mu -- mutation rate (mutations/site/gen)
    times -- optional list of discretized time points
    """
    mutations = []
    minlen = times[1] * .1 if times else 0.0

    for (start, end), tree in arglib.iter_local_trees(arg):
        arglib.remove_single_lineages(tree)
        for node in tree:
            if not node.parents:
                continue
            blen = max(node.get_dist(), minlen)
            rate = blen * mu
            i = start
            while i < end:
                i += random.expovariate(rate)
                if i < end:
                    t = random.uniform(node.age, node.age + blen)
                    mutations.append((node, node.parents[0], int(i), t))
    return mutations
Ejemplo n.º 4
0
def sample_arg_mutations(arg, mu, times=None):
    """
    Simulate mutations on an ARG.

    Mutations are represented as (node, parent, site, time).

    arg -- ARG on which to simulate mutations
    mu -- mutation rate (mutations/site/gen)
    times -- optional list of discretized time points
    """
    mutations = []
    minlen = times[1] * .1 if times else 0.0

    for (start, end), tree in arglib.iter_local_trees(arg):
        arglib.remove_single_lineages(tree)
        for node in tree:
            if not node.parents:
                continue
            blen = max(node.get_dist(), minlen)
            rate = blen * mu
            i = start
            while i < end:
                i += random.expovariate(rate)
                if i < end:
                    t = random.uniform(node.age, node.age + blen)
                    mutations.append((node, node.parents[0], int(i), t))
    return mutations
Ejemplo n.º 5
0
    def test_marginal_leaves(self):

        rho = 1.5e-8   # recomb/site/gen
        l = 10000      # length of locus
        k = 10         # number of lineages
        n = 2*10000    # effective popsize

        arg = arglib.sample_arg(k, n, rho, 0, l)

        for (start, end), tree in arglib.iter_local_trees(arg):
            arglib.remove_single_lineages(tree)
            mid = (start + end) / 2.0
            for node in tree:
                a = set(tree.leaves(node))
                b = set(arglib.get_marginal_leaves(arg, node, mid))
                self.assertEqual(a, b)
Ejemplo n.º 6
0
def test_arg_equal(arg, arg2):

    # test recomb points
    recombs = sorted(x.pos for x in arg if x.event == "recomb")
    recombs2 = sorted(x.pos for x in arg2 if x.event == "recomb")
    assert recombs == recombs2


    # check local tree topologies
    for (start, end), tree in arglib.iter_tree_tracks(arg):
        pos = (start + end) / 2.0

        arglib.remove_single_lineages(tree)
        tree1 = tree.get_tree()

        tree2 = arg2.get_marginal_tree(pos)
        arglib.remove_single_lineages(tree2)
        tree2 = tree2.get_tree()

        hash1 = phylo.hash_tree(tree1)
        hash2 = phylo.hash_tree(tree2)
        print
        print pos
        print hash1
        print hash2
        assert hash1 == hash2

    # check sprs
    sprs1 = arglib.iter_arg_sprs(arg, use_leaves=True)
    sprs2 = arglib.iter_arg_sprs(arg2, use_leaves=True)

    for (pos1, recomb1, coal1), (pos2, recomb2, coal2) in izip(sprs1, sprs2):
        recomb1 = (sorted(recomb1[0]), recomb1[1])
        recomb2 = (sorted(recomb2[0]), recomb2[1])
        coal1 = (sorted(coal1[0]), coal1[1])
        coal2 = (sorted(coal2[0]), coal2[1])

        print
        print (pos1, recomb1, coal1)
        print (pos2, recomb2, coal2)

        # check pos, leaves, time
        assert pos1 == pos2
        assert recomb1 == recomb2
        assert coal1 == coal2
Ejemplo n.º 7
0
def layout_chroms(arg, start=None, end=None):

    if start is None:
        start = arg.start
    if end is None:
        end = arg.end

    tree = arg.get_marginal_tree(start)
    arglib.remove_single_lineages(tree)
    last_pos = start
    blocks = []
    leaf_layout = []

    layout_func = layout_tree_leaves
    #layout_func = layout_tree_leaves_even

    for spr in arglib.iter_arg_sprs(arg, start=start, end=end,
                                    use_leaves=True):
        print "layout", spr[0]
        blocks.append([last_pos, spr[0]])
        leaf_layout.append(layout_func(tree))
        inorder = dict((n, i) for i, n in enumerate(inorder_tree(tree)))

        # determine SPR nodes
        rnode = arglib.arg_lca(tree, spr[1][0], spr[0])
        cnode = arglib.arg_lca(tree, spr[2][0], spr[0])

        # determine best side for adding new sister
        left = (inorder[rnode] < inorder[cnode])

        # apply spr
        arglib.apply_spr(tree, rnode, spr[1][1], cnode, spr[2][1], spr[0])

        # adjust sister
        rindex = rnode.parents[0].children.index(rnode)
        if left and rindex != 0:
            rnode.parents[0].children.reverse()

        last_pos = spr[0]

    blocks.append([last_pos, end])
    leaf_layout.append(layout_func(tree))

    return blocks, leaf_layout
Ejemplo n.º 8
0
def test_arg_equal(arg, arg2):

    # test recomb points
    recombs = sorted(x.pos for x in arg if x.event == "recomb")
    recombs2 = sorted(x.pos for x in arg2 if x.event == "recomb")
    assert recombs == recombs2

    # check local tree topologies
    for (start, end), tree in arglib.iter_tree_tracks(arg):
        pos = (start + end) / 2.0

        arglib.remove_single_lineages(tree)
        tree1 = tree.get_tree()

        tree2 = arg2.get_marginal_tree(pos)
        arglib.remove_single_lineages(tree2)
        tree2 = tree2.get_tree()

        hash1 = phylo.hash_tree(tree1)
        hash2 = phylo.hash_tree(tree2)
        print
        print pos
        print hash1
        print hash2
        assert hash1 == hash2

    # check sprs
    sprs1 = arglib.iter_arg_sprs(arg, use_leaves=True)
    sprs2 = arglib.iter_arg_sprs(arg2, use_leaves=True)

    for (pos1, recomb1, coal1), (pos2, recomb2, coal2) in izip(sprs1, sprs2):
        recomb1 = (sorted(recomb1[0]), recomb1[1])
        recomb2 = (sorted(recomb2[0]), recomb2[1])
        coal1 = (sorted(coal1[0]), coal1[1])
        coal2 = (sorted(coal2[0]), coal2[1])

        print
        print(pos1, recomb1, coal1)
        print(pos2, recomb2, coal2)

        # check pos, leaves, time
        assert pos1 == pos2
        assert recomb1 == recomb2
        assert coal1 == coal2
def layout_chroms(arg, start=None, end=None):

    if start is None:
        start = arg.start
    if end is None:
        end = arg.end

    tree = arg.get_marginal_tree(start)
    arglib.remove_single_lineages(tree)
    last_pos = start
    blocks = []
    leaf_layout = []

    layout_func = layout_tree_leaves
    #layout_func = layout_tree_leaves_even

    for spr in arglib.iter_arg_sprs(arg, start=start, end=end, use_leaves=True):
        print "layout", spr[0]
        blocks.append([last_pos, spr[0]])
        leaf_layout.append(layout_func(tree))
        inorder = dict((n, i) for i, n in enumerate(inorder_tree(tree)))

        # determine SPR nodes
        rnode = arglib.arg_lca(tree, spr[1][0], spr[0])
        cnode = arglib.arg_lca(tree, spr[2][0], spr[0])

        # determine best side for adding new sister
        left = (inorder[rnode] < inorder[cnode])

        # apply spr
        arglib.apply_spr(tree, rnode, spr[1][1], cnode, spr[2][1], spr[0])

        # adjust sister
        rindex = rnode.parents[0].children.index(rnode)
        if left and rindex != 0:
            rnode.parents[0].children.reverse()

        last_pos = spr[0]

    blocks.append([last_pos, end])
    leaf_layout.append(layout_func(tree))

    return blocks, leaf_layout
Ejemplo n.º 10
0
    def test_est_popsize2(self):
        """
        Fully sample an ARG from stratch using API
        """

        k = 20
        rho = 1.5e-8
        mu = 2.5e-8
        length = int(4e6)
        popsize = 1e4
        popsize2 = 1e4 * .5
        a = int(.3 * length)
        b = int(.7 * length)
        refine = 0

        util.tic("sim ARG")
        arg = arglib.sample_arg_smc(k, 2 * popsize,
                                    rho, start=0, end=a)
        arg = arglib.sample_arg_smc(k, 2 * popsize2,
                                    rho, start=a, end=b,
                                    init_tree=arg)
        arg = arglib.sample_arg_smc(k, 2 * popsize,
                                    rho, start=b, end=length,
                                    init_tree=arg)

        util.toc()

        x = []; y = []
        for (start, end), tree in arglib.iter_tree_tracks(arg):
            arglib.remove_single_lineages(tree)
            x.append(start)
            y.append(mle_popsize_tree(tree, mintime=0))

        x2, y2 = stats.smooth2(x, y, 100e3)
        p = plot(x, y, ymin=0)
        p.plot(x2, y2, style='lines')
        p.plot([0, a, a, b, b, length],
               [popsize, popsize, popsize2, popsize2, popsize, popsize],
               style='lines')
        
        pause()
Ejemplo n.º 11
0
    def test_est_popsize2(self):
        """
        Fully sample an ARG from stratch using API
        """

        k = 20
        rho = 1.5e-8
        mu = 2.5e-8
        length = int(4e6)
        popsize = 1e4
        popsize2 = 1e4 * .5
        a = int(.3 * length)
        b = int(.7 * length)
        refine = 0

        util.tic("sim ARG")
        arg = arglib.sample_arg_smc(k, 2 * popsize,
                                    rho, start=0, end=a)
        arg = arglib.sample_arg_smc(k, 2 * popsize2,
                                    rho, start=a, end=b,
                                    init_tree=arg)
        arg = arglib.sample_arg_smc(k, 2 * popsize,
                                    rho, start=b, end=length,
                                    init_tree=arg)

        util.toc()

        x = []; y = []
        for (start, end), tree in arglib.iter_tree_tracks(arg):
            arglib.remove_single_lineages(tree)
            x.append(start)
            y.append(mle_popsize_tree(tree, mintime=0))

        x2, y2 = stats.smooth2(x, y, 100e3)
        p = plot(x, y, ymin=0)
        p.plot(x2, y2, style='lines')
        p.plot([0, a, a, b, b, length],
               [popsize, popsize, popsize2, popsize2, popsize, popsize],
               style='lines')
        
        pause()
Ejemplo n.º 12
0
    def test_pars_seq(self):
        """
        Test parsimony ancestral sequence inference
        """

        k = 10
        n = 1e4
        rho = 1.5e-8
        mu = 2.5e-8 * 100
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        pos = int(muts[0][2])
        tree = arg.get_marginal_tree(pos)

        print "pos =", pos
        treelib.draw_tree_names(tree.get_tree(), scale=4e-4, minlen=5)

        arglib.remove_single_lineages(tree)
        ancestral = arghmm.emit.parsimony_ancestral_seq(tree, seqs, pos)
        util.print_dict(ancestral)
Ejemplo n.º 13
0
    def test_pars_seq(self):
        """
        Test parsimony ancestral sequence inference
        """

        k = 10
        n = 1e4
        rho = 1.5e-8
        mu = 2.5e-8 * 100
        length = 1000
        arg = arglib.sample_arg(k, n, rho, start=0, end=length)
        muts = arglib.sample_arg_mutations(arg, mu)
        seqs = arglib.make_alignment(arg, muts)

        pos = int(muts[0][2])
        tree = arg.get_marginal_tree(pos)

        print "pos =", pos
        treelib.draw_tree_names(tree.get_tree(), scale=4e-4, minlen=5)

        arglib.remove_single_lineages(tree)
        ancestral = arghmm.emit.parsimony_ancestral_seq(tree, seqs, pos)
        util.print_dict(ancestral)
Ejemplo n.º 14
0
def arg_equal(arg, arg2):

    # test recomb points
    recombs = sorted(x.pos for x in arg if x.event == "recomb")
    recombs2 = sorted(x.pos for x in arg2 if x.event == "recomb")
    nose.tools.assert_equal(recombs, recombs2)

    # check local tree topologies
    for (start, end), tree in arglib.iter_local_trees(arg):
        pos = (start + end) / 2.0

        arglib.remove_single_lineages(tree)
        tree1 = tree.get_tree()

        tree2 = arg2.get_marginal_tree(pos)
        arglib.remove_single_lineages(tree2)
        tree2 = tree2.get_tree()

        hash1 = phylo.hash_tree(tree1)
        hash2 = phylo.hash_tree(tree2)
        nose.tools.assert_equal(hash1, hash2)

    # check sprs
    sprs1 = arglib.iter_arg_sprs(arg, use_leaves=True)
    sprs2 = arglib.iter_arg_sprs(arg2, use_leaves=True)

    for (pos1, recomb1, coal1), (pos2, recomb2, coal2) in zip(sprs1, sprs2):
        recomb1 = (sorted(recomb1[0]), recomb1[1])
        recomb2 = (sorted(recomb2[0]), recomb2[1])
        coal1 = (sorted(coal1[0]), coal1[1])
        coal2 = (sorted(coal2[0]), coal2[1])

        # check pos, leaves, time
        nose.tools.assert_equal(pos1, pos2)
        nose.tools.assert_equal(recomb1, recomb2)
        nose.tools.assert_equal(coal1, coal2)
Ejemplo n.º 15
0
def arg_equal(arg, arg2):

    # test recomb points
    recombs = sorted(x.pos for x in arg if x.event == "recomb")
    recombs2 = sorted(x.pos for x in arg2 if x.event == "recomb")
    nose.tools.assert_equal(recombs, recombs2)

    # check local tree topologies
    for (start, end), tree in arglib.iter_local_trees(arg):
        pos = (start + end) / 2.0

        arglib.remove_single_lineages(tree)
        tree1 = tree.get_tree()

        tree2 = arg2.get_marginal_tree(pos)
        arglib.remove_single_lineages(tree2)
        tree2 = tree2.get_tree()

        hash1 = phylo.hash_tree(tree1)
        hash2 = phylo.hash_tree(tree2)
        nose.tools.assert_equal(hash1, hash2)

    # check sprs
    sprs1 = arglib.iter_arg_sprs(arg, use_leaves=True)
    sprs2 = arglib.iter_arg_sprs(arg2, use_leaves=True)

    for (pos1, recomb1, coal1), (pos2, recomb2, coal2) in izip(sprs1, sprs2):
        recomb1 = (sorted(recomb1[0]), recomb1[1])
        recomb2 = (sorted(recomb2[0]), recomb2[1])
        coal1 = (sorted(coal1[0]), coal1[1])
        coal2 = (sorted(coal2[0]), coal2[1])

        # check pos, leaves, time
        nose.tools.assert_equal(pos1, pos2)
        nose.tools.assert_equal(recomb1, recomb2)
        nose.tools.assert_equal(coal1, coal2)
Ejemplo n.º 16
0
def draw_arg_threads(arg,
                     blocks,
                     layout,
                     sites=None,
                     chrom_colors=None,
                     chrom_color=[.2, .2, .8, .8],
                     snp_colors={
                         "compat": [1, 0, 0],
                         "noncompat": [0, 1, 0]
                     },
                     get_snp_color=None,
                     spr_alpha=1,
                     spr_trim=10,
                     compat=False,
                     draw_group=None,
                     win=None):

    leaf_names = set(arg.leaf_names())

    # TEST:
    rnodes = dict((r.pos, r) for r in arg if r.event == "recomb")

    if draw_group is None:
        draw_group = group()

    # set chromosome color
    if chrom_colors is None:
        chrom_colors = {}
        for name in leaf_names:
            chrom_colors[name] = chrom_color

    spr_colors = {}
    for name in leaf_names:
        spr_colors[name] = list(chrom_colors[name])
        if len(spr_colors[name]) < 4:
            spr_colors[name].append(1.0)
        spr_colors[name][3] *= spr_alpha

    trims = []

    for k, (x1, x2) in enumerate(blocks):
        # calc trims
        length = x2 - x1
        minlen = 0
        spr_trim2 = min(spr_trim, (length - minlen) / 2.0)
        trims.append((x1 + spr_trim2, x2 - spr_trim2))
        trim = trims[-1]

        # horizontal lines
        l = []
        for name in leaf_names:
            c = chrom_colors[name]
            y = layout[k][name]
            l.extend([color(*c), trim[0], y, trim[1], y])
        draw_group.append(lines(*l))

        # SPRs
        if k > 0:
            l = []

            # TEST:
            #rnode = rnodes.get(x1, None)
            #young = (rnode is not None and rnode.age < 500)

            for name in leaf_names:
                #c = [1,0,0] if young else spr_colors[name]
                c = spr_colors[name]
                y1 = layout[k - 1][name]
                y2 = layout[k][name]
                l.extend([color(*c), trims[k - 1][1], y1, trims[k][0], y2])

            draw_group.append(lines(*l))

        # hotspots
        g = group()
        for name in leaf_names:
            y = layout[k][name]
            g.append(
                hotspot("click", x1 + spr_trim, y + .4, x2 - spr_trim, y - .4,
                        chrom_click(win, name, (x1, x2))))
        draw_group.append(g)

        # SNPs
        tree = None
        if sites:
            l = []
            for pos, col in sites.iter_region(x1, x2):
                split = set(sites.get_minor(pos)) & leaf_names
                if len(split) == 0:
                    continue
                if compat:
                    if tree is None:
                        tree = arg.get_marginal_tree((x1 + x2) / 2.0)
                        arglib.remove_single_lineages(tree)
                    node = arglib.split_to_arg_branch(tree, pos - .5, split)
                    if node is not None:
                        derived = list(tree.leaf_names(node))
                        c = color(*snp_colors["compat"])
                    else:
                        c = color(*snp_colors["noncompat"])
                        derived = split
                else:
                    c = color(*snp_colors["compat"])
                    derived = split
                if get_snp_color and derived:
                    allele = sites.get(pos, next(iter(derived)))
                    c = color(*get_snp_color(arg.chrom, pos, allele))

                for d in derived:
                    if d in layout[k]:
                        y = layout[k][d]
                        l.extend([c, pos, y + .4, pos, y - .4])
            draw_group.append(lines(*l))

    return draw_group
Ejemplo n.º 17
0
def calc_transition_probs_switch_c(tree, last_tree, recomb_name,
                                   states1, states2,
                                   nlineages, times,
                                   time_steps, popsizes, rho, raw=True):

    times_lookup = dict((t, i) for i, t in enumerate(times))
    nbranches, nrecombs, ncoals = nlineages
    (recomb_branch, recomb_time), (coal_branch, coal_time) = \
        argweaver.find_recomb_coal(tree, last_tree, recomb_name=recomb_name)

    recomb_time = times.index(recomb_time)
    coal_time = times.index(coal_time)

    last_tree2 = last_tree.copy()
    arglib.remove_single_lineages(last_tree2)
    tree2 = tree.copy()
    arglib.remove_single_lineages(tree2)

    # get last ptree
    last_tree2 = last_tree2.get_tree()
    tree2 = tree2.get_tree()
    last_ptree, last_nodes, last_nodelookup = make_ptree(last_tree2)

    # find old node and new node
    recomb_parent = last_tree2[recomb_branch].parent
    recoal = [x for x in tree2 if x.name not in last_tree2][0]

    # make nodes array consistent
    nodes = [tree2.nodes.get(x.name, None) for x in last_nodes]
    i = last_nodes.index(recomb_parent)
    assert nodes[i] is None
    nodes[i] = recoal

    # get ptree
    ptree, nodes, nodelookup = make_ptree(tree2, nodes=nodes)

    # get recomb and coal branches
    recomb_name = last_nodelookup[last_tree2[recomb_branch]]
    coal_name = last_nodelookup[last_tree2[coal_branch]]

    int_states1 = [[last_nodelookup[last_tree2[node]], timei]
                   for node, timei in states1]
    nstates1 = len(int_states1)
    int_states2 = [[nodelookup[tree2[node]], timei]
                   for node, timei in states2]
    nstates2 = len(int_states2)

    last_ages_index = [times_lookup[last_tree[node.name].age]
                       for node in last_nodes]
    ages_index = [times_lookup[tree[node.name].age]
                  for node in nodes]

    last_treelen = sum(x.dist for x in last_tree2)
    treelen = sum(x.dist for x in tree2)

    transmat = new_transition_probs_switch(
        ptree, last_ptree, len(nodes),
        recomb_name, recomb_time, coal_name, coal_time,

        ages_index, last_ages_index,
        treelen, last_treelen,
        ((C.c_int * 2) * nstates1)
        (* ((C.c_int * 2)(n, t) for n, t in int_states1)), nstates1,
        ((C.c_int * 2) * nstates2)
        (* ((C.c_int * 2)(n, t) for n, t in int_states2)), nstates2,

        len(time_steps), times, time_steps,
        nbranches, nrecombs, ncoals,
        popsizes, rho)

    if raw:
        return transmat
    else:
        transmat2 = [transmat[j][:nstates2]
                     for j in range(nstates1)]
        delete_transition_probs(transmat, nstates1)
        return transmat2
Ejemplo n.º 18
0
    def test_est_arg_popsize(self):
        """
        Fully sample an ARG from stratch using API
        """

        k = 20
        rho = 1.5e-8 * 20
        mu = 2.5e-8 * 20
        length = int(2e6) / 20
        times = arghmm.get_time_points(ntimes=20, maxtime=200000)
        popsize = 1e4
        popsize2 = 1e4 * .5
        a = int(.3 * length)
        b = int(.7 * length)
        refine = 0

        util.tic("sim ARG")
        arg = arglib.sample_arg_smc(k, 2 * popsize,
                                    rho, start=0, end=a)
        arg = arglib.sample_arg_smc(k, 2 * popsize2,
                                    rho, start=a, end=b,
                                    init_tree=arg)
        arg = arglib.sample_arg_smc(k, 2 * popsize,
                                    rho, start=b, end=length,
                                    init_tree=arg)

        # sim seq
        mut = arghmm.sample_arg_mutations(arg, mu, times)
        seqs = arghmm.make_alignment(arg, mut)
        util.toc()

        # sample arg
        util.tic("sample arg")
        arg2 = arghmm.sample_arg(seqs, rho=rho, mu=mu, times=times,
                                 popsizes=1e4, carg=True)
        arg2 = arghmm.resample_climb_arg(arg2, seqs, popsizes=1e4, 
                                         rho=rho, mu=mu, times=times,
                                         refine=200)
        arg2 = arghmm.resample_all_arg(arg2, seqs, popsizes=1e4, 
                                       rho=rho, mu=mu, times=times,
                                       refine=200)
        util.toc()

        x = []; y = []
        for (start, end), tree in arglib.iter_tree_tracks(arg2):
            arglib.remove_single_lineages(tree)
            x.append(start)
            y.append(mle_popsize_tree(tree, mintime=0))

        # thin popsizes
        x2 = range(0, length, length//5000); y2 = []
        j = 0
        for i in range(len(x2)):
            while j < len(x) and x[j] < x2[i]:
                j += 1
            y2.append(y[min(j, len(y)-1)])

        x3, y3 = stats.smooth2(x2, y2, 100e3)
        p = plot(x, y, ymin=0)
        p.plot(x3, y3, style='lines')
        p.plot([0, a, a, b, b, length],
               [popsize, popsize, popsize2, popsize2, popsize, popsize],
               style='lines')
        
        pause()
def draw_arg_threads(arg, blocks, layout, sites=None,
                     chrom_colors=None, chrom_color=[.2,.2,.8,.8],
                     snp_colors={"compat": [1, 0, 0],
                                 "noncompat": [0, 1, 0]},
                     spr_alpha=1,
                     spr_trim=10,
                     compat=False,
                     draw_group=None,
                     win=None):

    leaf_names = set(arg.leaf_names())

    # TEST:
    rnodes = dict((r.pos, r) for r in arg if r.event == "recomb")

    if draw_group is None:
        draw_group = group()

    # set chromosome color
    if chrom_colors is None:
        chrom_colors = {}
        for name in leaf_names:
            chrom_colors[name] = chrom_color

    spr_colors = {}
    for name in leaf_names:
        spr_colors[name] = list(chrom_colors[name])
        if len(spr_colors[name]) < 4:
            spr_colors[name].append(1.0)
        spr_colors[name][3] *= spr_alpha


    trims = []

    for k, (x1, x2) in enumerate(blocks):
        # calc trims
        length = x2 - x1
        minlen =  0
        spr_trim2 = min(spr_trim, (length - minlen) / 2.0)
        trims.append((x1 + spr_trim2, x2 - spr_trim2))
        trim = trims[-1]

        # horizontal lines
        l = []
        for name in leaf_names:
            c = chrom_colors[name]
            y = layout[k][name]
            l.extend([color(*c), trim[0], y, trim[1], y])
        draw_group.append(lines(*l))

        # SPRs
        if k > 0:
            l = []

            # TEST:
            #rnode = rnodes.get(x1, None)
            #young = (rnode is not None and rnode.age < 500)

            for name in leaf_names:
                #c = [1,0,0] if young else spr_colors[name]
                c = spr_colors[name]
                y1 = layout[k-1][name]
                y2 = layout[k][name]
                l.extend([color(*c), trims[k-1][1], y1, trims[k][0], y2])

            draw_group.append(lines(*l))

        # hotspots
        g = group()
        for name in leaf_names:
            y = layout[k][name]
            g.append(hotspot("click", x1+spr_trim, y+.4, x2-spr_trim, y-.4,
                             chrom_click(win, name, (x1, x2))))
        draw_group.append(g)

        # SNPs
        tree = None
        if sites:
            l = []
            for pos, col in sites.iter_region(x1, x2):
                split = set(sites.get_minor(pos)) & leaf_names
                if len(split) == 0:
                    continue
                if compat:
                    if tree is None:
                        tree = arg.get_marginal_tree((x1+x2)/2.0)
                        arglib.remove_single_lineages(tree)
                    node = arglib.split_to_arg_branch(tree, pos-.5, split)
                    if node is not None:
                        derived = list(tree.leaf_names(node))
                        c = color(*snp_colors["compat"])
                    else:
                        c = color(*snp_colors["noncompat"])
                        derived = split
                else:
                    c = color(*snp_colors["compat"])
                    derived = split

                for d in derived:
                    if d in layout[k]:
                        y = layout[k][d]
                        l.extend([c, pos, y+.4, pos, y-.4])
            draw_group.append(lines(*l))

    return draw_group
Ejemplo n.º 20
0
    def test_est_arg_popsize(self):
        """
        Fully sample an ARG from stratch using API
        """

        k = 20
        rho = 1.5e-8 * 20
        mu = 2.5e-8 * 20
        length = int(2e6) / 20
        times = arghmm.get_time_points(ntimes=20, maxtime=200000)
        popsize = 1e4
        popsize2 = 1e4 * .5
        a = int(.3 * length)
        b = int(.7 * length)
        refine = 0

        util.tic("sim ARG")
        arg = arglib.sample_arg_smc(k, 2 * popsize,
                                    rho, start=0, end=a)
        arg = arglib.sample_arg_smc(k, 2 * popsize2,
                                    rho, start=a, end=b,
                                    init_tree=arg)
        arg = arglib.sample_arg_smc(k, 2 * popsize,
                                    rho, start=b, end=length,
                                    init_tree=arg)

        # sim seq
        mut = arghmm.sample_arg_mutations(arg, mu, times)
        seqs = arghmm.make_alignment(arg, mut)
        util.toc()

        # sample arg
        util.tic("sample arg")
        arg2 = arghmm.sample_arg(seqs, rho=rho, mu=mu, times=times,
                                 popsizes=1e4, carg=True)
        arg2 = arghmm.resample_climb_arg(arg2, seqs, popsizes=1e4, 
                                         rho=rho, mu=mu, times=times,
                                         refine=200)
        arg2 = arghmm.resample_all_arg(arg2, seqs, popsizes=1e4, 
                                       rho=rho, mu=mu, times=times,
                                       refine=200)
        util.toc()

        x = []; y = []
        for (start, end), tree in arglib.iter_tree_tracks(arg2):
            arglib.remove_single_lineages(tree)
            x.append(start)
            y.append(mle_popsize_tree(tree, mintime=0))

        # thin popsizes
        x2 = list(range(0, length, length//5000)); y2 = []
        j = 0
        for i in range(len(x2)):
            while j < len(x) and x[j] < x2[i]:
                j += 1
            y2.append(y[min(j, len(y)-1)])

        x3, y3 = stats.smooth2(x2, y2, 100e3)
        p = plot(x, y, ymin=0)
        p.plot(x3, y3, style='lines')
        p.plot([0, a, a, b, b, length],
               [popsize, popsize, popsize2, popsize2, popsize, popsize],
               style='lines')
        
        pause()