Example #1
0
 def test_total_branch_length(self):
     recombination_rate = 1e-8
     Ne = 5000
     sample_size = 5
     length = 6e2
     ts_full = msprime.simulate(sample_size=sample_size,
                                Ne=Ne,
                                length=length,
                                mutation_rate=1e-8,
                                recombination_rate=recombination_rate,
                                random_seed=20,
                                record_full_arg=True)
     tsarg = treeSequence.TreeSeq(ts_full)
     tsarg.ts_to_argnode()
     argnode = tsarg.arg
     true_total_branch_length =\
         (ts_full.tables.nodes.time[5] * 600) + (ts_full.tables.nodes.time[5] * 600) + \
                      (ts_full.tables.nodes.time[6] * 600) + (ts_full.tables.nodes.time[6] * 600) + \
                     ((ts_full.tables.nodes.time[7] - ts_full.tables.nodes.time[6]) * 600) + \
                     ((ts_full.tables.nodes.time[9] - ts_full.tables.nodes.time[8]) * 554) + \
                      ((ts_full.tables.nodes.time[9] - ts_full.tables.nodes.time[3]) * 600) + \
                      ((ts_full.tables.nodes.time[10] - ts_full.tables.nodes.time[9]) * 600) + \
                      ((ts_full.tables.nodes.time[12] - ts_full.tables.nodes.time[10]) * 448) + \
                      ((ts_full.tables.nodes.time[12] - ts_full.tables.nodes.time[11]) * 152) + \
                      ((ts_full.tables.nodes.time[13] - ts_full.tables.nodes.time[5]) * 600) +  \
                      ((ts_full.tables.nodes.time[13] - ts_full.tables.nodes.time[7]) * 46) + \
                      ((ts_full.tables.nodes.time[14] - ts_full.tables.nodes.time[13]) * 600) + \
                      ((ts_full.tables.nodes.time[14] - ts_full.tables.nodes.time[12]) * 600)
     self.assertTrue(
         math.isclose(true_total_branch_length,
                      argnode.total_branch_length()))
     #------ also verify other methdods of arg
     self.assertTrue(argnode.__contains__(14))
Example #2
0
 def test_dump_and_load(self):
     recombination_rate = 1e-8
     Ne = 5000
     sample_size = 25
     length = 6e2
     ts_full = msprime.simulate(sample_size=sample_size,
                                Ne=Ne,
                                length=length,
                                mutation_rate=1e-8,
                                recombination_rate=recombination_rate,
                                random_seed=20,
                                record_full_arg=True)
     tsarg = treeSequence.TreeSeq(ts_full)
     tsarg.ts_to_argnode()
     argnode = tsarg.arg
     argnode.dump(path=os.getcwd(), file_name='pickle_out.arg')
     loaded_arg = argbook.ARG().load(path=os.getcwd() + '/pickle_out.arg')
     for key in loaded_arg.nodes:
         self.assertEqual(argnode[key].index, loaded_arg[key].index)
         self.assertEqual(sorted(argnode[key].snps),
                          sorted(loaded_arg[key].snps))
         if loaded_arg[key].left_parent is not None:
             self.assertEqual(argnode[key].left_parent.index,
                              loaded_arg[key].left_parent.index)
             self.assertEqual(argnode[key].right_parent.index,
                              loaded_arg[key].right_parent.index)
             self.assertEqual(argnode[key].first_segment.left,
                              loaded_arg[key].first_segment.left)
             self.assertEqual(argnode[key].first_segment.right,
                              loaded_arg[key].first_segment.right)
             self.assertEqual(sorted(argnode[key].first_segment.samples),
                              sorted(loaded_arg[key].first_segment.samples))
             tail_argnode = argnode[key].get_tail()
             tail_loaded_arg = loaded_arg[key].get_tail()
             self.assertEqual(tail_argnode.left, tail_loaded_arg.left)
             self.assertEqual(tail_argnode.right, tail_loaded_arg.right)
             self.assertEqual(sorted(tail_argnode.samples),
                              sorted(tail_loaded_arg.samples))
         else:
             self.assertEqual(argnode[key].left_parent,
                              loaded_arg[key].left_parent)
             self.assertEqual(argnode[key].right_parent,
                              loaded_arg[key].right_parent)
             self.assertEqual(argnode[key].first_segment, None)
             self.assertEqual(loaded_arg[key].first_segment, None)
         if loaded_arg[key].left_child is not None:
             self.assertEqual(argnode[key].left_child.index,
                              loaded_arg[key].left_child.index)
             self.assertEqual(argnode[key].right_child.index,
                              loaded_arg[key].right_child.index)
         else:
             self.assertEqual(argnode[key].left_child,
                              loaded_arg[key].left_child)
             self.assertEqual(argnode[key].right_child,
                              loaded_arg[key].right_child)
         self.assertEqual(argnode[key].breakpoint,
                          loaded_arg[key].breakpoint)
Example #3
0
 def test_arg_allele_age(self):
     recombination_rate = 1e-8
     Ne = 5000
     sample_size = 5
     length = 6e4
     ts_full = msprime.simulate(sample_size=sample_size,
                                Ne=Ne,
                                length=length,
                                mutation_rate=1e-8,
                                recombination_rate=recombination_rate,
                                random_seed=20,
                                record_full_arg=True)
     tsarg = treeSequence.TreeSeq(ts_full)
     tsarg.ts_to_argnode()
     arg = tsarg.arg
     allele_age = arg.allele_age()
Example #4
0
 def test_arg_breakpoints(self):
     recombination_rate = 1e-8
     Ne = 5000
     sample_size = 5
     length = 6e4
     ts_full = msprime.simulate(sample_size=sample_size,
                                Ne=Ne,
                                length=length,
                                mutation_rate=1e-8,
                                recombination_rate=recombination_rate,
                                random_seed=20,
                                record_full_arg=True)
     tsarg = treeSequence.TreeSeq(ts_full)
     tsarg.ts_to_argnode()
     arg = tsarg.arg
     all_recomb_events = arg.breakpoints()
     ancestral_recomb_events = arg.breakpoints(only_ancRec=True)
Example #5
0
    def test_arg_copy_and_equal(self):
        recombination_rate = 1e-8
        Ne = 5000
        sample_size = 5
        length = 6e2
        ts_full = msprime.simulate(sample_size=sample_size,
                                   Ne=Ne,
                                   length=length,
                                   mutation_rate=1e-8,
                                   recombination_rate=recombination_rate,
                                   random_seed=20,
                                   record_full_arg=True)

        tsarg = treeSequence.TreeSeq(ts_full)
        tsarg.ts_to_argnode()
        arg = tsarg.arg
        arg_copy = arg.copy()
        self.assertTrue(arg.equal(arg_copy), True)
        self.assertFalse(arg == arg_copy, False)
Example #6
0
 def test_arg_coal_and_rec_nodes(self):
     recombination_rate = 1e-8
     Ne = 5000
     sample_size = 5
     length = 6e2
     ts_full = msprime.simulate(sample_size=sample_size,
                                Ne=Ne,
                                length=length,
                                mutation_rate=1e-8,
                                recombination_rate=recombination_rate,
                                random_seed=20,
                                record_full_arg=True)
     tsarg = treeSequence.TreeSeq(ts_full)
     tsarg.ts_to_argnode()
     argnode = tsarg.arg
     self.assertEqual(set([key for key in argnode.rec.keys()]),
                      set([7, 8, 10, 11]))
     self.assertEqual(set([key for key in argnode.coal.keys()]),
                      set([5, 6, 9, 13, 12, 14]))
     self.assertEqual(argnode.coal.__len__(), 6)
Example #7
0
 def test_arg_leaves(self):
     recombination_rate = 1e-8
     Ne = 5000
     sample_size = 5
     length = 6e2
     ts_full = msprime.simulate(sample_size=sample_size,
                                Ne=Ne,
                                length=length,
                                mutation_rate=1e-8,
                                recombination_rate=recombination_rate,
                                random_seed=20,
                                record_full_arg=True)
     tsarg = treeSequence.TreeSeq(ts_full)
     tsarg.ts_to_argnode()
     arg = tsarg.arg
     manual_leaves = [i for i in range(sample_size)]
     leaves = list(arg.leaves(arg.__getitem__(arg.roots.max_key())))
     get_leaves = []
     for node in leaves:
         get_leaves.append(node.index)
     self.assertTrue(sorted(get_leaves) == sorted(manual_leaves), True)
Example #8
0
    def test_verify_mutation_is_on_the_lowest_possible_node(self):
        recombination_rate = 1e-8
        Ne = 5000
        sample_size = 5
        length = 6e5
        ts_full = msprime.simulate(sample_size=sample_size,
                                   Ne=Ne,
                                   length=length,
                                   mutation_rate=1e-8,
                                   recombination_rate=recombination_rate,
                                   random_seed=20,
                                   record_full_arg=True)

        tsarg = treeSequence.TreeSeq(ts_full)
        tsarg.ts_to_argnode()
        argnode = tsarg.arg

        def verify_mutation_node(node, data):
            '''
            verify node is the lowest possible position
            the mutation can sit on.
            '''
            for x in node.snps:
                # bth children have x
                # left_child is not right_child
                # for the segment containing x on node, samples == data[x]
                if node.left_child is not None:
                    assert node.left_child.index is not node.right_child.index
                    assert node.left_child.contains(
                        x) and node.right_child.contains(x)
                node_samples = node.x_segment(x).samples
                # assert node samples contain all the derived for snp x.
                assert sorted(node_samples) == sorted(data[x])

        data = treeSequence.get_arg_genotype(ts_full)
        # print(data)
        for node in argnode.nodes.values():
            verify_mutation_node(node, data)
Example #9
0
 def test_tree_node_age_and_also_upward_path(self):
     #-------- ts_full from msprime:
     recombination_rate = 1e-8
     Ne = 5000
     sample_size = 5
     length = 6e2
     ts_full = msprime.simulate(sample_size=sample_size,
                                Ne=Ne,
                                length=length,
                                mutation_rate=1e-8,
                                recombination_rate=recombination_rate,
                                random_seed=20,
                                record_full_arg=True)
     # print(ts_full.tables.edges)
     # print(ts_full.tables.nodes.time)
     tsarg = treeSequence.TreeSeq(ts_full)
     tsarg.ts_to_argnode()
     argnode = tsarg.arg
     #-------------- the age of node 2 at position 10 ---- easy
     self.assertEqual(argnode[2].tree_node_age(10),
                      ts_full.tables.nodes[6].time)
     # now node 6 undergo recombination, then for x=10, it should give us time[13] -time[6]
     self.assertEqual(
         argnode[6].tree_node_age(10),
         ts_full.tables.nodes[13].time - ts_full.tables.nodes[6].time)
     #there is a back rec, which is great for this test x= 10 in node 9 age=time[14] - time[9]
     self.assertEqual(
         argnode[9].tree_node_age(10),
         ts_full.tables.nodes[14].time - ts_full.tables.nodes[9].time)
     # edges values
     self.assertEqual(
         argnode[9].tree_node_age(0),
         ts_full.tables.nodes[14].time - ts_full.tables.nodes[9].time)
     self.assertEqual(
         argnode[9].tree_node_age(447),
         ts_full.tables.nodes[14].time - ts_full.tables.nodes[9].time)
Example #10
0
 def test_log_prior(self):
     recombination_rate = 1e-8
     Ne = 5000
     sample_size = 5
     length = 6e2
     ts_full = msprime.simulate(sample_size=sample_size,
                                Ne=Ne,
                                length=length,
                                mutation_rate=1e-8,
                                recombination_rate=recombination_rate,
                                random_seed=20,
                                record_full_arg=True)
     tsarg = treeSequence.TreeSeq(ts_full)
     tsarg.ts_to_argnode()
     argnode = tsarg.arg
     r = 0.1
     k = 5  # number_of_lineages
     num_link = 5 * (length - 1)
     rate = (k * (k - 1) / (2 * 2 * Ne)) + (num_link * r)
     #ca node =5
     true_log_prior = 0
     true_log_prior -= rate * (ts_full.tables.nodes.time[5] - 0) + math.log(
         2 * Ne)
     num_link -= 599
     k = 4
     # ca, node =6
     rate = (k * (k - 1) / (2 * 2 * Ne)) + (num_link * r)
     true_log_prior  -= rate * (ts_full.tables.nodes.time[6] - ts_full.tables.nodes.time[5])+\
                              math.log(2*Ne)
     num_link -= 599
     k = 3
     #rec nodes 7, 8
     rate = (k * (k - 1) / (2 * 2 * Ne)) + (num_link * r)
     gap = 1
     true_log_prior -= rate * (ts_full.tables.nodes.time[7] -
                               ts_full.tables.nodes.time[6])
     true_log_prior += math.log(r)
     num_link -= 1
     k = 4
     # CA , node = 9
     rate = (k * (k - 1) / (2 * 2 * Ne)) + (num_link * r)
     true_log_prior  -= rate * (ts_full.tables.nodes.time[9] - ts_full.tables.nodes.time[8]) +\
                                  math.log(2*Ne)
     num_link -= 553
     k = 3
     #Rec , nodes = 10, 11
     rate = (k * (k - 1) / (2 * 2 * Ne)) + (num_link * r)
     gap = 1
     true_log_prior -= rate * (ts_full.tables.nodes.time[10] -
                               ts_full.tables.nodes.time[9])
     true_log_prior += math.log(r)
     num_link -= 1
     k = 4
     # CA, node= 12
     rate = (k * (k - 1) / (2 * 2 * Ne)) + (num_link * r)
     true_log_prior  -= rate * (ts_full.tables.nodes.time[12] - ts_full.tables.nodes.time[10]) +\
                                                                          math.log(2*Ne)
     num_link += 1
     k = 3
     # CA , node = 13
     rate = (k * (k - 1) / (2 * 2 * Ne)) + (num_link * r)
     true_log_prior  -= rate * (ts_full.tables.nodes.time[13] - ts_full.tables.nodes.time[12])+\
                                                                              math.log(2*Ne)
     num_link -= 45
     k = 2
     # CA, node 14
     rate = (k * (k - 1) / (2 * 2 * Ne)) + (num_link * r)
     true_log_prior  -= rate * (ts_full.tables.nodes.time[14] - ts_full.tables.nodes.time[13])+\
                                                                      math.log(2*Ne)
     num_link -= (599 + 599)
     k = 1
     #----- compare
     self.assertTrue(
         math.isclose(true_log_prior,
                      argnode.log_prior(sample_size, length, r, Ne)))
Example #11
0
    def test_log_likelihood(self):
        recombination_rate = 1e-8
        Ne = 5000
        sample_size = 5
        length = 6e2
        ts_full = msprime.simulate(sample_size=sample_size,
                                   Ne=Ne,
                                   length=length,
                                   mutation_rate=1e-8,
                                   recombination_rate=recombination_rate,
                                   random_seed=20,
                                   record_full_arg=True)

        tsarg = treeSequence.TreeSeq(ts_full)
        tsarg.ts_to_argnode()
        argnode = tsarg.arg
        data = treeSequence.get_arg_genotype(ts_full)

        # put some mutations on some nodes
        argnode[6].snps.__setitem__(101, 101)
        argnode[6].snps.__setitem__(10, 10)
        argnode[3].snps.__setitem__(20, 20)
        argnode[9].snps.__setitem__(448, 448)
        a = bintrees.AVLTree()
        a.update({2: 2, 4: 4})
        data[10] = a
        a = bintrees.AVLTree()
        a.update({2: 2, 4: 4})
        data[101] = a
        a = bintrees.AVLTree()
        a.update({3: 3})
        data[20] = a
        a = bintrees.AVLTree()
        a.update({2: 2, 4: 4, 3: 3})
        data[448] = a
        #print
        nodes_with_mutation = []
        for node in argnode.nodes.values():
            if node.snps:
                nodes_with_mutation.append(node)
                # print("node", node.index, node.snps)

        total_material = (ts_full.tables.nodes.time[5] * 600) + (ts_full.tables.nodes.time[5] * 600) + \
                         (ts_full.tables.nodes.time[6] * 600) + (ts_full.tables.nodes.time[6] * 600) + \
                        ((ts_full.tables.nodes.time[7] - ts_full.tables.nodes.time[6]) * 600) + \
                        ((ts_full.tables.nodes.time[9] - ts_full.tables.nodes.time[8]) * 554) + \
                         ((ts_full.tables.nodes.time[9] - ts_full.tables.nodes.time[3]) * 600) + \
                         ((ts_full.tables.nodes.time[10] - ts_full.tables.nodes.time[9]) * 600) + \
                         ((ts_full.tables.nodes.time[12] - ts_full.tables.nodes.time[10]) * 448) + \
                         ((ts_full.tables.nodes.time[12] - ts_full.tables.nodes.time[11]) * 152) + \
                         ((ts_full.tables.nodes.time[13] - ts_full.tables.nodes.time[5]) * 600) +  \
                         ((ts_full.tables.nodes.time[13] - ts_full.tables.nodes.time[7]) * 46) + \
                         ((ts_full.tables.nodes.time[14] - ts_full.tables.nodes.time[13]) * 600) + \
                         ((ts_full.tables.nodes.time[14] - ts_full.tables.nodes.time[12]) * 600)
        number_of_mutations = 6
        m = 6  # number of snps
        mu = 0.1
        true_log_likelihood = (
            number_of_mutations * math.log(total_material * mu) -
            total_material * mu)
        true_log_likelihood += math.log(
            (ts_full.tables.nodes.time[14] - ts_full.tables.nodes.time[3]) /
            total_material)  #x=20 , node=3
        true_log_likelihood += math.log(
            (ts_full.tables.nodes.time[14] - ts_full.tables.nodes.time[5]) /
            total_material)  #x=111 , node=5
        true_log_likelihood += math.log(
            (ts_full.tables.nodes.time[14] - ts_full.tables.nodes.time[5]) /
            total_material)  #x=558 , node=5
        true_log_likelihood += math.log(
            (ts_full.tables.nodes.time[13] - ts_full.tables.nodes.time[6]) /
            total_material)  #x=10 , node=6
        true_log_likelihood += math.log(
            (ts_full.tables.nodes.time[9] - ts_full.tables.nodes.time[6]) /
            total_material)  #x=101 , node=6
        true_log_likelihood += math.log(
            (ts_full.tables.nodes.time[14] - ts_full.tables.nodes.time[9]) /
            total_material)  #x=448 , node=9
        #----- log likelihood function
        self.assertTrue(
            math.isclose(true_log_likelihood, argnode.log_likelihood(mu,
                                                                     data)))
Example #12
0
    def test_ts_to_argnode(self):
        #-------- ts_full from msprime:
        recombination_rate = 1e-8
        Ne = 5000
        sample_size = 5
        length = 6e2
        ts_full = msprime.simulate(sample_size=sample_size,
                                   Ne=Ne,
                                   length=length,
                                   mutation_rate=1e-8,
                                   recombination_rate=recombination_rate,
                                   random_seed=20,
                                   record_full_arg=True)

        tsarg = treeSequence.TreeSeq(ts_full)
        tsarg.ts_to_argnode()
        argnode = tsarg.arg
        ##----- ts.edges dict
        edges_dict = collections.defaultdict(list)
        child_edges_dict = collections.defaultdict(list)
        for edge in ts_full.tables.edges:
            edges_dict[edge.parent].append(edge)
            child_edges_dict[edge.child].append(edge)
        # number of nodes
        self.assertEqual(
            len(edges_dict) + ts_full.sample_size, argnode.__len__())
        while edges_dict:
            parent = next(iter(edges_dict))
            if ts_full.tables.nodes[parent].flags == msprime.NODE_IS_RE_EVENT:
                child = edges_dict[parent][0].child
                parent2 = parent + 1
                #time
                self.assertEqual(ts_full.tables.nodes[parent].time,
                                 argnode[parent].time)
                self.assertEqual(ts_full.tables.nodes[parent2].time,
                                 argnode[parent2].time)
                # find breakpoint
                if edges_dict[parent][-1].right == edges_dict[parent2][0].left:
                    l_break = math.ceil(edges_dict[parent][-1].right)
                    r_break = None
                else:
                    r_break = math.ceil(edges_dict[parent2][0].left)
                    l_break = math.ceil(edges_dict[parent][-1].right)
                #breakpoints
                self.assertEqual(argnode[child].breakpoint, l_break)
                #assert left
                self.assertEqual(argnode[parent].first_segment.left,
                                 math.ceil(edges_dict[parent][0].left))
                self.assertEqual(argnode[parent].get_tail().right,
                                 math.ceil(edges_dict[parent][-1].right))
                #parent2
                self.assertEqual(argnode[parent2].first_segment.left,
                                 math.ceil(edges_dict[parent2][0].left))
                self.assertEqual(argnode[parent2].get_tail().right,
                                 math.ceil(edges_dict[parent2][-1].right))
                #child
                self.assertEqual(argnode[child].first_segment.left,
                                 math.ceil(edges_dict[parent][0].left))
                self.assertEqual(argnode[child].get_tail().right,
                                 math.ceil(edges_dict[parent2][-1].right))
                # parent or child
                self.assertEqual(argnode[child].left_parent.index, parent)
                self.assertEqual(argnode[child].right_parent.index, parent2)
                self.assertEqual(argnode[parent].left_child.index,
                                 argnode[parent2].left_child.index)
                self.assertEqual(argnode[parent].right_child.index,
                                 argnode[parent2].right_child.index)
                del edges_dict[parent]
                del edges_dict[parent2]
            else:  # CA
                child0 = edges_dict[parent][0].child
                child1 = edges_dict[parent][-1].child
                assert child0 != child1
                #time
                self.assertEqual(ts_full.tables.nodes[parent].time,
                                 argnode[parent].time)
                #child0
                self.assertEqual(argnode[child0].first_segment.left,
                                 math.ceil(child_edges_dict[child0][0].left))
                self.assertEqual(argnode[child0].get_tail().right,
                                 math.ceil(child_edges_dict[child0][-1].right))
                #child1
                self.assertEqual(argnode[child1].first_segment.left,
                                 math.ceil(child_edges_dict[child1][0].left))
                self.assertEqual(argnode[child1].get_tail().right,
                                 math.ceil(child_edges_dict[child1][-1].right))
                # left_parent
                self.assertEqual(argnode[child0].left_parent.index, parent)
                self.assertEqual(argnode[child0].right_parent.index, parent)
                self.assertEqual(argnode[child1].left_parent.index, parent)
                self.assertEqual(argnode[child1].right_parent.index, parent)
                #sibling
                self.assertEqual(argnode[child0].sibling().index, child1)
                self.assertEqual(argnode[child1].sibling().index, child0)
                #-----
                self.assertEqual(argnode[parent].left_child.index, child0)
                self.assertEqual(argnode[parent].right_child.index, child1)
                del edges_dict[parent]