def sample_dlcoal_no_ifix(stree, n, freq, duprate, lossrate, freqdup, freqloss,\ forcetime, namefunc=lambda x: x, \ remove_single=True, name_internal="n", minsize=0): """Sample a gene tree from the DLCoal model using the new simulator""" # generate the locus tree while True: locus_tree, locus_extras = sim_DLILS_gene_tree(stree, n, freq, \ duprate, lossrate, \ freqdup, freqloss, \ forcetime) if len(locus_tree.leaves()) >= minsize: break if len(locus_tree.nodes) <= 1: # TODO: check 1 value # total extinction coal_tree = treelib.Tree() coal_tree.make_root() coal_recon = {coal_tree.root: locus_tree.root} daughters = set() else: # simulate coalescence # create new (expanded) locus tree logged_locus_tree, logged_extras = locus_to_logged_tree(locus_tree, popsize = n) daughters = logged_extras[0] pops = logged_extras[1] log_recon = logged_extras[2] # treelib.assert_tree(logged_locus_tree) # removed locus_tree_copy from below coal_tree, coal_recon = dlcoal.sample_locus_coal_tree(logged_locus_tree, n=pops, daughters=daughters, namefunc=lambda lognamex: log_recon[lognamex] + '_' + str(lognamex)) # print set(coal_tree) - set(coal_tree.postorder()) treelib.assert_tree(coal_tree) # clean up coal tree if remove_single: treelib.remove_single_children(coal_tree) phylo.subset_recon(coal_tree, coal_recon) if name_internal: dlcoal.rename_nodes(coal_tree, name_internal) dlcoal.rename_nodes(locus_tree, name_internal) # store extra information ### TODO: update this now that we're using logged locus tree, new sample function extra = {"locus_tree": locus_tree, "locus_recon": locus_extras['recon'], "locus_events": locus_extras['events'], "coal_tree": coal_tree, "coal_recon": coal_recon, "daughters": daughters} return coal_tree, extra
def _test_bounded_multicoal_tree(stree, n, T, nsamples): """test multicoal_tree""" tops = {} for i in xrange(nsamples): # use rejection sampling #tree, recon = coal.sample_bounded_multicoal_tree_reject( # stree, n, T, namefunc=lambda x: x) # sample tree tree, recon = coal.sample_bounded_multicoal_tree( stree, n, T, namefunc=lambda x: x) top = phylo.hash_tree(tree) tops.setdefault(top, [0, tree, recon])[0] += 1 tab = Table(headers=["top", "simple_top", "percent", "prob"]) for top, (num, tree, recon) in tops.items(): tree2 = tree.copy() treelib.remove_single_children(tree2) tab.add(top=top, simple_top=phylo.hash_tree(tree2), percent=num/float(nsamples), prob=exp(coal.prob_bounded_multicoal_recon_topology( tree, recon, stree, n, T))) tab.sort(col="prob", reverse=True) return tab, tops
def dup_loss_topology_prior(tree, stree, recon, birth, death, maxdoom=20, events=None): """ Returns the log prior of a gene tree topology according to dup-loss model """ def gene2species(gene): return recon[tree.nodes[gene]].name if events is None: events = phylo.label_events(tree, recon) leaves = set(tree.leaves()) phylo.add_implied_spec_nodes(tree, stree, recon, events) pstree, snodes, snodelookup = spidir.make_ptree(stree) # get doomtable doomtable = calc_doom_table(stree, birth, death, maxdoom) prod = 0.0 for node in tree: if events[node] == "spec": for schild in recon[node].children: nodes2 = [x for x in node.children if recon[x] == schild] if len(nodes2) > 0: node2 = nodes2[0] subleaves = get_sub_tree(node2, schild, recon, events) nhist = birthdeath.num_topology_histories(node2, subleaves) s = len(subleaves) thist = stats.factorial(s) * stats.factorial(s - 1) / 2 ** (s - 1) if len(set(subleaves) & leaves) == 0: # internal prod += log(num_redundant_topology(node2, gene2species, subleaves, True)) else: # leaves prod += log(num_redundant_topology(node2, gene2species, subleaves, False)) else: nhist = 1.0 thist = 1.0 s = 0 t = sum( stats.choose(s + i, i) * birthdeath.prob_birth_death1(s + i, schild.dist, birth, death) * exp(doomtable[snodelookup[schild]]) ** i for i in range(maxdoom + 1) ) prod += log(nhist) - log(thist) + log(t) # correct for renumbering nt = num_redundant_topology(tree.root, gene2species) prod -= log(nt) # phylo.removeImpliedSpecNodes(tree, recon, events) treelib.remove_single_children(tree) return prod
def sample_dlcoal(stree, n, duprate, lossrate, namefunc=lambda x: x, remove_single=True, name_internal="n", minsize=0): """Sample a gene tree from the DLCoal model""" # generate the locus tree while True: locus_tree, locus_recon, locus_events = \ birthdeath.sample_birth_death_gene_tree( stree, duprate, lossrate) if len(locus_tree.leaves()) >= minsize: break if len(locus_tree.nodes) <= 1: # total extinction coal_tree = treelib.Tree() coal_tree.make_root() coal_recon = {coal_tree.root: locus_tree.root} daughters = set() else: # simulate coalescence # choose daughter duplications daughters = set() for node in locus_tree: if locus_events[node] == "dup": daughters.add(node.children[random.randint(0, 1)]) coal_tree, coal_recon = sample_multicoal_tree(locus_tree, n, daughters=daughters, namefunc=namefunc) # clean up coal tree if remove_single: treelib.remove_single_children(coal_tree) phylo.subset_recon(coal_tree, coal_recon) if name_internal: rename_nodes(coal_tree, name_internal) rename_nodes(locus_tree, name_internal) # store extra information extra = {"locus_tree": locus_tree, "locus_recon": locus_recon, "locus_events": locus_events, "coal_tree": coal_tree, "coal_recon": coal_recon, "daughters": daughters} return coal_tree, extra
def test_trans_switch(): """ Calculate transition probabilities for switch matrix Only calculate a single matrix """ create_data = False if create_data: make_clean_dir('test/data/test_trans_switch') # model parameters k = 12 n = 1e4 rho = 1.5e-8 * 20 length = 1000 times = argweaver.get_time_points(ntimes=20, maxtime=200000) popsizes = [n] * len(times) ntests = 100 # generate test data if create_data: for i in range(ntests): # Sample ARG with at least one recombination. while True: arg = argweaver.sample_arg_dsmc(k, 2 * n, rho, start=0, end=length, times=times) if any(x.event == "recomb" for x in arg): break arg.write('test/data/test_trans_switch/%d.arg' % i) for i in range(ntests): print('arg', i) arg = arglib.read_arg('test/data/test_trans_switch/%d.arg' % i) argweaver.discretize_arg(arg, times) recombs = [x.pos for x in arg if x.event == "recomb"] pos = recombs[0] tree = arg.get_marginal_tree(pos - .5) rpos, r, c = next(arglib.iter_arg_sprs(arg, start=pos - .5)) spr = (r, c) if not argweaverc.assert_transition_switch_probs( tree, spr, times, popsizes, rho): tree2 = tree.get_tree() treelib.remove_single_children(tree2) treelib.draw_tree_names(tree2, maxlen=5, minlen=5) assert False
def test_trans_switch(): """ Calculate transition probabilities for switch matrix Only calculate a single matrix """ create_data = False if create_data: make_clean_dir('test/data/test_trans_switch') # model parameters k = 12 n = 1e4 rho = 1.5e-8 * 20 length = 1000 times = argweaver.get_time_points(ntimes=20, maxtime=200000) popsizes = [n] * len(times) ntests = 100 # generate test data if create_data: for i in range(ntests): # Sample ARG with at least one recombination. while True: arg = argweaver.sample_arg_dsmc( k, 2*n, rho, start=0, end=length, times=times) if any(x.event == "recomb" for x in arg): break arg.write('test/data/test_trans_switch/%d.arg' % i) for i in range(ntests): print 'arg', i arg = arglib.read_arg('test/data/test_trans_switch/%d.arg' % i) argweaver.discretize_arg(arg, times) recombs = [x.pos for x in arg if x.event == "recomb"] pos = recombs[0] tree = arg.get_marginal_tree(pos-.5) rpos, r, c = arglib.iter_arg_sprs(arg, start=pos-.5).next() spr = (r, c) if not argweaverc.assert_transition_switch_probs( tree, spr, times, popsizes, rho): tree2 = tree.get_tree() treelib.remove_single_children(tree2) treelib.draw_tree_names(tree2, maxlen=5, minlen=5) assert False
def _test_prog_infsites(): make_clean_dir("test/data/test_prog_infsites") run_cmd("""bin/arg-sim \ -k 40 -L 200000 \ -N 1e4 -r 1.5e-8 -m 2.5e-8 --infsites \ --ntimes 20 --maxtime 400e3 \ -o test/data/test_prog_infsites/0""") make_clean_dir("test/data/test_prog_infsites/0.sample") run_cmd("""bin/arg-sample \ -s test/data/test_prog_infsites/0.sites \ -N 1e4 -r 1.5e-8 -m 2.5e-8 \ --ntimes 5 --maxtime 100e3 -c 1 \ --climb 0 -n 20 --infsites \ -x 1 \ -o test/data/test_prog_infsites/0.sample/out""") arg = argweaver.read_arg( "test/data/test_prog_infsites/0.sample/out.0.smc.gz") sites = argweaver.read_sites("test/data/test_prog_infsites/0.sites") print "names", sites.names print noncompats = [] for block, tree in arglib.iter_local_trees(arg): tree = tree.get_tree() treelib.remove_single_children(tree) phylo.hash_order_tree(tree) for pos, col in sites.iter_region(block[0]+1, block[1]+1): assert block[0]+1 <= pos <= block[1]+1, (block, pos) split = sites_split(sites.names, col) node = arglib.split_to_tree_branch(tree, split) if node is None: noncompats.append(pos) print "noncompat", block, pos, col print phylo.hash_tree(tree) print tree.leaf_names() print "".join(col[sites.names.index(name)] for name in tree.leaf_names()) print split print print "num noncompats", len(noncompats)
def _test_prog_infsites(): make_clean_dir("test/tmp/test_prog_infsites") run_cmd("""bin/arg-sim \ -k 40 -L 200000 \ -N 1e4 -r 1.5e-8 -m 2.5e-8 --infsites \ --ntimes 20 --maxtime 400e3 \ -o test/tmp/test_prog_infsites/0""") make_clean_dir("test/tmp/test_prog_infsites/0.sample") run_cmd("""bin/arg-sample \ -s test/tmp/test_prog_infsites/0.sites \ -N 1e4 -r 1.5e-8 -m 2.5e-8 \ --ntimes 5 --maxtime 100e3 -c 1 \ --climb 0 -n 20 --infsites \ -x 1 \ -o test/tmp/test_prog_infsites/0.sample/out""") arg = argweaver.read_arg( "test/tmp/test_prog_infsites/0.sample/out.0.smc.gz") sites = argweaver.read_sites("test/tmp/test_prog_infsites/0.sites") print "names", sites.names print noncompats = [] for block, tree in arglib.iter_local_trees(arg): tree = tree.get_tree() treelib.remove_single_children(tree) phylo.hash_order_tree(tree) for pos, col in sites.iter_region(block[0] + 1, block[1] + 1): assert block[0] + 1 <= pos <= block[1] + 1, (block, pos) split = sites_split(sites.names, col) node = arglib.split_to_tree_branch(tree, split) if node is None: noncompats.append(pos) print "noncompat", block, pos, col print phylo.hash_tree(tree) print tree.leaf_names() print "".join(col[sites.names.index(name)] for name in tree.leaf_names()) print split print print "num noncompats", len(noncompats)
def eval_proposal(self, proposal): """Compute probability of proposal""" # compute recon probability phylo.add_implied_spec_nodes(proposal["locus_tree"], self.stree, proposal["locus_recon"], proposal["locus_events"]) p = prob_dlcoal_recon_topology(self.coal_tree, proposal["coal_recon"], proposal["locus_tree"], proposal["locus_recon"], proposal["locus_events"], proposal["daughters"], self.stree, self.n, self.duprate, self.lossrate, self.pretime, self.premean, maxdoom=self.maxdoom, nsamples=self.nsamples, add_spec=False) treelib.remove_single_children(proposal["locus_tree"]) phylo.subset_recon(proposal["locus_tree"], proposal["locus_recon"]) return p
def sample_birth_death_gene_tree(stree, birth, death, genename=lambda sp, x: sp + "_" + str(x), removeloss=True): """Simulate a gene tree within a species tree with birth and death rates""" # initialize gene tree tree = treelib.Tree() tree.make_root() recon = {tree.root: stree.root} events = {tree.root: "spec"} losses = set() def walk(snode, node): if snode.is_leaf(): tree.rename(node.name, genename(snode.name, node.name)) events[node] = "gene" else: for child in snode: # determine if loss will occur tree2, doom = sample_birth_death_tree(child.dist, birth, death, tree=tree, node=node, keepdoom=True) # record reconciliation next_nodes = [] def walk2(node): node.recurse(walk2) recon[node] = child if node in doom: losses.add(node) events[node] = "gene" elif node.is_leaf(): events[node] = "spec" next_nodes.append(node) else: events[node] = "dup" walk2(node.children[-1]) # recurse for leaf in next_nodes: walk(child, leaf) # if no child for node then it is a loss if node.is_leaf(): losses.add(node) walk(stree.root, tree.root) # remove lost nodes if removeloss: treelib.remove_exposed_internal_nodes(tree, set(tree.leaves()) - losses) treelib.remove_single_children(tree, simplify_root=False) delnodes = set() for node in recon: if node.name not in tree.nodes: delnodes.add(node) for node in delnodes: del recon[node] del events[node] if len(tree.nodes) <= 1: tree.nodes = {tree.root.name: tree.root} recon = {tree.root: stree.root} events = {tree.root: "spec"} return tree, recon, events
def sample_dlcoal(stree, n, duprate, lossrate, leaf_counts=None, namefunc=lambda x: x, remove_single=True, name_internal="n", minsize=0, reject=False): """Sample a gene tree from the DLCoal model""" # generate the locus tree while True: # TODO: does this take a namefunc? locus_tree, locus_recon, locus_events = \ birthdeath.sample_birth_death_gene_tree( stree, duprate, lossrate) if len(locus_tree.leaves()) >= minsize: break # if n is a dict, update it with gene names from locus tree if isinstance(n, dict): n2 = {} for node, snode in locus_recon.iteritems(): n2[node.name] = n[snode.name] else: n2 = n # if leaf_counts is a dict, update it with gene names from locus tree # TODO: how to handle copy number polymorphism? if isinstance(leaf_counts, dict): leaf_counts2 = {} for node in locus_tree.leaves(): snode = locus_recon[node] leaf_counts2[node.name] = leaf_counts[snode.name] else: leaf_counts2 = leaf_counts if len(locus_tree.nodes) <= 1: # total extinction coal_tree = treelib.Tree() coal_tree.make_root() coal_recon = {coal_tree.root: locus_tree.root} daughters = set() else: # simulate coalescence # choose daughter duplications daughters = set() for node in locus_tree: if locus_events[node] == "dup": daughters.add(node.children[random.randint(0, 1)]) if reject: # use slow rejection sampling (for testing) coal_tree, coal_recon = sample_multilocus_tree_reject( locus_tree, n2, leaf_counts=leaf_counts2, daughters=daughters, namefunc=namefunc) else: coal_tree, coal_recon = sample_multilocus_tree( locus_tree, n2, leaf_counts=leaf_counts2, daughters=daughters, namefunc=namefunc) # clean up coal tree if remove_single: treelib.remove_single_children(coal_tree) phylo.subset_recon(coal_tree, coal_recon) if name_internal: dlcoal.rename_nodes(coal_tree, name_internal) dlcoal.rename_nodes(locus_tree, name_internal) # store extra information extra = {"locus_tree": locus_tree, "locus_recon": locus_recon, "locus_events": locus_events, "coal_tree": coal_tree, "coal_recon": coal_recon, "daughters": daughters} return coal_tree, extra
def dlcoal_recon_old(tree, stree, gene2species, n, duprate, lossrate, pretime=None, premean=None, nsearch=1000, maxdoom=20, nsamples=100, search=phylo.TreeSearchNni): """ Perform reconciliation using the DLCoal model Returns (maxp, maxrecon) where 'maxp' is the probability of the MAP reconciliation 'maxrecon' which further defined as maxrecon = {'coal_recon': coal_recon, 'locus_tree': locus_tree, 'locus_recon': locus_recon, 'locus_events': locus_events, 'daughters': daughters} """ # init coal tree coal_tree = tree # init locus tree as congruent to coal tree # equivalent to assuming no ILS locus_tree = coal_tree.copy() maxp = - util.INF maxrecon = None # init search locus_search = search(locus_tree) for i in xrange(nsearch): # TODO: propose other reconciliations beside LCA locus_tree2 = locus_tree.copy() phylo.recon_root(locus_tree2, stree, gene2species, newCopy=False) locus_recon = phylo.reconcile(locus_tree2, stree, gene2species) locus_events = phylo.label_events(locus_tree2, locus_recon) # propose daughters (TODO) daughters = set() # propose coal recon (TODO: propose others beside LCA) coal_recon = phylo.reconcile(coal_tree, locus_tree2, lambda x: x) # compute recon probability phylo.add_implied_spec_nodes(locus_tree2, stree, locus_recon, locus_events) p = prob_dlcoal_recon_topology(coal_tree, coal_recon, locus_tree2, locus_recon, locus_events, daughters, stree, n, duprate, lossrate, pretime, premean, maxdoom=maxdoom, nsamples=nsamples, add_spec=False) treelib.remove_single_children(locus_tree2) if p > maxp: maxp = p maxrecon = {"coal_recon": coal_recon, "locus_tree": locus_tree2, "locus_recon": locus_recon, "locus_events": locus_events, "daughters": daughters} locus_tree = locus_tree2.copy() locus_search.set_tree(locus_tree) else: locus_search.revert() # perform local rearrangement to locus tree locus_search.propose() return maxp, maxrecon
def sample_birth_death_gene_tree(stree, birth, death, genename=lambda sp, x: sp + "_" + str(x), removeloss=True): """Simulate a gene tree within a species tree with birth and death rates""" # initialize gene tree tree = treelib.Tree() tree.make_root() recon = {tree.root: stree.root} events = {tree.root: "spec"} losses = set() def walk(snode, node): if snode.is_leaf(): tree.rename(node.name, genename(snode.name, node.name)) events[node] = "gene" else: for child in snode: # determine if loss will occur tree2, doom = sample_birth_death_tree( child.dist, birth, death, tree=tree, node=node, keepdoom=True) # record reconciliation next_nodes = [] def walk2(node): node.recurse(walk2) recon[node] = child if node in doom: losses.add(node) events[node] = "gene" elif node.is_leaf(): events[node] = "spec" next_nodes.append(node) else: events[node] = "dup" walk2(node.children[-1]) # recurse for leaf in next_nodes: walk(child, leaf) # if no child for node then it is a loss if node.is_leaf(): losses.add(node) walk(stree.root, tree.root) # remove lost nodes if removeloss: treelib.remove_exposed_internal_nodes(tree, set(tree.leaves()) - losses) treelib.remove_single_children(tree, simplify_root=False) delnodes = set() for node in recon: if node.name not in tree.nodes: delnodes.add(node) for node in delnodes: del recon[node] del events[node] if len(tree.nodes) <= 1: tree.nodes = {tree.root.name : tree.root} recon = {tree.root: stree.root} events = {tree.root: "spec"} return tree, recon, events
def labeledrecon_to_recon(gene_tree, labeled_recon, stree, name_internal="n"): """Convert from DLCpar to DLCoal reconciliation model NOTE: This is non-reversible because it produces NON-dated coalescent and locus trees """ locus_map = labeled_recon.locus_map species_map = labeled_recon.species_map order = labeled_recon.order # coalescent tree equals gene tree coal_tree = gene_tree.copy() # factor gene tree events = phylo.label_events(gene_tree, species_map) subtrees = factor_tree(gene_tree, stree, species_map, events) # gene names genenames = {} for snode in stree: genenames[snode] = {} for leaf in gene_tree.leaves(): genenames[species_map[leaf]][locus_map[leaf]] = leaf.name # 2D dict to keep track of locus tree nodes by hashing by speciation node and locus # key1 = snode, key2 = locus, value = list of nodes (sorted from oldest to most recent) locus_tree_map = {} for snode in stree: locus_tree_map[snode] = {} # initialize locus tree, coal/locus recon, and daughters locus_tree = treelib.Tree() coal_recon = {} locus_recon = {} locus_events = {} daughters = [] # initialize root of locus tree root = treelib.TreeNode(locus_tree.new_name()) locus_tree.add(root) locus_tree.root = root sroot = species_map[gene_tree.root] locus = locus_map[gene_tree.root] coal_recon[coal_tree.root] = root locus_recon[root] = sroot locus_tree_map[sroot][locus] = [root] # build locus tree along each species branch for snode in stree.preorder(sroot): subtrees_snode = subtrees[snode] # skip if no branches in this species branch if len(subtrees_snode) == 0: continue # build locus tree # 1) speciation if snode.parent: for (root, rootchild, leaves) in subtrees_snode: if rootchild: locus = locus_map[root] # use root locus! # create new locus tree node in this species branch if locus not in locus_tree_map[snode]: old_node = locus_tree_map[snode.parent][locus][-1] new_node = treelib.TreeNode(locus_tree.new_name()) locus_tree.add_child(old_node, new_node) locus_recon[new_node] = snode locus_events[old_node] = "spec" locus_tree_map[snode][locus] = [new_node] # update coal_recon cnode = coal_tree.nodes[rootchild.name] lnode = locus_tree_map[snode][locus][-1] coal_recon[cnode] = lnode # 2) duplication if snode in order: # may have to reorder loci (in case of multiple duplications) queue = collections.deque(order[snode].keys()) while len(queue) > 0: plocus = queue.popleft() if plocus not in locus_tree_map[snode]: # punt queue.append(plocus) continue # handle this ordered list lst = order[snode][plocus] for gnode in lst: locus = locus_map[gnode] cnode = coal_tree.nodes[gnode.name] if locus != plocus: # duplication # update locus_tree, locus_recon, and daughters old_node = locus_tree_map[snode][plocus][-1] new_node1 = treelib.TreeNode(locus_tree.new_name()) locus_tree.add_child(old_node, new_node1) locus_recon[new_node1] = snode new_node2 = treelib.TreeNode(locus_tree.new_name()) locus_tree.add_child(old_node, new_node2) coal_recon[cnode] = new_node2 locus_recon[new_node2] = snode daughters.append(new_node2) locus_events[old_node] = "dup" locus_tree_map[snode][plocus].append(new_node1) locus_tree_map[snode][locus] = [new_node2] else: # deep coalescence lnode = locus_tree_map[snode][locus][-1] coal_recon[cnode] = lnode # reconcile remaining coal tree nodes to locus tree # (no duplication so only a single locus tree node with the desired locus) for (root, rootchild, leaves) in subtrees_snode: if rootchild: for gnode in gene_tree.preorder(rootchild, is_leaf=lambda x: x in leaves): cnode = coal_tree.nodes[gnode.name] if cnode not in coal_recon: locus = locus_map[gnode] assert len(locus_tree_map[snode][locus]) == 1 lnode = locus_tree_map[snode][locus][-1] coal_recon[cnode] = lnode # tidy up if at an extant species if snode.is_leaf(): for locus, nodes in locus_tree_map[snode].iteritems(): genename = genenames[snode][locus] lnode = nodes[-1] cnode = coal_tree.nodes[genename] # relabel genes in locus tree locus_tree.rename(lnode.name, genename) # relabel locus events locus_events[lnode] = "gene" # reconcile genes (genes in coal tree reconcile to genes in locus tree) # possible mismatch due to genes having an internal ordering even though all exist to present time # [could also do a new round of "speciation" at bottom of extant species branches, # but this introduces single children nodes that would just be removed anyway] coal_recon[cnode] = lnode # rename internal nodes common.rename_nodes(locus_tree, name_internal) # simplify coal_tree (and reconciliations) removed = treelib.remove_single_children(coal_tree) for cnode in removed: del coal_recon[cnode] # simplify locus_tree (and reconciliations + daughters) removed = treelib.remove_single_children(locus_tree) for cnode, lnode in coal_recon.items(): if lnode in removed: # reconciliation updates to first child that is not removed new_lnode = lnode while new_lnode in removed: new_lnode = new_lnode.children[0] coal_recon[cnode] = new_lnode for lnode in removed: del locus_recon[lnode] del locus_events[lnode] for ndx, lnode in enumerate(daughters): if lnode in removed: # daughter updates to first child that is not removed new_lnode = lnode while new_lnode in removed: new_lnode = new_lnode.children[0] daughters[ndx] = new_lnode ## locus_events = phylo.label_events(locus_tree, locus_recon) assert all([lnode in locus_events for lnode in locus_tree]) #======================================== # put everything together return coal_tree, phyloDLC.Recon(coal_recon, locus_tree, locus_recon, locus_events, daughters)
def dup_loss_topology_prior(tree, stree, recon, birth, death, maxdoom=20, events=None): """ Returns the log prior of a gene tree topology according to dup-loss model """ def gene2species(gene): return recon[tree.nodes[gene]].name if events is None: events = phylo.label_events(tree, recon) leaves = set(tree.leaves()) phylo.add_implied_spec_nodes(tree, stree, recon, events) pstree, snodes, snodelookup = spidir.make_ptree(stree) # get doomtable doomtable = calc_doom_table(stree, birth, death, maxdoom) prod = 0.0 for node in tree: if events[node] == "spec": for schild in recon[node].children: nodes2 = [x for x in node.children if recon[x] == schild] if len(nodes2) > 0: node2 = nodes2[0] subleaves = get_sub_tree(node2, schild, recon, events) nhist = birthdeath.num_topology_histories(node2, subleaves) s = len(subleaves) thist = stats.factorial(s) * stats.factorial(s - 1) / 2**( s - 1) if len(set(subleaves) & leaves) == 0: # internal prod += log( num_redundant_topology(node2, gene2species, subleaves, True)) else: # leaves prod += log( num_redundant_topology(node2, gene2species, subleaves, False)) else: nhist = 1.0 thist = 1.0 s = 0 t = sum( stats.choose(s + i, i) * birthdeath.prob_birth_death1( s + i, schild.dist, birth, death) * exp(doomtable[snodelookup[schild]])**i for i in range(maxdoom + 1)) prod += log(nhist) - log(thist) + log(t) # correct for renumbering nt = num_redundant_topology(tree.root, gene2species) prod -= log(nt) #phylo.removeImpliedSpecNodes(tree, recon, events) treelib.remove_single_children(tree) return prod
def sample_dlcoal_hem(stree, n, duprate, lossrate, freq, freqdup, freqloss, steptime, namefunc=lambda x: x, keep_extinct=False, remove_single=True, name_internal="n", minsize=0): """Sample a gene tree from the DLCoal model with hemiplasy""" # generate the locus tree while True: locus_tree, locus_extras = sample_locus_tree_hem( stree, n, duprate, lossrate, freq, freqdup, freqloss, steptime, keep_extinct=keep_extinct) if len(locus_tree.leaves()) >= minsize: break if len(locus_tree.nodes) <= 1: # TODO: check 1 value # total extinction coal_tree = treelib.Tree() coal_tree.make_root() coal_recon = {coal_tree.root: locus_tree.root} daughters = set() else: # simulate coalescence # create new (expanded) locus tree logged_locus_tree, logged_extras = locus_to_logged_tree( locus_tree, popsize=n) daughters = logged_extras[0] pops = logged_extras[1] log_recon = logged_extras[2] #treelib.assert_tree(logged_locus_tree) # removed locus_tree_copy from below coal_tree, coal_recon = dlcoal.sim.sample_multilocus_tree( logged_locus_tree, n=pops, daughters=daughters, namefunc=lambda lognamex: log_recon[lognamex]+'_'+str(lognamex)) #print set(coal_tree) - set(coal_tree.postorder()) treelib.assert_tree(coal_tree) # clean up coal tree if remove_single: treelib.remove_single_children(coal_tree) phylo.subset_recon(coal_tree, coal_recon) if name_internal: dlcoal.rename_nodes(coal_tree, name_internal) dlcoal.rename_nodes(locus_tree, name_internal) # store extra information extra = {"locus_tree": locus_tree, "locus_recon": locus_extras['recon'], "locus_events": locus_extras['events'], "coal_tree": coal_tree, "coal_recon": coal_recon, "daughters": daughters} if keep_extinct: extra["full_locus_tree"] = locus_extras["full_locus_tree"] return coal_tree, extra