예제 #1
0
 def test_options(self):
     ts = self.get_example_tree_sequence()
     st = _tskit.Tree(ts)
     self.assertEqual(st.get_options(), 0)
     # We should still be able to count the samples, just inefficiently.
     self.assertEqual(st.get_num_samples(0), 1)
     self.assertRaises(_tskit.LibraryError, st.get_num_tracked_samples, 0)
     all_options = [
         0, _tskit.SAMPLE_COUNTS, _tskit.SAMPLE_LISTS,
         _tskit.SAMPLE_COUNTS | _tskit.SAMPLE_LISTS
     ]
     for options in all_options:
         tree = _tskit.Tree(ts, options=options)
         copy = tree.copy()
         for st in [tree, copy]:
             self.assertEqual(st.get_options(), options)
             self.assertEqual(st.get_num_samples(0), 1)
             if options & _tskit.SAMPLE_COUNTS:
                 self.assertEqual(st.get_num_tracked_samples(0), 0)
             else:
                 self.assertRaises(_tskit.LibraryError,
                                   st.get_num_tracked_samples, 0)
             if options & _tskit.SAMPLE_LISTS:
                 self.assertEqual(0, st.get_left_sample(0))
                 self.assertEqual(0, st.get_right_sample(0))
             else:
                 self.assertRaises(ValueError, st.get_left_sample, 0)
                 self.assertRaises(ValueError, st.get_right_sample, 0)
                 self.assertRaises(ValueError, st.get_next_sample, 0)
예제 #2
0
    def test_sample_list(self):
        options = _tskit.SAMPLE_COUNTS | _tskit.SAMPLE_LISTS
        # Note: we're assuming that samples are 0-n here.
        for ts in self.get_example_tree_sequences():
            t = _tskit.Tree(ts, options=options)
            while t.next():
                # All sample nodes should have themselves.
                for j in range(ts.get_num_samples()):
                    self.assertEqual(t.get_left_sample(j), j)
                    self.assertEqual(t.get_right_sample(j), j)

                # All non-tree nodes should have 0
                for j in range(t.get_num_nodes()):
                    if t.get_parent(j) == _tskit.NULL \
                            and t.get_left_child(j) == _tskit.NULL:
                        self.assertEqual(t.get_left_sample(j), _tskit.NULL)
                        self.assertEqual(t.get_right_sample(j), _tskit.NULL)
                # The roots should have all samples.
                u = t.get_left_root()
                samples = []
                while u != _tskit.NULL:
                    sample = t.get_left_sample(u)
                    end = t.get_right_sample(u)
                    while True:
                        samples.append(sample)
                        if sample == end:
                            break
                        sample = t.get_next_sample(sample)
                    u = t.get_right_sib(u)
                self.assertEqual(sorted(samples),
                                 list(range(ts.get_num_samples())))
예제 #3
0
 def test_index(self):
     for ts in self.get_example_tree_sequences():
         st = _tskit.Tree(ts)
         index = 0
         while st.next():
             self.assertEqual(index, st.get_index())
             index += 1
예제 #4
0
    def test_newick_precision(self):
        def get_times(tree):
            """
            Returns the time strings from the specified newick tree.
            """
            ret = []
            current_time = None
            for c in tree:
                if c == ":":
                    current_time = ""
                elif c in [",", ")"]:
                    ret.append(current_time)
                    current_time = None
                elif current_time is not None:
                    current_time += c
            return ret

        ts = self.get_example_tree_sequence()
        st = _tskit.Tree(ts)
        while st.next():
            self.assertRaises(ValueError, st.get_newick, root=0, precision=-1)
            self.assertRaises(ValueError, st.get_newick, root=0, precision=17)
            self.assertRaises(ValueError, st.get_newick, root=0, precision=100)
            for precision in range(17):
                tree = st.get_newick(root=st.get_left_root(),
                                     precision=precision).decode()
                times = get_times(tree)
                self.assertGreater(len(times), ts.get_num_samples())
                for t in times:
                    if precision == 0:
                        self.assertNotIn(".", t)
                    else:
                        point = t.find(".")
                        self.assertEqual(precision, len(t) - point - 1)
예제 #5
0
 def test_sites(self):
     for ts in self.get_example_tree_sequences():
         st = _tskit.Tree(ts)
         all_sites = [ts.get_site(j) for j in range(ts.get_num_sites())]
         all_tree_sites = []
         j = 0
         mutation_id = 0
         while st.next():
             tree_sites = st.get_sites()
             self.assertEqual(st.get_num_sites(), len(tree_sites))
             all_tree_sites.extend(tree_sites)
             for position, ancestral_state, mutations, index, metadata in tree_sites:
                 self.assertTrue(st.get_left() <= position < st.get_right())
                 self.assertEqual(index, j)
                 self.assertEqual(metadata, b"")
                 for mut_id in mutations:
                     site, node, derived_state, parent, metadata = \
                         ts.get_mutation(mut_id)
                     self.assertEqual(site, index)
                     self.assertEqual(mutation_id, mut_id)
                     self.assertNotEqual(st.get_parent(node), _tskit.NULL)
                     self.assertEqual(metadata, b"")
                     mutation_id += 1
                 j += 1
         self.assertEqual(all_tree_sites, all_sites)
예제 #6
0
 def test_copy(self):
     for ts in self.get_example_tree_sequences():
         t1 = _tskit.Tree(ts)
         t2 = t1.copy()
         self.assertEqual(t1.get_index(), t2.get_index())
         self.assertIsNot(t1, t2)
         while t1.next():
             t2 = t1.copy()
             self.assertEqual(t1.get_index(), t2.get_index())
예제 #7
0
 def test_mrca_interface(self):
     for ts in self.get_example_tree_sequences():
         num_nodes = ts.get_num_nodes()
         st = _tskit.Tree(ts)
         for v in [num_nodes, 10**6, _tskit.NULL]:
             self.assertRaises(ValueError, st.get_mrca, v, v)
             self.assertRaises(ValueError, st.get_mrca, v, 1)
             self.assertRaises(ValueError, st.get_mrca, 1, v)
         # All the mrcas for an uninitialised tree should be _tskit.NULL
         for u, v in itertools.combinations(range(num_nodes), 2):
             self.assertEqual(st.get_mrca(u, v), _tskit.NULL)
예제 #8
0
 def test_equality(self):
     last_ts = None
     for ts in self.get_example_tree_sequences():
         t1 = _tskit.Tree(ts)
         t2 = _tskit.Tree(ts)
         self.assertTrue(t1.equals(t2))
         self.assertTrue(t2.equals(t1))
         while True:
             self.assertTrue(t1.equals(t2))
             self.assertTrue(t2.equals(t1))
             n1 = t1.next()
             self.assertFalse(t1.equals(t2))
             self.assertFalse(t2.equals(t1))
             n2 = t2.next()
             self.assertEqual(n1, n2)
             if not n1:
                 break
         if last_ts is not None:
             t2 = _tskit.Tree(last_ts)
             self.assertFalse(t1.equals(t2))
             self.assertFalse(t2.equals(t1))
         last_ts = ts
예제 #9
0
 def test_newick_interface(self):
     ts = self.get_example_tree_sequence()
     st = _tskit.Tree(ts)
     # TODO this will break when we correctly handle multiple roots.
     self.assertEqual(st.get_newick(0), b"1;")
     for bad_type in [None, "", [], {}]:
         self.assertRaises(TypeError, st.get_newick, precision=bad_type)
         self.assertRaises(TypeError,
                           st.get_newick,
                           ts,
                           time_scale=bad_type)
     while st.next():
         newick = st.get_newick(st.get_left_root())
         self.assertTrue(newick.endswith(b";"))
예제 #10
0
 def test_count_tracked_samples(self):
     # Ensure that there are some non-binary nodes.
     non_binary = False
     for ts in self.get_example_tree_sequences():
         st = _tskit.Tree(ts)
         while st.next():
             for u in range(ts.get_num_nodes()):
                 if len(st.get_children(u)) > 1:
                     non_binary = True
         samples = [j for j in range(ts.get_num_samples())]
         powerset = itertools.chain.from_iterable(
             itertools.combinations(samples, r)
             for r in range(len(samples) + 1))
         max_sets = 100
         for _, subset in zip(range(max_sets), map(list, powerset)):
             # Ordering shouldn't make any difference.
             random.shuffle(subset)
             st = _tskit.Tree(ts,
                              options=_tskit.SAMPLE_COUNTS,
                              tracked_samples=subset)
             while st.next():
                 nu = get_tracked_sample_counts(st, subset)
                 nu_prime = [
                     st.get_num_tracked_samples(j)
                     for j in range(st.get_num_nodes())
                 ]
                 self.assertEqual(nu, nu_prime)
         # Passing duplicated values should raise an error
         sample = 1
         for j in range(2, 20):
             tracked_samples = [sample for _ in range(j)]
             self.assertRaises(_tskit.LibraryError,
                               _tskit.Tree,
                               ts,
                               options=_tskit.SAMPLE_COUNTS,
                               tracked_samples=tracked_samples)
     self.assertTrue(non_binary)
예제 #11
0
 def test_bounds_checking(self):
     for ts in self.get_example_tree_sequences():
         n = ts.get_num_nodes()
         st = _tskit.Tree(ts,
                          options=_tskit.SAMPLE_COUNTS
                          | _tskit.SAMPLE_LISTS)
         for v in [-100, -1, n + 1, n + 100, n * 100]:
             self.assertRaises(ValueError, st.get_parent, v)
             self.assertRaises(ValueError, st.get_children, v)
             self.assertRaises(ValueError, st.get_time, v)
             self.assertRaises(ValueError, st.get_left_sample, v)
             self.assertRaises(ValueError, st.get_right_sample, v)
             self.assertRaises(ValueError, st.is_descendant, v, 0)
             self.assertRaises(ValueError, st.is_descendant, 0, v)
         n = ts.get_num_samples()
         for v in [-100, -1, n + 1, n + 100, n * 100]:
             self.assertRaises(ValueError, st.get_next_sample, v)
예제 #12
0
    def test_cleared_tree(self):
        ts = self.get_example_tree_sequence()
        samples = ts.get_samples()

        def check_tree(tree):
            self.assertEqual(tree.get_index(), -1)
            self.assertEqual(tree.get_left_root(), samples[0])
            self.assertEqual(tree.get_mrca(0, 1), _tskit.NULL)
            for u in range(ts.get_num_nodes()):
                self.assertEqual(tree.get_parent(u), _tskit.NULL)
                self.assertEqual(tree.get_left_child(u), _tskit.NULL)
                self.assertEqual(tree.get_right_child(u), _tskit.NULL)

        tree = _tskit.Tree(ts)
        check_tree(tree)
        while tree.next():
            pass
        check_tree(tree)
        while tree.prev():
            pass
        check_tree(tree)
예제 #13
0
 def test_while_loop_semantics(self):
     for ts in self.get_example_tree_sequences():
         tree = _tskit.Tree(ts)
         # Any mixture of prev and next is OK and gives a valid iteration.
         for _ in range(2):
             j = 0
             while tree.next():
                 self.assertEqual(tree.get_index(), j)
                 j += 1
             self.assertEqual(j, ts.get_num_trees())
         for _ in range(2):
             j = ts.get_num_trees()
             while tree.prev():
                 self.assertEqual(tree.get_index(), j - 1)
                 j -= 1
             self.assertEqual(j, 0)
         j = 0
         while tree.next():
             self.assertEqual(tree.get_index(), j)
             j += 1
         self.assertEqual(j, ts.get_num_trees())
예제 #14
0
 def test_count_all_samples(self):
     for ts in self.get_example_tree_sequences():
         self.verify_iterator(_tskit.TreeDiffIterator(ts))
         st = _tskit.Tree(ts, options=_tskit.SAMPLE_COUNTS)
         # Without initialisation we should be 0 samples for every node
         # that is not a sample.
         for j in range(st.get_num_nodes()):
             count = 1 if j < ts.get_num_samples() else 0
             self.assertEqual(st.get_num_samples(j), count)
             self.assertEqual(st.get_num_tracked_samples(j), 0)
         while st.next():
             nu = get_sample_counts(ts, st)
             nu_prime = [
                 st.get_num_samples(j) for j in range(st.get_num_nodes())
             ]
             self.assertEqual(nu, nu_prime)
             # For tracked samples, this should be all zeros.
             nu = [
                 st.get_num_tracked_samples(j)
                 for j in range(st.get_num_nodes())
             ]
             self.assertEqual(nu, list([0 for _ in nu]))
예제 #15
0
 def test_constructor(self):
     self.assertRaises(TypeError, _tskit.Tree)
     for bad_type in ["", {}, [], None, 0]:
         self.assertRaises(TypeError, _tskit.Tree, bad_type)
     ts = self.get_example_tree_sequence()
     for bad_type in ["", {}, True, 1, None]:
         self.assertRaises(TypeError,
                           _tskit.Tree,
                           ts,
                           tracked_samples=bad_type)
     for bad_type in ["", {}, None, []]:
         self.assertRaises(TypeError, _tskit.Tree, ts, options=bad_type)
     for ts in self.get_example_tree_sequences():
         st = _tskit.Tree(ts)
         self.assertEqual(st.get_num_nodes(), ts.get_num_nodes())
         # An uninitialised sparse tree should always be zero.
         self.assertEqual(st.get_left_root(), 0)
         self.assertEqual(st.get_left(), 0)
         self.assertEqual(st.get_right(), 0)
         for j in range(ts.get_num_samples()):
             self.assertEqual(st.get_parent(j), _tskit.NULL)
             self.assertEqual(st.get_children(j), tuple())
             self.assertEqual(st.get_time(j), 0)