Exemplo n.º 1
0
    def test_iter_sprs_time(self):

        rho = 1.5e-8   # recomb/site/gen
        l = 100000     # length of locus
        k = 40         # number of lineages
        n = 2*10000    # effective popsize

        arg = arglib.sample_arg(k, n, rho, 0, l)

        x = list(arglib.iter_arg_sprs(arg))
        x = list(arglib.iter_arg_sprs_simple(arg))
        x = list(arglib.iter_arg_sprs(arg, use_leaves=True))
        x = list(arglib.iter_arg_sprs_simple(arg, use_leaves=True))
        x
Exemplo n.º 2
0
def test_trans_switch():
    """
    Calculate transition probabilities for switch matrix

    Only calculate a single matrix
    """

    k = 12
    n = 1e4
    rho = 1.5e-8 * 20
    length = 1000
    times = argweaver.get_time_points(ntimes=20, maxtime=200000)
    popsizes = [n] * len(times)

    recombs = []

    while len(recombs) == 0:
        arg = argweaver.sample_arg_dsmc(k,
                                        2 * n,
                                        rho,
                                        start=0,
                                        end=length,
                                        times=times)
        recombs = [x.pos for x in arg if x.event == "recomb"]

    pos = recombs[0]
    tree = arg.get_marginal_tree(pos - .5)
    rpos, r, c = arglib.iter_arg_sprs(arg, start=pos - .5).next()
    spr = (r, c)

    assert argweaverc.assert_transition_switch_probs(tree, spr, times,
                                                     popsizes, rho)
Exemplo n.º 3
0
    def test_trans_switch(self):
        """
        Calculate transition probabilities for k=2

        Only calculate a single matrix
        """

        k = 12
        n = 1e4
        rho = 1.5e-8 * 20
        mu = 2.5e-8 * 20
        length = 1000
        times = arghmm.get_time_points(ntimes=20, maxtime=200000)
        popsizes = [n] * len(times)

        recombs = []

        while len(recombs) == 0:
            arg = arghmm.sample_arg_dsmc(k, 2*n, rho, start=0, end=length,
                                         times=times)
            recombs = [x.pos for x in arg if x.event == "recomb"]

        pos = recombs[0]
        tree = arg.get_marginal_tree(pos-.5)
        rpos, r, c = arglib.iter_arg_sprs(arg, start=pos-.5).next()
        spr = (r, c)

        assert arghmm.assert_transition_switch_probs(tree, spr,
                                                     times, popsizes, rho)
Exemplo n.º 4
0
def test_arg_equal(arg, arg2):

    # test recomb points
    recombs = sorted(x.pos for x in arg if x.event == "recomb")
    recombs2 = sorted(x.pos for x in arg2 if x.event == "recomb")
    assert recombs == recombs2


    # check local tree topologies
    for (start, end), tree in arglib.iter_tree_tracks(arg):
        pos = (start + end) / 2.0

        arglib.remove_single_lineages(tree)
        tree1 = tree.get_tree()

        tree2 = arg2.get_marginal_tree(pos)
        arglib.remove_single_lineages(tree2)
        tree2 = tree2.get_tree()

        hash1 = phylo.hash_tree(tree1)
        hash2 = phylo.hash_tree(tree2)
        print
        print pos
        print hash1
        print hash2
        assert hash1 == hash2

    # check sprs
    sprs1 = arglib.iter_arg_sprs(arg, use_leaves=True)
    sprs2 = arglib.iter_arg_sprs(arg2, use_leaves=True)

    for (pos1, recomb1, coal1), (pos2, recomb2, coal2) in izip(sprs1, sprs2):
        recomb1 = (sorted(recomb1[0]), recomb1[1])
        recomb2 = (sorted(recomb2[0]), recomb2[1])
        coal1 = (sorted(coal1[0]), coal1[1])
        coal2 = (sorted(coal2[0]), coal2[1])

        print
        print (pos1, recomb1, coal1)
        print (pos2, recomb2, coal2)

        # check pos, leaves, time
        assert pos1 == pos2
        assert recomb1 == recomb2
        assert coal1 == coal2
Exemplo n.º 5
0
def test_arg_equal(arg, arg2):

    # test recomb points
    recombs = sorted(x.pos for x in arg if x.event == "recomb")
    recombs2 = sorted(x.pos for x in arg2 if x.event == "recomb")
    assert recombs == recombs2

    # check local tree topologies
    for (start, end), tree in arglib.iter_tree_tracks(arg):
        pos = (start + end) / 2.0

        arglib.remove_single_lineages(tree)
        tree1 = tree.get_tree()

        tree2 = arg2.get_marginal_tree(pos)
        arglib.remove_single_lineages(tree2)
        tree2 = tree2.get_tree()

        hash1 = phylo.hash_tree(tree1)
        hash2 = phylo.hash_tree(tree2)
        print
        print pos
        print hash1
        print hash2
        assert hash1 == hash2

    # check sprs
    sprs1 = arglib.iter_arg_sprs(arg, use_leaves=True)
    sprs2 = arglib.iter_arg_sprs(arg2, use_leaves=True)

    for (pos1, recomb1, coal1), (pos2, recomb2, coal2) in izip(sprs1, sprs2):
        recomb1 = (sorted(recomb1[0]), recomb1[1])
        recomb2 = (sorted(recomb2[0]), recomb2[1])
        coal1 = (sorted(coal1[0]), coal1[1])
        coal2 = (sorted(coal2[0]), coal2[1])

        print
        print(pos1, recomb1, coal1)
        print(pos2, recomb2, coal2)

        # check pos, leaves, time
        assert pos1 == pos2
        assert recomb1 == recomb2
        assert coal1 == coal2
Exemplo n.º 6
0
    def test_smcify_arg(self):

        rho = 1.5e-8   # recomb/site/gen
        l = 100000     # length of locus
        k = 6          # number of lineages
        n = 2*10000    # effective popsize

        arg = arglib.sample_arg(k, n, rho, 0, l)
        arg = arglib.smcify_arg(arg)

        for pos, (rnode, rtime), (cnode, ctime) in arglib.iter_arg_sprs(arg):
            self.assertNotEqual(rnode, cnode)
Exemplo n.º 7
0
    def test_iter_sprs(self):

        rho = 1.5e-8   # recomb/site/gen
        l = 100000     # length of locus
        k = 6          # number of lineages
        n = 2*10000    # effective popsize

        arg = arglib.sample_arg(k, n, rho, 0, l)

        for a, b in izip(arglib.iter_arg_sprs(arg),
                         arglib.iter_arg_sprs_simple(arg)):
            self.assertEqual(a, b)
Exemplo n.º 8
0
def test_trans_switch():
    """
    Calculate transition probabilities for switch matrix

    Only calculate a single matrix
    """
    create_data = False
    if create_data:
        make_clean_dir('test/data/test_trans_switch')

    # model parameters
    k = 12
    n = 1e4
    rho = 1.5e-8 * 20
    length = 1000
    times = argweaver.get_time_points(ntimes=20, maxtime=200000)
    popsizes = [n] * len(times)
    ntests = 100

    # generate test data
    if create_data:
        for i in range(ntests):
            # Sample ARG with at least one recombination.
            while True:
                arg = argweaver.sample_arg_dsmc(k,
                                                2 * n,
                                                rho,
                                                start=0,
                                                end=length,
                                                times=times)
                if any(x.event == "recomb" for x in arg):
                    break
            arg.write('test/data/test_trans_switch/%d.arg' % i)

    for i in range(ntests):
        print('arg', i)
        arg = arglib.read_arg('test/data/test_trans_switch/%d.arg' % i)
        argweaver.discretize_arg(arg, times)
        recombs = [x.pos for x in arg if x.event == "recomb"]
        pos = recombs[0]
        tree = arg.get_marginal_tree(pos - .5)
        rpos, r, c = next(arglib.iter_arg_sprs(arg, start=pos - .5))
        spr = (r, c)

        if not argweaverc.assert_transition_switch_probs(
                tree, spr, times, popsizes, rho):
            tree2 = tree.get_tree()
            treelib.remove_single_children(tree2)
            treelib.draw_tree_names(tree2, maxlen=5, minlen=5)
            assert False
Exemplo n.º 9
0
def arg_equal(arg, arg2):

    # test recomb points
    recombs = sorted(x.pos for x in arg if x.event == "recomb")
    recombs2 = sorted(x.pos for x in arg2 if x.event == "recomb")
    nose.tools.assert_equal(recombs, recombs2)

    # check local tree topologies
    for (start, end), tree in arglib.iter_local_trees(arg):
        pos = (start + end) / 2.0

        arglib.remove_single_lineages(tree)
        tree1 = tree.get_tree()

        tree2 = arg2.get_marginal_tree(pos)
        arglib.remove_single_lineages(tree2)
        tree2 = tree2.get_tree()

        hash1 = phylo.hash_tree(tree1)
        hash2 = phylo.hash_tree(tree2)
        nose.tools.assert_equal(hash1, hash2)

    # check sprs
    sprs1 = arglib.iter_arg_sprs(arg, use_leaves=True)
    sprs2 = arglib.iter_arg_sprs(arg2, use_leaves=True)

    for (pos1, recomb1, coal1), (pos2, recomb2, coal2) in zip(sprs1, sprs2):
        recomb1 = (sorted(recomb1[0]), recomb1[1])
        recomb2 = (sorted(recomb2[0]), recomb2[1])
        coal1 = (sorted(coal1[0]), coal1[1])
        coal2 = (sorted(coal2[0]), coal2[1])

        # check pos, leaves, time
        nose.tools.assert_equal(pos1, pos2)
        nose.tools.assert_equal(recomb1, recomb2)
        nose.tools.assert_equal(coal1, coal2)
Exemplo n.º 10
0
def arg_equal(arg, arg2):

    # test recomb points
    recombs = sorted(x.pos for x in arg if x.event == "recomb")
    recombs2 = sorted(x.pos for x in arg2 if x.event == "recomb")
    nose.tools.assert_equal(recombs, recombs2)

    # check local tree topologies
    for (start, end), tree in arglib.iter_local_trees(arg):
        pos = (start + end) / 2.0

        arglib.remove_single_lineages(tree)
        tree1 = tree.get_tree()

        tree2 = arg2.get_marginal_tree(pos)
        arglib.remove_single_lineages(tree2)
        tree2 = tree2.get_tree()

        hash1 = phylo.hash_tree(tree1)
        hash2 = phylo.hash_tree(tree2)
        nose.tools.assert_equal(hash1, hash2)

    # check sprs
    sprs1 = arglib.iter_arg_sprs(arg, use_leaves=True)
    sprs2 = arglib.iter_arg_sprs(arg2, use_leaves=True)

    for (pos1, recomb1, coal1), (pos2, recomb2, coal2) in izip(sprs1, sprs2):
        recomb1 = (sorted(recomb1[0]), recomb1[1])
        recomb2 = (sorted(recomb2[0]), recomb2[1])
        coal1 = (sorted(coal1[0]), coal1[1])
        coal2 = (sorted(coal2[0]), coal2[1])

        # check pos, leaves, time
        nose.tools.assert_equal(pos1, pos2)
        nose.tools.assert_equal(recomb1, recomb2)
        nose.tools.assert_equal(coal1, coal2)
Exemplo n.º 11
0
    def test_iter_sprs_remove_thread(self):

        rho = 1.5e-8   # recomb/site/gen
        l = 100000     # length of locus
        k = 6          # number of lineages
        n = 2*10000    # effective popsize

        arg = arglib.sample_arg(k, n, rho, 0, l)
        remove_chroms = set("n%d" % (k-1))
        keep = [x for x in arg.leaf_names() if x not in remove_chroms]
        arg = arg.copy()
        arglib.subarg_by_leaf_names(arg, keep)

        for a, b in izip(arglib.iter_arg_sprs(arg),
                         arglib.iter_arg_sprs_simple(arg)):
            self.assertEqual(a, b)
Exemplo n.º 12
0
    def test_iter_sprs_leaves(self):

        rho = 1.5e-8   # recomb/site/gen
        l = 100000     # length of locus
        k = 40         # number of lineages
        n = 2*10000    # effective popsize

        arg = arglib.sample_arg(k, n, rho, 0, l)

        for a, b in izip(arglib.iter_arg_sprs(arg, use_leaves=True),
                         arglib.iter_arg_sprs_simple(arg, use_leaves=True)):
            a[1][0].sort()
            a[2][0].sort()
            b[1][0].sort()
            b[2][0].sort()
            self.assertEqual(a, b)
Exemplo n.º 13
0
def test_trans_switch():
    """
    Calculate transition probabilities for switch matrix

    Only calculate a single matrix
    """
    create_data = False
    if create_data:
        make_clean_dir('test/data/test_trans_switch')

    # model parameters
    k = 12
    n = 1e4
    rho = 1.5e-8 * 20
    length = 1000
    times = argweaver.get_time_points(ntimes=20, maxtime=200000)
    popsizes = [n] * len(times)
    ntests = 100

    # generate test data
    if create_data:
        for i in range(ntests):
            # Sample ARG with at least one recombination.
            while True:
                arg = argweaver.sample_arg_dsmc(
                    k, 2*n, rho, start=0, end=length, times=times)
                if any(x.event == "recomb" for x in arg):
                    break
            arg.write('test/data/test_trans_switch/%d.arg' % i)

    for i in range(ntests):
        print 'arg', i
        arg = arglib.read_arg('test/data/test_trans_switch/%d.arg' % i)
        argweaver.discretize_arg(arg, times)
        recombs = [x.pos for x in arg if x.event == "recomb"]
        pos = recombs[0]
        tree = arg.get_marginal_tree(pos-.5)
        rpos, r, c = arglib.iter_arg_sprs(arg, start=pos-.5).next()
        spr = (r, c)

        if not argweaverc.assert_transition_switch_probs(
                tree, spr, times, popsizes, rho):
            tree2 = tree.get_tree()
            treelib.remove_single_children(tree2)
            treelib.draw_tree_names(tree2, maxlen=5, minlen=5)
            assert False
Exemplo n.º 14
0
def layout_chroms(arg, start=None, end=None):

    if start is None:
        start = arg.start
    if end is None:
        end = arg.end

    tree = arg.get_marginal_tree(start)
    arglib.remove_single_lineages(tree)
    last_pos = start
    blocks = []
    leaf_layout = []

    layout_func = layout_tree_leaves
    #layout_func = layout_tree_leaves_even

    for spr in arglib.iter_arg_sprs(arg, start=start, end=end,
                                    use_leaves=True):
        print "layout", spr[0]
        blocks.append([last_pos, spr[0]])
        leaf_layout.append(layout_func(tree))
        inorder = dict((n, i) for i, n in enumerate(inorder_tree(tree)))

        # determine SPR nodes
        rnode = arglib.arg_lca(tree, spr[1][0], spr[0])
        cnode = arglib.arg_lca(tree, spr[2][0], spr[0])

        # determine best side for adding new sister
        left = (inorder[rnode] < inorder[cnode])

        # apply spr
        arglib.apply_spr(tree, rnode, spr[1][1], cnode, spr[2][1], spr[0])

        # adjust sister
        rindex = rnode.parents[0].children.index(rnode)
        if left and rindex != 0:
            rnode.parents[0].children.reverse()

        last_pos = spr[0]

    blocks.append([last_pos, end])
    leaf_layout.append(layout_func(tree))

    return blocks, leaf_layout
def layout_chroms(arg, start=None, end=None):

    if start is None:
        start = arg.start
    if end is None:
        end = arg.end

    tree = arg.get_marginal_tree(start)
    arglib.remove_single_lineages(tree)
    last_pos = start
    blocks = []
    leaf_layout = []

    layout_func = layout_tree_leaves
    #layout_func = layout_tree_leaves_even

    for spr in arglib.iter_arg_sprs(arg, start=start, end=end, use_leaves=True):
        print "layout", spr[0]
        blocks.append([last_pos, spr[0]])
        leaf_layout.append(layout_func(tree))
        inorder = dict((n, i) for i, n in enumerate(inorder_tree(tree)))

        # determine SPR nodes
        rnode = arglib.arg_lca(tree, spr[1][0], spr[0])
        cnode = arglib.arg_lca(tree, spr[2][0], spr[0])

        # determine best side for adding new sister
        left = (inorder[rnode] < inorder[cnode])

        # apply spr
        arglib.apply_spr(tree, rnode, spr[1][1], cnode, spr[2][1], spr[0])

        # adjust sister
        rindex = rnode.parents[0].children.index(rnode)
        if left and rindex != 0:
            rnode.parents[0].children.reverse()

        last_pos = spr[0]

    blocks.append([last_pos, end])
    leaf_layout.append(layout_func(tree))

    return blocks, leaf_layout
Exemplo n.º 16
0
    def add_arg(self, arg):

        nleaves = len(list(arg.leaves()))
        times = self.times
        assert times
        eps = 1e-3

        def get_local_children(node, pos, local):
            return set(child for child in arg.get_local_children(node, pos) if child in local)

        def get_parent(node, pos, local):
            parent = arg.get_local_parent(node, pos)
            while len(get_local_children(parent, pos, local)) == 1:
                parent = arg.get_local_parent(parent, pos)
            return parent

        # add initial tree
        tree = arg.get_marginal_tree(arg.start)
        starts, ends, time_steps = count_tree_lineages(tree, times)
        self.init_trees.append({"starts": starts, "ends": ends, "time_steps": time_steps})

        # loop through sprs
        for recomb_pos, (rnode, rtime), (cnode, ctime), local in arglib.iter_arg_sprs(arg, use_local=True):
            i, _ = util.binsearch(times, ctime)
            self.ncoals[i] += 1

            recomb_node = arg[rnode]
            broken_node = get_parent(recomb_node, recomb_pos - eps, local)
            coals = [0.0] + [node.age for node in local if len(get_local_children(node, recomb_pos - eps, local)) == 2]

            coals.sort()
            nlineages = range(nleaves, 0, -1)
            assert len(nlineages) == len(coals)

            # subtract broken branch
            r = coals.index(recomb_node.age)
            r2 = coals.index(broken_node.age)
            for i in range(r, r2):
                nlineages[i] -= 1

            # get average number of branches in the time interval
            data = zip(coals, nlineages)
            for t in times[1:]:
                data.append((t, "time step"))
            data.sort()

            lineages_per_time = []
            counts = []
            last_lineages = 0
            last_time = 0.0
            for a, b in data:
                if b != "time step":
                    if a > last_time:
                        counts.append((last_lineages, a - last_time))
                    last_lineages = b
                else:
                    counts.append((last_lineages, a - last_time))
                    s = sum(u * v for u, v in counts)
                    total_time = sum(v for u, v in counts)
                    if s == 0.0:
                        lineages_per_time.append(last_lineages)
                    else:
                        lineages_per_time.append(s / total_time)
                    counts = []
                last_time = a

            assert len(lineages_per_time) == len(self.time_steps)

            r, _ = util.binsearch(times, rtime)
            c, _ = util.binsearch(times, ctime)
            for j in range(r, c):
                self.k_lineages[j] += lineages_per_time[j]
Exemplo n.º 17
0
    def add_arg(self, arg):

        nleaves = len(list(arg.leaves()))
        times = self.times
        assert times
        eps = 1e-3

        def get_local_children(node, pos, local):
            return set(child for child in arg.get_local_children(node, pos)
                       if child in local)

        def get_parent(node, pos, local):
            parent = arg.get_local_parent(node, pos)
            while len(get_local_children(parent, pos, local)) == 1:
                parent = arg.get_local_parent(parent, pos)
            return parent

        # add initial tree
        tree = arg.get_marginal_tree(arg.start)
        starts, ends, time_steps = count_tree_lineages(tree, times)
        self.init_trees.append({
            "starts": starts,
            "ends": ends,
            "time_steps": time_steps
        })

        # loop through sprs
        for recomb_pos, (rnode, rtime), (cnode, ctime), local in \
                arglib.iter_arg_sprs(arg, use_local=True):
            i, _ = util.binsearch(times, ctime)
            self.ncoals[i] += 1

            recomb_node = arg[rnode]
            broken_node = get_parent(recomb_node, recomb_pos - eps, local)
            coals = [0.0] + [
                node.age for node in local
                if len(get_local_children(node, recomb_pos - eps, local)) == 2
            ]

            coals.sort()
            nlineages = list(range(nleaves, 0, -1))
            assert len(nlineages) == len(coals)

            # subtract broken branch
            r = coals.index(recomb_node.age)
            r2 = coals.index(broken_node.age)
            for i in range(r, r2):
                nlineages[i] -= 1

            # get average number of branches in the time interval
            data = list(zip(coals, nlineages))
            for t in times[1:]:
                data.append((t, "time step"))
            data.sort()

            lineages_per_time = []
            counts = []
            last_lineages = 0
            last_time = 0.0
            for a, b in data:
                if b != "time step":
                    if a > last_time:
                        counts.append((last_lineages, a - last_time))
                    last_lineages = b
                else:
                    counts.append((last_lineages, a - last_time))
                    s = sum(u * v for u, v in counts)
                    total_time = sum(v for u, v in counts)
                    if s == 0.0:
                        lineages_per_time.append(last_lineages)
                    else:
                        lineages_per_time.append(s / total_time)
                    counts = []
                last_time = a

            assert len(lineages_per_time) == len(self.time_steps)

            r, _ = util.binsearch(times, rtime)
            c, _ = util.binsearch(times, ctime)
            for j in range(r, c):
                self.k_lineages[j] += lineages_per_time[j]