def test_post(self): k = 6 n = 1e4 rho = 1.5e-8 * 10 mu = 2.5e-8 * 10 length = 10000 arg = arglib.sample_arg(k, n, rho, start=0, end=length) muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) print "muts", len(muts) print "recombs", len(arglib.get_recomb_pos(arg)) times = arghmm.get_time_points(ntimes=10) arghmm.discretize_arg(arg, times) tree = arg.get_marginal_tree(0) treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4) # remove chrom new_name = "n%d" % (k - 1) keep = set(arg.leaf_names()) - set([new_name]) arglib.subarg_by_leaf_names(arg, keep) arg = arglib.smcify_arg(arg) model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times, rho=rho, mu=mu) print "states", len(model.states[0]) probs = arghmm.get_posterior_probs(model, length, verbose=True) for pcol in probs: p = sum(map(exp, pcol)) print p, " ".join("%.3f" % f for f in map(exp, pcol)) fequal(p, 1.0, rel=1e-2)
def test_post_real(self): k = 3 n = 1e4 rho = 1.5e-8 mu = 2.5e-8 length = 100000 arg = arglib.sample_arg(k, n, rho, start=0, end=length) muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) #arg = arglib.read_arg("test/data/real.arg") #seqs = fasta.read_fasta("test/data/real.fa") #arglib.write_arg("test/data/real.arg", arg) #fasta.write_fasta("test/data/real.fa", seqs) times = arghmm.get_time_points(maxtime=50000, ntimes=20) arghmm.discretize_arg(arg, times) new_name = "n%d" % (k - 1) thread = list( arghmm.iter_chrom_thread(arg, arg[new_name], by_block=False)) tree = arg.get_marginal_tree(0) print tree.root.age treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4) p = plot(cget(thread, 1), style="lines", ymin=10, ylog=10) #alignlib.print_align(seqs) # remove chrom keep = ["n%d" % i for i in range(k - 1)] arglib.subarg_by_leaf_names(arg, keep) arg = arglib.smcify_arg(arg) model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times, rho=rho, mu=mu) print "states", len(model.states[0]) #print "muts", len(muts) print "recomb", len(model.recomb_pos) - 2, model.recomb_pos[1:-1] probs = arghmm.get_posterior_probs(model, length, verbose=True) high = list(arghmm.iter_posterior_times(model, probs, .95)) low = list(arghmm.iter_posterior_times(model, probs, .05)) p.plot(high, style="lines") p.plot(low, style="lines") pause()
def test_smcify_arg(self): rho = 1.5e-8 # recomb/site/gen l = 100000 # length of locus k = 6 # number of lineages n = 2*10000 # effective popsize arg = arglib.sample_arg(k, n, rho, 0, l) arg = arglib.smcify_arg(arg) for pos, (rnode, rtime), (cnode, ctime) in arglib.iter_arg_sprs(arg): self.assertNotEqual(rnode, cnode)
def test_smcify_arg_remove_thread(self): rho = 1.5e-8 # recomb/site/gen l = 100000 # length of locus k = 6 # number of lineages n = 2*10000 # effective popsize arg = arglib.sample_arg(k, n, rho, 0, l) remove_chroms = set("n%d" % (k-1)) keep = [x for x in arg.leaf_names() if x not in remove_chroms] arg = arg.copy() arglib.subarg_by_leaf_names(arg, keep) arg = arglib.smcify_arg(arg)
def test_post_real(self): k = 3 n = 1e4 rho = 1.5e-8 mu = 2.5e-8 length = 100000 arg = arglib.sample_arg(k, n, rho, start=0, end=length) muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) #arg = arglib.read_arg("test/data/real.arg") #seqs = fasta.read_fasta("test/data/real.fa") #arglib.write_arg("test/data/real.arg", arg) #fasta.write_fasta("test/data/real.fa", seqs) times = arghmm.get_time_points(maxtime=50000, ntimes=20) arghmm.discretize_arg(arg, times) new_name = "n%d" % (k - 1) thread = list(arghmm.iter_chrom_thread(arg, arg[new_name], by_block=False)) tree = arg.get_marginal_tree(0) print tree.root.age treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4) p = plot(cget(thread, 1), style="lines", ymin=10, ylog=10) #alignlib.print_align(seqs) # remove chrom keep = ["n%d" % i for i in range(k-1)] arglib.subarg_by_leaf_names(arg, keep) arg = arglib.smcify_arg(arg) model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times, rho=rho, mu=mu) print "states", len(model.states[0]) #print "muts", len(muts) print "recomb", len(model.recomb_pos) - 2, model.recomb_pos[1:-1] probs = arghmm.get_posterior_probs(model, length, verbose=True) high = list(arghmm.iter_posterior_times(model, probs, .95)) low = list(arghmm.iter_posterior_times(model, probs, .05)) p.plot(high, style="lines") p.plot(low, style="lines") pause()
def test_post2(self): k = 2 n = 1e4 rho = 1.5e-8 * 10 mu = 2.5e-8 * 10 length = 10000 arg = arglib.sample_arg(k, n, rho, start=0, end=length) muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) print "muts", len(muts) times = arghmm.get_time_points() arghmm.discretize_arg(arg, times) thread = list(arghmm.iter_chrom_thread(arg, arg["n1"], by_block=False)) tree = arg.get_marginal_tree(0) print tree.root.age treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4) p = plot(cget(thread, 1), style="lines", ymin=0) #alignlib.print_align(seqs) # remove chrom keep = ["n0"] arglib.subarg_by_leaf_names(arg, keep) arg = arglib.smcify_arg(arg) model = arghmm.ArgHmm(arg, seqs, new_name="n1", times=times, rho=rho, mu=mu) print "states", len(model.states[0]) probs = arghmm.get_posterior_probs(model, length, verbose=True) high = list(arghmm.iter_posterior_times(model, probs, .95)) low = list(arghmm.iter_posterior_times(model, probs, .05)) p.plot(high, style="lines") p.plot(low, style="lines") pause()
def show_plots(arg_file, sites_file, stats_file, output_prefix, rho, mu, popsize, ntimes=20, maxtime=200000): """ Show plots of convergence. """ # read true arg and seqs times = argweaver.get_time_points(ntimes=ntimes, maxtime=maxtime) arg = arglib.read_arg(arg_file) argweaver.discretize_arg(arg, times, ignore_top=False, round_age="closer") arg = arglib.smcify_arg(arg) seqs = argweaver.sites2seqs(argweaver.read_sites(sites_file)) # compute true stats arglen = arglib.arglen(arg) arg = argweaverc.arg2ctrees(arg, times) nrecombs = argweaverc.get_local_trees_ntrees(arg[0]) - 1 lk = argweaverc.calc_likelihood( arg, seqs, mu=mu, times=times, delete_arg=False) prior = argweaverc.calc_prior_prob( arg, rho=rho, times=times, popsizes=popsize, delete_arg=False) joint = lk + prior data = read_table(stats_file) # joint y2 = joint y = data.cget("joint") rplot_start(output_prefix + ".trace.joint.pdf", width=8, height=5) rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)], main="joint probability", xlab="iterations", ylab="joint probability") rp.lines([0, len(y)], [y2, y2], col="gray") rplot_end(True) # lk y2 = lk y = data.cget("likelihood") rplot_start(output_prefix + ".trace.lk.pdf", width=8, height=5) rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)], main="likelihood", xlab="iterations", ylab="likelihood") rp.lines([0, len(y)], [y2, y2], col="gray") rplot_end(True) # prior y2 = prior y = data.cget("prior") rplot_start(output_prefix + ".trace.prior.pdf", width=8, height=5) rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)], main="prior probability", xlab="iterations", ylab="prior probability") rp.lines([0, len(y)], [y2, y2], col="gray") rplot_end(True) # nrecombs y2 = nrecombs y = data.cget("recombs") rplot_start(output_prefix + ".trace.nrecombs.pdf", width=8, height=5) rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)], main="number of recombinations", xlab="iterations", ylab="number of recombinations") rp.lines([0, len(y)], [y2, y2], col="gray") rplot_end(True) # arglen y2 = arglen y = data.cget("arglen") rplot_start(output_prefix + ".trace.arglen.pdf", width=8, height=5) rp.plot(y, t="l", ylim=[min(min(y), y2), max(max(y), y2)], main="ARG branch length", xlab="iterations", ylab="ARG branch length") rp.lines([0, len(y)], [y2, y2], col="gray") rplot_end(True)
from compbio import arglib if 1: cwr_coals_list = [] smc_coals_list = [] for i in range(20): k = 10 n = 10e3 length = 500e3 rho = 1.5e-8 # simulate an ARG from the CwR and convert it into SMC-style tic("simulate %d" % i) cwr_arg = arglib.sample_arg(k, n, rho, start=0, end=length) cwr_arg_converted = arglib.smcify_arg(cwr_arg) toc() # simulate an ARG directly from SMC process smc_arg = arglib.sample_arg_smc(k, n, rho, start=0, end=length) # gather all coalescence times cwr_coals = [node.age for node in cwr_arg_converted if node.event == 'coal'] smc_coals = [node.age for node in smc_arg if node.event == 'coal'] print len(cwr_coals), len(smc_coals) cwr_coals_list.append(cwr_coals) smc_coals_list.append(smc_coals)
def test_determ(self): k = 8 n = 1e4 rho = 1.5e-8 mu = 2.5e-8 length = 100000 arg = arglib.sample_arg(k, n, rho, start=0, end=length) muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) times = arghmm.get_time_points(maxtime=50000, ntimes=20) arghmm.discretize_arg(arg, times) new_name = "n%d" % (k - 1) thread = list( arghmm.iter_chrom_thread(arg, arg[new_name], by_block=False)) thread_clades = list( arghmm.iter_chrom_thread(arg, arg[new_name], by_block=True, use_clades=True)) # remove chrom keep = ["n%d" % i for i in range(k - 1)] arglib.subarg_by_leaf_names(arg, keep) arg = arglib.smcify_arg(arg) model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times, rho=rho, mu=mu) for i, rpos in enumerate(model.recomb_pos[1:-1]): pos = rpos + 1 model.check_local_tree(pos, force=True) #recomb = arghmm.find_tree_next_recomb(arg, pos - 1) tree = arg.get_marginal_tree(pos - .5) last_tree = arg.get_marginal_tree(pos - 1 - .5) states1 = model.states[pos - 1] states2 = model.states[pos] (recomb_branch, recomb_time), (coal_branch, coal_time) = \ arghmm.find_recomb_coal(tree, last_tree, pos=rpos) recomb_time = times.index(recomb_time) coal_time = times.index(coal_time) determ = arghmm.get_deterministic_transitions( states1, states2, times, tree, last_tree, recomb_branch, recomb_time, coal_branch, coal_time) leaves1, time1, block1 = thread_clades[i] leaves2, time2, block2 = thread_clades[i + 1] if new_name in leaves1: leaves1.remove(new_name) if new_name in leaves2: leaves2.remove(new_name) node1 = arghmm.arg_lca(arg, leaves1, None, pos - 1).name node2 = arghmm.arg_lca(arg, leaves2, None, pos).name state1 = (node1, times.index(time1)) state2 = (node2, times.index(time2)) print pos, state1, state2 try: statei1 = states1.index(state1) statei2 = states2.index(state2) except: print "states1", states1 print "states2", states2 raise statei3 = determ[statei1] print " ", statei1, statei2, statei3, states2[statei3] if statei2 != statei3 and statei3 != -1: tree = tree.get_tree() treelib.remove_single_children(tree) last_tree = last_tree.get_tree() treelib.remove_single_children(last_tree) print "block1", block1 print "block2", block2 print "r=", (recomb_branch, recomb_time) print "c=", (coal_branch, coal_time) print "tree" treelib.draw_tree_names(tree, minlen=8, maxlen=8) print "last_tree" treelib.draw_tree_names(last_tree, minlen=8, maxlen=8) assert False
def test_determ(self): k = 8 n = 1e4 rho = 1.5e-8 mu = 2.5e-8 length = 100000 arg = arglib.sample_arg(k, n, rho, start=0, end=length) muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) times = arghmm.get_time_points(maxtime=50000, ntimes=20) arghmm.discretize_arg(arg, times) new_name = "n%d" % (k - 1) thread = list(arghmm.iter_chrom_thread(arg, arg[new_name], by_block=False)) thread_clades = list(arghmm.iter_chrom_thread( arg, arg[new_name], by_block=True, use_clades=True)) # remove chrom keep = ["n%d" % i for i in range(k-1)] arglib.subarg_by_leaf_names(arg, keep) arg = arglib.smcify_arg(arg) model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times, rho=rho, mu=mu) for i, rpos in enumerate(model.recomb_pos[1:-1]): pos = rpos + 1 model.check_local_tree(pos, force=True) #recomb = arghmm.find_tree_next_recomb(arg, pos - 1) tree = arg.get_marginal_tree(pos-.5) last_tree = arg.get_marginal_tree(pos-1-.5) states1 = model.states[pos-1] states2 = model.states[pos] (recomb_branch, recomb_time), (coal_branch, coal_time) = \ arghmm.find_recomb_coal(tree, last_tree, pos=rpos) recomb_time = times.index(recomb_time) coal_time = times.index(coal_time) determ = arghmm.get_deterministic_transitions( states1, states2, times, tree, last_tree, recomb_branch, recomb_time, coal_branch, coal_time) leaves1, time1, block1 = thread_clades[i] leaves2, time2, block2 = thread_clades[i+1] if new_name in leaves1: leaves1.remove(new_name) if new_name in leaves2: leaves2.remove(new_name) node1 = arghmm.arg_lca(arg, leaves1, None, pos-1).name node2 = arghmm.arg_lca(arg, leaves2, None, pos).name state1 = (node1, times.index(time1)) state2 = (node2, times.index(time2)) print pos, state1, state2 try: statei1 = states1.index(state1) statei2 = states2.index(state2) except: print "states1", states1 print "states2", states2 raise statei3 = determ[statei1] print " ", statei1, statei2, statei3, states2[statei3] if statei2 != statei3 and statei3 != -1: tree = tree.get_tree() treelib.remove_single_children(tree) last_tree = last_tree.get_tree() treelib.remove_single_children(last_tree) print "block1", block1 print "block2", block2 print "r=", (recomb_branch, recomb_time) print "c=", (coal_branch, coal_time) print "tree" treelib.draw_tree_names(tree, minlen=8, maxlen=8) print "last_tree" treelib.draw_tree_names(last_tree, minlen=8, maxlen=8) assert False