def test_iter_sprs_time(self): rho = 1.5e-8 # recomb/site/gen l = 100000 # length of locus k = 40 # number of lineages n = 2*10000 # effective popsize arg = arglib.sample_arg(k, n, rho, 0, l) x = list(arglib.iter_arg_sprs(arg)) x = list(arglib.iter_arg_sprs_simple(arg)) x = list(arglib.iter_arg_sprs(arg, use_leaves=True)) x = list(arglib.iter_arg_sprs_simple(arg, use_leaves=True)) x
def test_trans_switch(): """ Calculate transition probabilities for switch matrix Only calculate a single matrix """ k = 12 n = 1e4 rho = 1.5e-8 * 20 length = 1000 times = argweaver.get_time_points(ntimes=20, maxtime=200000) popsizes = [n] * len(times) recombs = [] while len(recombs) == 0: arg = argweaver.sample_arg_dsmc(k, 2 * n, rho, start=0, end=length, times=times) recombs = [x.pos for x in arg if x.event == "recomb"] pos = recombs[0] tree = arg.get_marginal_tree(pos - .5) rpos, r, c = arglib.iter_arg_sprs(arg, start=pos - .5).next() spr = (r, c) assert argweaverc.assert_transition_switch_probs(tree, spr, times, popsizes, rho)
def test_trans_switch(self): """ Calculate transition probabilities for k=2 Only calculate a single matrix """ k = 12 n = 1e4 rho = 1.5e-8 * 20 mu = 2.5e-8 * 20 length = 1000 times = arghmm.get_time_points(ntimes=20, maxtime=200000) popsizes = [n] * len(times) recombs = [] while len(recombs) == 0: arg = arghmm.sample_arg_dsmc(k, 2*n, rho, start=0, end=length, times=times) recombs = [x.pos for x in arg if x.event == "recomb"] pos = recombs[0] tree = arg.get_marginal_tree(pos-.5) rpos, r, c = arglib.iter_arg_sprs(arg, start=pos-.5).next() spr = (r, c) assert arghmm.assert_transition_switch_probs(tree, spr, times, popsizes, rho)
def test_arg_equal(arg, arg2): # test recomb points recombs = sorted(x.pos for x in arg if x.event == "recomb") recombs2 = sorted(x.pos for x in arg2 if x.event == "recomb") assert recombs == recombs2 # check local tree topologies for (start, end), tree in arglib.iter_tree_tracks(arg): pos = (start + end) / 2.0 arglib.remove_single_lineages(tree) tree1 = tree.get_tree() tree2 = arg2.get_marginal_tree(pos) arglib.remove_single_lineages(tree2) tree2 = tree2.get_tree() hash1 = phylo.hash_tree(tree1) hash2 = phylo.hash_tree(tree2) print print pos print hash1 print hash2 assert hash1 == hash2 # check sprs sprs1 = arglib.iter_arg_sprs(arg, use_leaves=True) sprs2 = arglib.iter_arg_sprs(arg2, use_leaves=True) for (pos1, recomb1, coal1), (pos2, recomb2, coal2) in izip(sprs1, sprs2): recomb1 = (sorted(recomb1[0]), recomb1[1]) recomb2 = (sorted(recomb2[0]), recomb2[1]) coal1 = (sorted(coal1[0]), coal1[1]) coal2 = (sorted(coal2[0]), coal2[1]) print print (pos1, recomb1, coal1) print (pos2, recomb2, coal2) # check pos, leaves, time assert pos1 == pos2 assert recomb1 == recomb2 assert coal1 == coal2
def test_arg_equal(arg, arg2): # test recomb points recombs = sorted(x.pos for x in arg if x.event == "recomb") recombs2 = sorted(x.pos for x in arg2 if x.event == "recomb") assert recombs == recombs2 # check local tree topologies for (start, end), tree in arglib.iter_tree_tracks(arg): pos = (start + end) / 2.0 arglib.remove_single_lineages(tree) tree1 = tree.get_tree() tree2 = arg2.get_marginal_tree(pos) arglib.remove_single_lineages(tree2) tree2 = tree2.get_tree() hash1 = phylo.hash_tree(tree1) hash2 = phylo.hash_tree(tree2) print print pos print hash1 print hash2 assert hash1 == hash2 # check sprs sprs1 = arglib.iter_arg_sprs(arg, use_leaves=True) sprs2 = arglib.iter_arg_sprs(arg2, use_leaves=True) for (pos1, recomb1, coal1), (pos2, recomb2, coal2) in izip(sprs1, sprs2): recomb1 = (sorted(recomb1[0]), recomb1[1]) recomb2 = (sorted(recomb2[0]), recomb2[1]) coal1 = (sorted(coal1[0]), coal1[1]) coal2 = (sorted(coal2[0]), coal2[1]) print print(pos1, recomb1, coal1) print(pos2, recomb2, coal2) # check pos, leaves, time assert pos1 == pos2 assert recomb1 == recomb2 assert coal1 == coal2
def test_smcify_arg(self): rho = 1.5e-8 # recomb/site/gen l = 100000 # length of locus k = 6 # number of lineages n = 2*10000 # effective popsize arg = arglib.sample_arg(k, n, rho, 0, l) arg = arglib.smcify_arg(arg) for pos, (rnode, rtime), (cnode, ctime) in arglib.iter_arg_sprs(arg): self.assertNotEqual(rnode, cnode)
def test_iter_sprs(self): rho = 1.5e-8 # recomb/site/gen l = 100000 # length of locus k = 6 # number of lineages n = 2*10000 # effective popsize arg = arglib.sample_arg(k, n, rho, 0, l) for a, b in izip(arglib.iter_arg_sprs(arg), arglib.iter_arg_sprs_simple(arg)): self.assertEqual(a, b)
def test_trans_switch(): """ Calculate transition probabilities for switch matrix Only calculate a single matrix """ create_data = False if create_data: make_clean_dir('test/data/test_trans_switch') # model parameters k = 12 n = 1e4 rho = 1.5e-8 * 20 length = 1000 times = argweaver.get_time_points(ntimes=20, maxtime=200000) popsizes = [n] * len(times) ntests = 100 # generate test data if create_data: for i in range(ntests): # Sample ARG with at least one recombination. while True: arg = argweaver.sample_arg_dsmc(k, 2 * n, rho, start=0, end=length, times=times) if any(x.event == "recomb" for x in arg): break arg.write('test/data/test_trans_switch/%d.arg' % i) for i in range(ntests): print('arg', i) arg = arglib.read_arg('test/data/test_trans_switch/%d.arg' % i) argweaver.discretize_arg(arg, times) recombs = [x.pos for x in arg if x.event == "recomb"] pos = recombs[0] tree = arg.get_marginal_tree(pos - .5) rpos, r, c = next(arglib.iter_arg_sprs(arg, start=pos - .5)) spr = (r, c) if not argweaverc.assert_transition_switch_probs( tree, spr, times, popsizes, rho): tree2 = tree.get_tree() treelib.remove_single_children(tree2) treelib.draw_tree_names(tree2, maxlen=5, minlen=5) assert False
def arg_equal(arg, arg2): # test recomb points recombs = sorted(x.pos for x in arg if x.event == "recomb") recombs2 = sorted(x.pos for x in arg2 if x.event == "recomb") nose.tools.assert_equal(recombs, recombs2) # check local tree topologies for (start, end), tree in arglib.iter_local_trees(arg): pos = (start + end) / 2.0 arglib.remove_single_lineages(tree) tree1 = tree.get_tree() tree2 = arg2.get_marginal_tree(pos) arglib.remove_single_lineages(tree2) tree2 = tree2.get_tree() hash1 = phylo.hash_tree(tree1) hash2 = phylo.hash_tree(tree2) nose.tools.assert_equal(hash1, hash2) # check sprs sprs1 = arglib.iter_arg_sprs(arg, use_leaves=True) sprs2 = arglib.iter_arg_sprs(arg2, use_leaves=True) for (pos1, recomb1, coal1), (pos2, recomb2, coal2) in zip(sprs1, sprs2): recomb1 = (sorted(recomb1[0]), recomb1[1]) recomb2 = (sorted(recomb2[0]), recomb2[1]) coal1 = (sorted(coal1[0]), coal1[1]) coal2 = (sorted(coal2[0]), coal2[1]) # check pos, leaves, time nose.tools.assert_equal(pos1, pos2) nose.tools.assert_equal(recomb1, recomb2) nose.tools.assert_equal(coal1, coal2)
def arg_equal(arg, arg2): # test recomb points recombs = sorted(x.pos for x in arg if x.event == "recomb") recombs2 = sorted(x.pos for x in arg2 if x.event == "recomb") nose.tools.assert_equal(recombs, recombs2) # check local tree topologies for (start, end), tree in arglib.iter_local_trees(arg): pos = (start + end) / 2.0 arglib.remove_single_lineages(tree) tree1 = tree.get_tree() tree2 = arg2.get_marginal_tree(pos) arglib.remove_single_lineages(tree2) tree2 = tree2.get_tree() hash1 = phylo.hash_tree(tree1) hash2 = phylo.hash_tree(tree2) nose.tools.assert_equal(hash1, hash2) # check sprs sprs1 = arglib.iter_arg_sprs(arg, use_leaves=True) sprs2 = arglib.iter_arg_sprs(arg2, use_leaves=True) for (pos1, recomb1, coal1), (pos2, recomb2, coal2) in izip(sprs1, sprs2): recomb1 = (sorted(recomb1[0]), recomb1[1]) recomb2 = (sorted(recomb2[0]), recomb2[1]) coal1 = (sorted(coal1[0]), coal1[1]) coal2 = (sorted(coal2[0]), coal2[1]) # check pos, leaves, time nose.tools.assert_equal(pos1, pos2) nose.tools.assert_equal(recomb1, recomb2) nose.tools.assert_equal(coal1, coal2)
def test_iter_sprs_remove_thread(self): rho = 1.5e-8 # recomb/site/gen l = 100000 # length of locus k = 6 # number of lineages n = 2*10000 # effective popsize arg = arglib.sample_arg(k, n, rho, 0, l) remove_chroms = set("n%d" % (k-1)) keep = [x for x in arg.leaf_names() if x not in remove_chroms] arg = arg.copy() arglib.subarg_by_leaf_names(arg, keep) for a, b in izip(arglib.iter_arg_sprs(arg), arglib.iter_arg_sprs_simple(arg)): self.assertEqual(a, b)
def test_iter_sprs_leaves(self): rho = 1.5e-8 # recomb/site/gen l = 100000 # length of locus k = 40 # number of lineages n = 2*10000 # effective popsize arg = arglib.sample_arg(k, n, rho, 0, l) for a, b in izip(arglib.iter_arg_sprs(arg, use_leaves=True), arglib.iter_arg_sprs_simple(arg, use_leaves=True)): a[1][0].sort() a[2][0].sort() b[1][0].sort() b[2][0].sort() self.assertEqual(a, b)
def test_trans_switch(): """ Calculate transition probabilities for switch matrix Only calculate a single matrix """ create_data = False if create_data: make_clean_dir('test/data/test_trans_switch') # model parameters k = 12 n = 1e4 rho = 1.5e-8 * 20 length = 1000 times = argweaver.get_time_points(ntimes=20, maxtime=200000) popsizes = [n] * len(times) ntests = 100 # generate test data if create_data: for i in range(ntests): # Sample ARG with at least one recombination. while True: arg = argweaver.sample_arg_dsmc( k, 2*n, rho, start=0, end=length, times=times) if any(x.event == "recomb" for x in arg): break arg.write('test/data/test_trans_switch/%d.arg' % i) for i in range(ntests): print 'arg', i arg = arglib.read_arg('test/data/test_trans_switch/%d.arg' % i) argweaver.discretize_arg(arg, times) recombs = [x.pos for x in arg if x.event == "recomb"] pos = recombs[0] tree = arg.get_marginal_tree(pos-.5) rpos, r, c = arglib.iter_arg_sprs(arg, start=pos-.5).next() spr = (r, c) if not argweaverc.assert_transition_switch_probs( tree, spr, times, popsizes, rho): tree2 = tree.get_tree() treelib.remove_single_children(tree2) treelib.draw_tree_names(tree2, maxlen=5, minlen=5) assert False
def layout_chroms(arg, start=None, end=None): if start is None: start = arg.start if end is None: end = arg.end tree = arg.get_marginal_tree(start) arglib.remove_single_lineages(tree) last_pos = start blocks = [] leaf_layout = [] layout_func = layout_tree_leaves #layout_func = layout_tree_leaves_even for spr in arglib.iter_arg_sprs(arg, start=start, end=end, use_leaves=True): print "layout", spr[0] blocks.append([last_pos, spr[0]]) leaf_layout.append(layout_func(tree)) inorder = dict((n, i) for i, n in enumerate(inorder_tree(tree))) # determine SPR nodes rnode = arglib.arg_lca(tree, spr[1][0], spr[0]) cnode = arglib.arg_lca(tree, spr[2][0], spr[0]) # determine best side for adding new sister left = (inorder[rnode] < inorder[cnode]) # apply spr arglib.apply_spr(tree, rnode, spr[1][1], cnode, spr[2][1], spr[0]) # adjust sister rindex = rnode.parents[0].children.index(rnode) if left and rindex != 0: rnode.parents[0].children.reverse() last_pos = spr[0] blocks.append([last_pos, end]) leaf_layout.append(layout_func(tree)) return blocks, leaf_layout
def add_arg(self, arg): nleaves = len(list(arg.leaves())) times = self.times assert times eps = 1e-3 def get_local_children(node, pos, local): return set(child for child in arg.get_local_children(node, pos) if child in local) def get_parent(node, pos, local): parent = arg.get_local_parent(node, pos) while len(get_local_children(parent, pos, local)) == 1: parent = arg.get_local_parent(parent, pos) return parent # add initial tree tree = arg.get_marginal_tree(arg.start) starts, ends, time_steps = count_tree_lineages(tree, times) self.init_trees.append({"starts": starts, "ends": ends, "time_steps": time_steps}) # loop through sprs for recomb_pos, (rnode, rtime), (cnode, ctime), local in arglib.iter_arg_sprs(arg, use_local=True): i, _ = util.binsearch(times, ctime) self.ncoals[i] += 1 recomb_node = arg[rnode] broken_node = get_parent(recomb_node, recomb_pos - eps, local) coals = [0.0] + [node.age for node in local if len(get_local_children(node, recomb_pos - eps, local)) == 2] coals.sort() nlineages = range(nleaves, 0, -1) assert len(nlineages) == len(coals) # subtract broken branch r = coals.index(recomb_node.age) r2 = coals.index(broken_node.age) for i in range(r, r2): nlineages[i] -= 1 # get average number of branches in the time interval data = zip(coals, nlineages) for t in times[1:]: data.append((t, "time step")) data.sort() lineages_per_time = [] counts = [] last_lineages = 0 last_time = 0.0 for a, b in data: if b != "time step": if a > last_time: counts.append((last_lineages, a - last_time)) last_lineages = b else: counts.append((last_lineages, a - last_time)) s = sum(u * v for u, v in counts) total_time = sum(v for u, v in counts) if s == 0.0: lineages_per_time.append(last_lineages) else: lineages_per_time.append(s / total_time) counts = [] last_time = a assert len(lineages_per_time) == len(self.time_steps) r, _ = util.binsearch(times, rtime) c, _ = util.binsearch(times, ctime) for j in range(r, c): self.k_lineages[j] += lineages_per_time[j]
def add_arg(self, arg): nleaves = len(list(arg.leaves())) times = self.times assert times eps = 1e-3 def get_local_children(node, pos, local): return set(child for child in arg.get_local_children(node, pos) if child in local) def get_parent(node, pos, local): parent = arg.get_local_parent(node, pos) while len(get_local_children(parent, pos, local)) == 1: parent = arg.get_local_parent(parent, pos) return parent # add initial tree tree = arg.get_marginal_tree(arg.start) starts, ends, time_steps = count_tree_lineages(tree, times) self.init_trees.append({ "starts": starts, "ends": ends, "time_steps": time_steps }) # loop through sprs for recomb_pos, (rnode, rtime), (cnode, ctime), local in \ arglib.iter_arg_sprs(arg, use_local=True): i, _ = util.binsearch(times, ctime) self.ncoals[i] += 1 recomb_node = arg[rnode] broken_node = get_parent(recomb_node, recomb_pos - eps, local) coals = [0.0] + [ node.age for node in local if len(get_local_children(node, recomb_pos - eps, local)) == 2 ] coals.sort() nlineages = list(range(nleaves, 0, -1)) assert len(nlineages) == len(coals) # subtract broken branch r = coals.index(recomb_node.age) r2 = coals.index(broken_node.age) for i in range(r, r2): nlineages[i] -= 1 # get average number of branches in the time interval data = list(zip(coals, nlineages)) for t in times[1:]: data.append((t, "time step")) data.sort() lineages_per_time = [] counts = [] last_lineages = 0 last_time = 0.0 for a, b in data: if b != "time step": if a > last_time: counts.append((last_lineages, a - last_time)) last_lineages = b else: counts.append((last_lineages, a - last_time)) s = sum(u * v for u, v in counts) total_time = sum(v for u, v in counts) if s == 0.0: lineages_per_time.append(last_lineages) else: lineages_per_time.append(s / total_time) counts = [] last_time = a assert len(lineages_per_time) == len(self.time_steps) r, _ = util.binsearch(times, rtime) c, _ = util.binsearch(times, ctime) for j in range(r, c): self.k_lineages[j] += lineages_per_time[j]