def test_trans(): """ Calculate transition probabilities """ create_data = False if create_data: make_clean_dir('test/data/test_trans') k = 8 n = 1e4 rho = 1.5e-8 * 20 length = 1000 times = argweaver.get_time_points(ntimes=10, maxtime=200000) popsizes = [n] * len(times) ntests = 40 # generate test data if create_data: for i in range(ntests): arg = arglib.sample_arg(k, 2*n, rho, start=0, end=length) argweaver.discretize_arg(arg, times) arg.write('test/data/test_trans/%d.arg' % i) for i in range(ntests): print 'arg', i arg = arglib.read_arg('test/data/test_trans/%d.arg' % i) argweaver.discretize_arg(arg, times) pos = 10 tree = arg.get_marginal_tree(pos) assert argweaverc.assert_transition_probs(tree, times, popsizes, rho)
def test_trans_switch(): """ Calculate transition probabilities for switch matrix Only calculate a single matrix """ create_data = False if create_data: make_clean_dir('test/data/test_trans_switch') # model parameters k = 12 n = 1e4 rho = 1.5e-8 * 20 length = 1000 times = argweaver.get_time_points(ntimes=20, maxtime=200000) popsizes = [n] * len(times) ntests = 100 # generate test data if create_data: for i in range(ntests): # Sample ARG with at least one recombination. while True: arg = argweaver.sample_arg_dsmc(k, 2 * n, rho, start=0, end=length, times=times) if any(x.event == "recomb" for x in arg): break arg.write('test/data/test_trans_switch/%d.arg' % i) for i in range(ntests): print('arg', i) arg = arglib.read_arg('test/data/test_trans_switch/%d.arg' % i) argweaver.discretize_arg(arg, times) recombs = [x.pos for x in arg if x.event == "recomb"] pos = recombs[0] tree = arg.get_marginal_tree(pos - .5) rpos, r, c = next(arglib.iter_arg_sprs(arg, start=pos - .5)) spr = (r, c) if not argweaverc.assert_transition_switch_probs( tree, spr, times, popsizes, rho): tree2 = tree.get_tree() treelib.remove_single_children(tree2) treelib.draw_tree_names(tree2, maxlen=5, minlen=5) assert False
def test_trans_switch(): """ Calculate transition probabilities for switch matrix Only calculate a single matrix """ create_data = False if create_data: make_clean_dir('test/data/test_trans_switch') # model parameters k = 12 n = 1e4 rho = 1.5e-8 * 20 length = 1000 times = argweaver.get_time_points(ntimes=20, maxtime=200000) popsizes = [n] * len(times) ntests = 100 # generate test data if create_data: for i in range(ntests): # Sample ARG with at least one recombination. while True: arg = argweaver.sample_arg_dsmc( k, 2*n, rho, start=0, end=length, times=times) if any(x.event == "recomb" for x in arg): break arg.write('test/data/test_trans_switch/%d.arg' % i) for i in range(ntests): print 'arg', i arg = arglib.read_arg('test/data/test_trans_switch/%d.arg' % i) argweaver.discretize_arg(arg, times) recombs = [x.pos for x in arg if x.event == "recomb"] pos = recombs[0] tree = arg.get_marginal_tree(pos-.5) rpos, r, c = arglib.iter_arg_sprs(arg, start=pos-.5).next() spr = (r, c) if not argweaverc.assert_transition_switch_probs( tree, spr, times, popsizes, rho): tree2 = tree.get_tree() treelib.remove_single_children(tree2) treelib.draw_tree_names(tree2, maxlen=5, minlen=5) assert False
def test_trans(): """ Calculate transition probabilities """ k = 4 n = 1e4 rho = 1.5e-8 * 20 length = 1000 times = argweaver.get_time_points(ntimes=4, maxtime=200000) popsizes = [n] * len(times) arg = arglib.sample_arg(k, 2 * n, rho, start=0, end=length) argweaver.discretize_arg(arg, times) pos = 10 tree = arg.get_marginal_tree(pos) assert argweaverc.assert_transition_probs(tree, times, popsizes, rho)
def test_trans(): """ Calculate transition probabilities """ k = 4 n = 1e4 rho = 1.5e-8 * 20 length = 1000 times = argweaver.get_time_points(ntimes=4, maxtime=200000) popsizes = [n] * len(times) arg = arglib.sample_arg(k, 2*n, rho, start=0, end=length) argweaver.discretize_arg(arg, times) pos = 10 tree = arg.get_marginal_tree(pos) assert argweaverc.assert_transition_probs(tree, times, popsizes, rho)
def test_trans_internal(): """ Calculate transition probabilities for internal branch re-sampling Only calculate a single matrix """ k = 5 n = 1e4 rho = 1.5e-8 * 20 length = 1000 times = argweaver.get_time_points(ntimes=5, maxtime=200000) popsizes = [n] * len(times) arg = arglib.sample_arg(k, 2 * n, rho, start=0, end=length) argweaver.discretize_arg(arg, times) pos = 10 tree = arg.get_marginal_tree(pos) assert argweaverc.assert_transition_probs_internal(tree, times, popsizes, rho)
def test_trans_internal(): """ Calculate transition probabilities for internal branch re-sampling Only calculate a single matrix """ k = 5 n = 1e4 rho = 1.5e-8 * 20 length = 1000 times = argweaver.get_time_points(ntimes=5, maxtime=200000) popsizes = [n] * len(times) arg = arglib.sample_arg(k, 2*n, rho, start=0, end=length) argweaver.discretize_arg(arg, times) pos = 10 tree = arg.get_marginal_tree(pos) assert argweaverc.assert_transition_probs_internal( tree, times, popsizes, rho)
def test_forward(): k = 4 n = 1e4 rho = 1.5e-8 * 20 mu = 2.5e-8 * 20 length = int(100e3 / 20) times = argweaver.get_time_points(ntimes=100) arg = arglib.sample_arg_smc(k, 2 * n, rho, start=0, end=length) muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) print "muts", len(muts) print "recomb", len(arglib.get_recomb_pos(arg)) argweaver.discretize_arg(arg, times) # remove chrom new_name = "n%d" % (k - 1) arg = argweaver.remove_arg_thread(arg, new_name) carg = argweaverc.arg2ctrees(arg, times) util.tic("C fast") probs1 = argweaverc.argweaver_forward_algorithm(carg, seqs, times=times) util.toc() util.tic("C slow") probs2 = argweaverc.argweaver_forward_algorithm(carg, seqs, times=times, slow=True) util.toc() for i, (col1, col2) in enumerate(izip(probs1, probs2)): for a, b in izip(col1, col2): fequal(a, b, rel=.0001)
def test_forward(): k = 4 n = 1e4 rho = 1.5e-8 * 20 mu = 2.5e-8 * 20 length = int(100e3 / 20) times = argweaver.get_time_points(ntimes=100) arg = arglib.sample_arg_smc(k, 2*n, rho, start=0, end=length) muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) print "muts", len(muts) print "recomb", len(arglib.get_recombs(arg)) argweaver.discretize_arg(arg, times) # remove chrom new_name = "n%d" % (k - 1) arg = argweaver.remove_arg_thread(arg, new_name) carg = argweaverc.arg2ctrees(arg, times) util.tic("C fast") probs1 = argweaverc.argweaver_forward_algorithm(carg, seqs, times=times) util.toc() util.tic("C slow") probs2 = argweaverc.argweaver_forward_algorithm(carg, seqs, times=times, slow=True) util.toc() for i, (col1, col2) in enumerate(izip(probs1, probs2)): for a, b in izip(col1, col2): fequal(a, b, rel=.0001)
def show_plots(arg_file, sites_file, stats_file, output_prefix, rho, mu, popsize, ntimes=20, maxtime=200000): """ Show plots of convergence. """ # read true arg and seqs times = argweaver.get_time_points(ntimes=ntimes, maxtime=maxtime) arg = arglib.read_arg(arg_file) argweaver.discretize_arg(arg, times, ignore_top=False, round_age="closer") arg = arglib.smcify_arg(arg) seqs = argweaver.sites2seqs(argweaver.read_sites(sites_file)) # compute true stats arglen = arglib.arglen(arg) arg = argweaverc.arg2ctrees(arg, times) nrecombs = argweaverc.get_local_trees_ntrees(arg[0]) - 1 lk = argweaverc.calc_likelihood( arg, seqs, mu=mu, times=times, delete_arg=False) prior = argweaverc.calc_prior_prob( arg, rho=rho, times=times, popsizes=popsize, delete_arg=False) joint = lk + prior data = read_table(stats_file) # joint y2 = joint y = data.cget("joint") rplot_start(output_prefix + ".trace.joint.pdf", width=8, height=5) rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)], main="joint probability", xlab="iterations", ylab="joint probability") rp.lines([0, len(y)], [y2, y2], col="gray") rplot_end(True) # lk y2 = lk y = data.cget("likelihood") rplot_start(output_prefix + ".trace.lk.pdf", width=8, height=5) rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)], main="likelihood", xlab="iterations", ylab="likelihood") rp.lines([0, len(y)], [y2, y2], col="gray") rplot_end(True) # prior y2 = prior y = data.cget("prior") rplot_start(output_prefix + ".trace.prior.pdf", width=8, height=5) rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)], main="prior probability", xlab="iterations", ylab="prior probability") rp.lines([0, len(y)], [y2, y2], col="gray") rplot_end(True) # nrecombs y2 = nrecombs y = data.cget("recombs") rplot_start(output_prefix + ".trace.nrecombs.pdf", width=8, height=5) rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)], main="number of recombinations", xlab="iterations", ylab="number of recombinations") rp.lines([0, len(y)], [y2, y2], col="gray") rplot_end(True) # arglen y2 = arglen y = data.cget("arglen") rplot_start(output_prefix + ".trace.arglen.pdf", width=8, height=5) rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)], main="ARG branch length", xlab="iterations", ylab="ARG branch length") rp.lines([0, len(y)], [y2, y2], col="gray") rplot_end(True)
def sample_dsmc_sprs(k, popsize, rho, recombmap=None, start=0.0, end=0.0, times=None, times2=None, init_tree=None, names=None, make_names=True): """ Sample ARG using Discrete Sequentially Markovian Coalescent (SMC) k -- chromosomes popsize -- effective population size (haploid) rho -- recombination rate (recombinations / site / generation) recombmap -- map for variable recombination rate start -- staring chromosome coordinate end -- ending chromsome coordinate t -- initial time (default: 0) names -- names to use for leaves (default: None) make_names -- make names using strings (default: True) """ assert times is not None assert times2 is not None ntimes = len(times) - 1 time_steps = [times[i] - times[i - 1] for i in range(1, ntimes + 1)] # times2 = get_coal_times(times) if hasattr(popsize, "__len__"): popsizes = popsize else: popsizes = [popsize] * len(time_steps) # yield initial tree first if init_tree is None: init_tree = sample_tree(k, popsizes, times, start=start, end=end, names=names, make_names=make_names) argweaver.discretize_arg(init_tree, times2) yield init_tree # sample SPRs pos = start tree = init_tree.copy() while True: # sample next recomb point treelen = sum(x.get_dist() for x in tree) blocklen = int( sample_next_recomb(treelen, rho, pos=pos, recombmap=recombmap, minlen=1)) pos += blocklen if pos >= end - 1: break root_age_index = times.index(tree.root.age) # choose time interval for recombination states = set(argweaver.iter_coal_states(tree, times)) nbranches, nrecombs, ncoals = argweaver.get_nlineages_recomb_coal( tree, times) probs = [ nbranches[i] * time_steps[i] for i in range(root_age_index + 1) ] recomb_time_index = stats.sample(probs) recomb_time = times[recomb_time_index] # choose branch for recombination branches = [ x for x in states if x[1] == recomb_time_index and x[0] != tree.root.name ] recomb_node = tree[random.sample(branches, 1)[0][0]] # choose coal time j = recomb_time_index last_kj = nbranches[max(j - 1, 0)] while j < ntimes - 1: kj = nbranches[j] if ((recomb_node.name, j) in states and recomb_node.parents[0].age > times[j]): kj -= 1 assert kj > 0, (j, root_age_index, states) A = (times2[2 * j + 1] - times2[2 * j]) * kj if j > recomb_time_index: A += (times2[2 * j] - times2[2 * j - 1]) * last_kj coal_prob = 1.0 - exp(-A / float(popsizes[j])) if random.random() < coal_prob: break j += 1 last_kj = kj coal_time_index = j coal_time = times[j] # choose coal node # since coal points collapse, exclude parent node, but allow sibling exclude = [] def walk(node): exclude.append(node.name) if node.age == coal_time: for child in node.children: walk(child) walk(recomb_node) exclude2 = (recomb_node.parents[0].name, times.index(recomb_node.parents[0].age)) branches = [ x for x in states if x[1] == coal_time_index and x[0] not in exclude and x != exclude2 ] coal_node = tree[random.sample(branches, 1)[0][0]] # yield SPR rleaves = list(tree.leaf_names(recomb_node)) cleaves = list(tree.leaf_names(coal_node)) yield pos, (rleaves, recomb_time), (cleaves, coal_time) # apply SPR to local tree broken = recomb_node.parents[0] recoal = tree.new_node(age=coal_time, children=[recomb_node, coal_node]) # add recoal node to tree recomb_node.parents[0] = recoal broken.children.remove(recomb_node) if coal_node.parents: recoal.parents.append(coal_node.parents[0]) util.replace(coal_node.parents[0].children, coal_node, recoal) coal_node.parents[0] = recoal else: coal_node.parents.append(recoal) # remove broken node broken_child = broken.children[0] if broken.parents: broken_child.parents[0] = broken.parents[0] util.replace(broken.parents[0].children, broken, broken_child) else: broken_child.parents.remove(broken) del tree.nodes[broken.name] tree.set_root()
def test_trans_two(): """ Calculate transition probabilities for k=2 Only calculate a single matrix """ k = 2 n = 1e4 rho = 1.5e-8 * 20 length = 1000 times = argweaver.get_time_points(ntimes=5, maxtime=200000) time_steps = [times[i] - times[i - 1] for i in range(1, len(times))] time_steps.append(200000 * 10000.0) popsizes = [n] * len(times) arg = arglib.sample_arg(k, 2 * n, rho, start=0, end=length) argweaver.discretize_arg(arg, times) print "recomb", arglib.get_recomb_pos(arg) arg = argweaver.make_trunk_arg(0, length, "n0") pos = 10 tree = arg.get_marginal_tree(pos) nlineages = argweaver.get_nlineages_recomb_coal(tree, times) states = list(argweaver.iter_coal_states(tree, times)) mat = argweaver.calc_transition_probs(tree, states, nlineages, times, time_steps, popsizes, rho) nstates = len(states) def coal(j): return 1.0 - exp(-time_steps[j] / (2.0 * n)) def recoal2(k, j): p = coal(j) for m in range(k, j): p *= 1.0 - coal(m) return p def recoal(k, j): if j == nstates - 1: return exp(-sum(time_steps[m] / (2.0 * n) for m in range(k, j))) else: return ((1.0 - exp(-time_steps[j] / (2.0 * n))) * exp(-sum(time_steps[m] / (2.0 * n) for m in range(k, j)))) def isrecomb(i): return 1.0 - exp(-max(rho * 2.0 * times[i], rho)) def recomb(i, k): treelen = 2 * times[i] + time_steps[i] if k < i: return 2.0 * time_steps[k] / treelen / 2.0 else: return time_steps[k] / treelen / 2.0 def trans(i, j): a = states[i][1] b = states[j][1] p = sum(recoal(k, b) * recomb(a, k) for k in range(0, min(a, b) + 1)) p += sum(recoal(k, b) * recomb(a, k) for k in range(0, min(a, b) + 1)) p *= isrecomb(a) if i == j: p += 1.0 - isrecomb(a) return p for i in range(len(states)): for j in range(len(states)): print isrecomb(states[i][1]) print states[i], states[j], mat[i][j], log(trans(i, j)) fequal(mat[i][j], log(trans(i, j))) # recombs add up to 1 fequal(sum(recomb(i, k) for k in range(i + 1)), 0.5) # recoal add up to 1 fequal(sum(recoal(i, j) for j in range(i, nstates)), 1.0) # recomb * recoal add up to .5 fequal( sum( sum( recoal(k, j) * recomb(i, k) for k in range(0, min(i, j) + 1)) for j in range(0, nstates)), 0.5) fequal(sum(trans(i, j) for j in range(len(states))), 1.0)
def sample_dsmc_sprs( k, popsize, rho, recombmap=None, start=0.0, end=0.0, times=None, times2=None, init_tree=None, names=None, make_names=True): """ Sample ARG using Discrete Sequentially Markovian Coalescent (SMC) k -- chromosomes popsize -- effective population size (haploid) rho -- recombination rate (recombinations / site / generation) recombmap -- map for variable recombination rate start -- staring chromosome coordinate end -- ending chromsome coordinate t -- initial time (default: 0) names -- names to use for leaves (default: None) make_names -- make names using strings (default: True) """ assert times is not None assert times2 is not None ntimes = len(times) - 1 time_steps = [times[i] - times[i-1] for i in range(1, ntimes+1)] # times2 = get_coal_times(times) if hasattr(popsize, "__len__"): popsizes = popsize else: popsizes = [popsize] * len(time_steps) # yield initial tree first if init_tree is None: init_tree = sample_tree(k, popsizes, times, start=start, end=end, names=names, make_names=make_names) argweaver.discretize_arg(init_tree, times2) yield init_tree # sample SPRs pos = start tree = init_tree.copy() while True: # sample next recomb point treelen = sum(x.get_dist() for x in tree) blocklen = int(sample_next_recomb(treelen, rho, pos=pos, recombmap=recombmap, minlen=1)) pos += blocklen if pos >= end - 1: break root_age_index = times.index(tree.root.age) # choose time interval for recombination states = set(argweaver.iter_coal_states(tree, times)) nbranches, nrecombs, ncoals = argweaver.get_nlineages_recomb_coal( tree, times) probs = [nbranches[i] * time_steps[i] for i in range(root_age_index+1)] recomb_time_index = stats.sample(probs) recomb_time = times[recomb_time_index] # choose branch for recombination branches = [x for x in states if x[1] == recomb_time_index and x[0] != tree.root.name] recomb_node = tree[random.sample(branches, 1)[0][0]] # choose coal time j = recomb_time_index last_kj = nbranches[max(j-1, 0)] while j < ntimes - 1: kj = nbranches[j] if ((recomb_node.name, j) in states and recomb_node.parents[0].age > times[j]): kj -= 1 assert kj > 0, (j, root_age_index, states) A = (times2[2*j+1] - times2[2*j]) * kj if j > recomb_time_index: A += (times2[2*j] - times2[2*j-1]) * last_kj coal_prob = 1.0 - exp(-A/float(popsizes[j])) if random.random() < coal_prob: break j += 1 last_kj = kj coal_time_index = j coal_time = times[j] # choose coal node # since coal points collapse, exclude parent node, but allow sibling exclude = [] def walk(node): exclude.append(node.name) if node.age == coal_time: for child in node.children: walk(child) walk(recomb_node) exclude2 = (recomb_node.parents[0].name, times.index(recomb_node.parents[0].age)) branches = [x for x in states if x[1] == coal_time_index and x[0] not in exclude and x != exclude2] coal_node = tree[random.sample(branches, 1)[0][0]] # yield SPR rleaves = list(tree.leaf_names(recomb_node)) cleaves = list(tree.leaf_names(coal_node)) yield pos, (rleaves, recomb_time), (cleaves, coal_time) # apply SPR to local tree broken = recomb_node.parents[0] recoal = tree.new_node(age=coal_time, children=[recomb_node, coal_node]) # add recoal node to tree recomb_node.parents[0] = recoal broken.children.remove(recomb_node) if coal_node.parents: recoal.parents.append(coal_node.parents[0]) util.replace(coal_node.parents[0].children, coal_node, recoal) coal_node.parents[0] = recoal else: coal_node.parents.append(recoal) # remove broken node broken_child = broken.children[0] if broken.parents: broken_child.parents[0] = broken.parents[0] util.replace(broken.parents[0].children, broken, broken_child) else: broken_child.parents.remove(broken) del tree.nodes[broken.name] tree.set_root()
def test_trans_two(): """ Calculate transition probabilities for k=2 Only calculate a single matrix """ k = 2 n = 1e4 rho = 1.5e-8 * 20 length = 1000 times = argweaver.get_time_points(ntimes=5, maxtime=200000) time_steps = [times[i] - times[i-1] for i in range(1, len(times))] time_steps.append(200000*10000.0) popsizes = [n] * len(times) arg = arglib.sample_arg(k, 2*n, rho, start=0, end=length) argweaver.discretize_arg(arg, times) print "recomb", arglib.get_recombs(arg) arg = argweaver.make_trunk_arg(0, length, "n0") pos = 10 tree = arg.get_marginal_tree(pos) nlineages = argweaver.get_nlineages_recomb_coal(tree, times) states = list(argweaver.iter_coal_states(tree, times)) mat = argweaver.calc_transition_probs( tree, states, nlineages, times, time_steps, popsizes, rho) nstates = len(states) def coal(j): return 1.0 - exp(-time_steps[j]/(2.0 * n)) def recoal2(k, j): p = coal(j) for m in range(k, j): p *= 1.0 - coal(m) return p def recoal(k, j): if j == nstates-1: return exp(- sum(time_steps[m] / (2.0 * n) for m in range(k, j))) else: return ((1.0 - exp(-time_steps[j]/(2.0 * n))) * exp(- sum(time_steps[m] / (2.0 * n) for m in range(k, j)))) def isrecomb(i): return 1.0 - exp(-max(rho * 2.0 * times[i], rho)) def recomb(i, k): treelen = 2*times[i] + time_steps[i] if k < i: return 2.0 * time_steps[k] / treelen / 2.0 else: return time_steps[k] / treelen / 2.0 def trans(i, j): a = states[i][1] b = states[j][1] p = sum(recoal(k, b) * recomb(a, k) for k in range(0, min(a, b)+1)) p += sum(recoal(k, b) * recomb(a, k) for k in range(0, min(a, b)+1)) p *= isrecomb(a) if i == j: p += 1.0 - isrecomb(a) return p for i in range(len(states)): for j in range(len(states)): print isrecomb(states[i][1]) print states[i], states[j], mat[i][j], log(trans(i, j)) fequal(mat[i][j], log(trans(i, j))) # recombs add up to 1 fequal(sum(recomb(i, k) for k in range(i+1)), 0.5) # recoal add up to 1 fequal(sum(recoal(i, j) for j in range(i, nstates)), 1.0) # recomb * recoal add up to .5 fequal(sum(sum(recoal(k, j) * recomb(i, k) for k in range(0, min(i, j)+1)) for j in range(0, nstates)), 0.5) fequal(sum(trans(i, j) for j in range(len(states))), 1.0)