def test_post_plot(self): k = 6 n = 1e4 rho = 1.5e-8 * 50 mu = 2.5e-8 * 50 length = 10000 arg = arglib.sample_arg(k, n, rho, start=0, end=length) muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) times = arghmm.get_time_points(ntimes=30) arghmm.discretize_arg(arg, times) pause() # save #arglib.write_arg("test/data/k4.arg", arg) #fasta.write_fasta("test/data/k4.fa", seqs) new_name = "n%d" % (k - 1) thread = list( arghmm.iter_chrom_thread(arg, arg[new_name], by_block=False)) p = plot(cget(thread, 1), style="lines", ymin=times[1], ylog=10) # remove chrom new_name = "n%d" % (k - 1) arg = arghmm.remove_arg_thread(arg, new_name) model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times, rho=rho, mu=mu) print "states", len(model.states[0]) print "muts", len(muts) print "recomb", len(model.recomb_pos) - 2, model.recomb_pos[1:-1] p.plot(model.recomb_pos, [10000] * len(model.recomb_pos), style="points") probs = arghmm.get_posterior_probs(model, length, verbose=True) print "done" high = list(arghmm.iter_posterior_times(model, probs, .95)) low = list(arghmm.iter_posterior_times(model, probs, .05)) p.gnuplot("set linestyle 2") p.plot(high, style="lines") p.gnuplot("set linestyle 2") p.plot(low, style="lines") #write_list("test/data/post_real.txt", cget(thread, 1)) #write_list("test/data/post_high.txt", high) #write_list("test/data/post_low.txt", low) pause()
def test_post_plot(self): k = 6 n = 1e4 rho = 1.5e-8 * 50 mu = 2.5e-8 * 50 length = 10000 arg = arglib.sample_arg(k, n, rho, start=0, end=length) muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) times = arghmm.get_time_points(ntimes=30) arghmm.discretize_arg(arg, times) pause() # save #arglib.write_arg("test/data/k4.arg", arg) #fasta.write_fasta("test/data/k4.fa", seqs) new_name = "n%d" % (k-1) thread = list(arghmm.iter_chrom_thread(arg, arg[new_name], by_block=False)) p = plot(cget(thread, 1), style="lines", ymin=times[1], ylog=10) # remove chrom new_name = "n%d" % (k-1) arg = arghmm.remove_arg_thread(arg, new_name) model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times, rho=rho, mu=mu) print "states", len(model.states[0]) print "muts", len(muts) print "recomb", len(model.recomb_pos) - 2, model.recomb_pos[1:-1] p.plot(model.recomb_pos, [10000] * len(model.recomb_pos), style="points") probs = arghmm.get_posterior_probs(model, length, verbose=True) print "done" high = list(arghmm.iter_posterior_times(model, probs, .95)) low = list(arghmm.iter_posterior_times(model, probs, .05)) p.gnuplot("set linestyle 2") p.plot(high, style="lines") p.gnuplot("set linestyle 2") p.plot(low, style="lines") #write_list("test/data/post_real.txt", cget(thread, 1)) #write_list("test/data/post_high.txt", high) #write_list("test/data/post_low.txt", low) pause()
def test_post_real(self): k = 3 n = 1e4 rho = 1.5e-8 mu = 2.5e-8 length = 100000 arg = arglib.sample_arg(k, n, rho, start=0, end=length) muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) #arg = arglib.read_arg("test/data/real.arg") #seqs = fasta.read_fasta("test/data/real.fa") #arglib.write_arg("test/data/real.arg", arg) #fasta.write_fasta("test/data/real.fa", seqs) times = arghmm.get_time_points(maxtime=50000, ntimes=20) arghmm.discretize_arg(arg, times) new_name = "n%d" % (k - 1) thread = list( arghmm.iter_chrom_thread(arg, arg[new_name], by_block=False)) tree = arg.get_marginal_tree(0) print tree.root.age treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4) p = plot(cget(thread, 1), style="lines", ymin=10, ylog=10) #alignlib.print_align(seqs) # remove chrom keep = ["n%d" % i for i in range(k - 1)] arglib.subarg_by_leaf_names(arg, keep) arg = arglib.smcify_arg(arg) model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times, rho=rho, mu=mu) print "states", len(model.states[0]) #print "muts", len(muts) print "recomb", len(model.recomb_pos) - 2, model.recomb_pos[1:-1] probs = arghmm.get_posterior_probs(model, length, verbose=True) high = list(arghmm.iter_posterior_times(model, probs, .95)) low = list(arghmm.iter_posterior_times(model, probs, .05)) p.plot(high, style="lines") p.plot(low, style="lines") pause()
def test_post_real(self): k = 3 n = 1e4 rho = 1.5e-8 mu = 2.5e-8 length = 100000 arg = arglib.sample_arg(k, n, rho, start=0, end=length) muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) #arg = arglib.read_arg("test/data/real.arg") #seqs = fasta.read_fasta("test/data/real.fa") #arglib.write_arg("test/data/real.arg", arg) #fasta.write_fasta("test/data/real.fa", seqs) times = arghmm.get_time_points(maxtime=50000, ntimes=20) arghmm.discretize_arg(arg, times) new_name = "n%d" % (k - 1) thread = list(arghmm.iter_chrom_thread(arg, arg[new_name], by_block=False)) tree = arg.get_marginal_tree(0) print tree.root.age treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4) p = plot(cget(thread, 1), style="lines", ymin=10, ylog=10) #alignlib.print_align(seqs) # remove chrom keep = ["n%d" % i for i in range(k-1)] arglib.subarg_by_leaf_names(arg, keep) arg = arglib.smcify_arg(arg) model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times, rho=rho, mu=mu) print "states", len(model.states[0]) #print "muts", len(muts) print "recomb", len(model.recomb_pos) - 2, model.recomb_pos[1:-1] probs = arghmm.get_posterior_probs(model, length, verbose=True) high = list(arghmm.iter_posterior_times(model, probs, .95)) low = list(arghmm.iter_posterior_times(model, probs, .05)) p.plot(high, style="lines") p.plot(low, style="lines") pause()
def test_post3(self): k = 3 n = 1e4 rho = 1.5e-8 * 3 mu = 2.5e-8 * 100 length = 10000 arg = arglib.sample_arg(k, n, rho, start=0, end=length) arg.prune() muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) times = arghmm.get_time_points(ntimes=10) arghmm.discretize_arg(arg, times) tree = arg.get_marginal_tree(0) treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4) thread = list(arghmm.iter_chrom_thread(arg, arg["n2"], by_block=False)) p = plot(cget(thread, 1), style="lines", ymin=0) # remove chrom keep = ["n0", "n1"] arglib.subarg_by_leaf_names(arg, keep) arg.set_ancestral() arg.prune() model = arghmm.ArgHmm(arg, seqs, new_name="n2", times=times, rho=rho, mu=mu) print "states", len(model.states[0]) print "muts", len(muts) print "recomb", len(model.recomb_pos) - 2, model.recomb_pos[1:-1] p.plot(model.recomb_pos, [1000] * len(model.recomb_pos), style="points") probs = arghmm.get_posterior_probs(model, length, verbose=True) high = list(arghmm.iter_posterior_times(model, probs, .95)) low = list(arghmm.iter_posterior_times(model, probs, .05)) p.plot(high, style="lines") p.plot(low, style="lines") pause()
def test_post2(self): k = 2 n = 1e4 rho = 1.5e-8 * 10 mu = 2.5e-8 * 10 length = 10000 arg = arglib.sample_arg(k, n, rho, start=0, end=length) muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) print "muts", len(muts) times = arghmm.get_time_points() arghmm.discretize_arg(arg, times) thread = list(arghmm.iter_chrom_thread(arg, arg["n1"], by_block=False)) tree = arg.get_marginal_tree(0) print tree.root.age treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4) p = plot(cget(thread, 1), style="lines", ymin=0) #alignlib.print_align(seqs) # remove chrom keep = ["n0"] arglib.subarg_by_leaf_names(arg, keep) arg = arglib.smcify_arg(arg) model = arghmm.ArgHmm(arg, seqs, new_name="n1", times=times, rho=rho, mu=mu) print "states", len(model.states[0]) probs = arghmm.get_posterior_probs(model, length, verbose=True) high = list(arghmm.iter_posterior_times(model, probs, .95)) low = list(arghmm.iter_posterior_times(model, probs, .05)) p.plot(high, style="lines") p.plot(low, style="lines") pause()
def test_plot_thread(self): """ Test thread retrieval """ k = 60 n = 1e4 rho = 1.5e-8 * 20 mu = 2.5e-8 * 20 length = int(1000e3) / 20 arg = arglib.sample_arg(k, n, rho, start=0, end=length) node = arg.leaves().next() x = range(length) y = cget(arghmm.iter_chrom_thread(arg, node, by_block=False), 1) p = plot(x, y, style='lines') pause()
def test_thread(self): """ Test thread retrieval """ k = 10 n = 1e4 rho = 1.5e-8 * 10 mu = 2.5e-8 * 100 length = 1000 arg = arglib.sample_arg(k, n, rho, start=0, end=length) muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) for (block, tree), threadi in izip( arglib.iter_tree_tracks(arg), arghmm.iter_chrom_thread(arg, arg["n9"], by_block=True)): print block print threadi treelib.draw_tree_names(tree.get_tree(), minlen=5, scale=4e-4)
def test_emit_argmax(self): """ Calculate emission probabilities """ k = 10 n = 1e4 rho = 0.0 mu = 2.5e-8 * 100 length = 10000 arg = arglib.sample_arg(k, n, rho, start=0, end=length) muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) times = arghmm.get_time_points(10) arghmm.discretize_arg(arg, times) new_name = "n%d" % (k - 1) thread = list(arghmm.iter_chrom_thread(arg, arg[new_name])) arg = arghmm.remove_arg_thread(arg, new_name) model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times) nstates = model.get_num_states(1) probs = [0.0 for j in xrange(nstates)] for i in xrange(1, length): if i % 100 == 0: print i for j in xrange(nstates): probs[j] += model.prob_emission(i, j) print # is the maximum likelihood emission matching truth data = sorted(zip(probs, model.states[0]), reverse=True) pc(data[:20]) state = (thread[0][0], times.index(thread[0][1])) print data[0][1], state assert data[0][1] == state
def test_emit_argmax(self): """ Calculate emission probabilities """ k = 10 n = 1e4 rho = 0.0 mu = 2.5e-8 * 100 length = 10000 arg = arglib.sample_arg(k, n, rho, start=0, end=length) muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) times = arghmm.get_time_points(10) arghmm.discretize_arg(arg, times) new_name = "n%d" % (k-1) thread = list(arghmm.iter_chrom_thread(arg, arg[new_name])) arg = arghmm.remove_arg_thread(arg, new_name) model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times) nstates = model.get_num_states(1) probs = [0.0 for j in xrange(nstates)] for i in xrange(1, length): if i % 100 == 0: print i for j in xrange(nstates): probs[j] += model.prob_emission(i, j) print # is the maximum likelihood emission matching truth data = sorted(zip(probs, model.states[0]), reverse=True) pc(data[:20]) state = (thread[0][0], times.index(thread[0][1])) print data[0][1], state assert data[0][1] == state
def test_determ(self): k = 8 n = 1e4 rho = 1.5e-8 mu = 2.5e-8 length = 100000 arg = arglib.sample_arg(k, n, rho, start=0, end=length) muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) times = arghmm.get_time_points(maxtime=50000, ntimes=20) arghmm.discretize_arg(arg, times) new_name = "n%d" % (k - 1) thread = list( arghmm.iter_chrom_thread(arg, arg[new_name], by_block=False)) thread_clades = list( arghmm.iter_chrom_thread(arg, arg[new_name], by_block=True, use_clades=True)) # remove chrom keep = ["n%d" % i for i in range(k - 1)] arglib.subarg_by_leaf_names(arg, keep) arg = arglib.smcify_arg(arg) model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times, rho=rho, mu=mu) for i, rpos in enumerate(model.recomb_pos[1:-1]): pos = rpos + 1 model.check_local_tree(pos, force=True) #recomb = arghmm.find_tree_next_recomb(arg, pos - 1) tree = arg.get_marginal_tree(pos - .5) last_tree = arg.get_marginal_tree(pos - 1 - .5) states1 = model.states[pos - 1] states2 = model.states[pos] (recomb_branch, recomb_time), (coal_branch, coal_time) = \ arghmm.find_recomb_coal(tree, last_tree, pos=rpos) recomb_time = times.index(recomb_time) coal_time = times.index(coal_time) determ = arghmm.get_deterministic_transitions( states1, states2, times, tree, last_tree, recomb_branch, recomb_time, coal_branch, coal_time) leaves1, time1, block1 = thread_clades[i] leaves2, time2, block2 = thread_clades[i + 1] if new_name in leaves1: leaves1.remove(new_name) if new_name in leaves2: leaves2.remove(new_name) node1 = arghmm.arg_lca(arg, leaves1, None, pos - 1).name node2 = arghmm.arg_lca(arg, leaves2, None, pos).name state1 = (node1, times.index(time1)) state2 = (node2, times.index(time2)) print pos, state1, state2 try: statei1 = states1.index(state1) statei2 = states2.index(state2) except: print "states1", states1 print "states2", states2 raise statei3 = determ[statei1] print " ", statei1, statei2, statei3, states2[statei3] if statei2 != statei3 and statei3 != -1: tree = tree.get_tree() treelib.remove_single_children(tree) last_tree = last_tree.get_tree() treelib.remove_single_children(last_tree) print "block1", block1 print "block2", block2 print "r=", (recomb_branch, recomb_time) print "c=", (coal_branch, coal_time) print "tree" treelib.draw_tree_names(tree, minlen=8, maxlen=8) print "last_tree" treelib.draw_tree_names(last_tree, minlen=8, maxlen=8) assert False
def test_determ(self): k = 8 n = 1e4 rho = 1.5e-8 mu = 2.5e-8 length = 100000 arg = arglib.sample_arg(k, n, rho, start=0, end=length) muts = arglib.sample_arg_mutations(arg, mu) seqs = arglib.make_alignment(arg, muts) times = arghmm.get_time_points(maxtime=50000, ntimes=20) arghmm.discretize_arg(arg, times) new_name = "n%d" % (k - 1) thread = list(arghmm.iter_chrom_thread(arg, arg[new_name], by_block=False)) thread_clades = list(arghmm.iter_chrom_thread( arg, arg[new_name], by_block=True, use_clades=True)) # remove chrom keep = ["n%d" % i for i in range(k-1)] arglib.subarg_by_leaf_names(arg, keep) arg = arglib.smcify_arg(arg) model = arghmm.ArgHmm(arg, seqs, new_name=new_name, times=times, rho=rho, mu=mu) for i, rpos in enumerate(model.recomb_pos[1:-1]): pos = rpos + 1 model.check_local_tree(pos, force=True) #recomb = arghmm.find_tree_next_recomb(arg, pos - 1) tree = arg.get_marginal_tree(pos-.5) last_tree = arg.get_marginal_tree(pos-1-.5) states1 = model.states[pos-1] states2 = model.states[pos] (recomb_branch, recomb_time), (coal_branch, coal_time) = \ arghmm.find_recomb_coal(tree, last_tree, pos=rpos) recomb_time = times.index(recomb_time) coal_time = times.index(coal_time) determ = arghmm.get_deterministic_transitions( states1, states2, times, tree, last_tree, recomb_branch, recomb_time, coal_branch, coal_time) leaves1, time1, block1 = thread_clades[i] leaves2, time2, block2 = thread_clades[i+1] if new_name in leaves1: leaves1.remove(new_name) if new_name in leaves2: leaves2.remove(new_name) node1 = arghmm.arg_lca(arg, leaves1, None, pos-1).name node2 = arghmm.arg_lca(arg, leaves2, None, pos).name state1 = (node1, times.index(time1)) state2 = (node2, times.index(time2)) print pos, state1, state2 try: statei1 = states1.index(state1) statei2 = states2.index(state2) except: print "states1", states1 print "states2", states2 raise statei3 = determ[statei1] print " ", statei1, statei2, statei3, states2[statei3] if statei2 != statei3 and statei3 != -1: tree = tree.get_tree() treelib.remove_single_children(tree) last_tree = last_tree.get_tree() treelib.remove_single_children(last_tree) print "block1", block1 print "block2", block2 print "r=", (recomb_branch, recomb_time) print "c=", (coal_branch, coal_time) print "tree" treelib.draw_tree_names(tree, minlen=8, maxlen=8) print "last_tree" treelib.draw_tree_names(last_tree, minlen=8, maxlen=8) assert False