示例#1
0
def sample_dlcoal_no_ifix(stree, n, freq, duprate, lossrate, freqdup, freqloss,\
                            forcetime, namefunc=lambda x: x, \
                            remove_single=True, name_internal="n", minsize=0):
    """Sample a gene tree from the DLCoal model using the new simulator"""

    # generate the locus tree
    while True:
        locus_tree, locus_extras = sim_DLILS_gene_tree(stree, n, freq, \
                                                        duprate, lossrate, \
                                                        freqdup, freqloss, \
                                                        forcetime)
        if len(locus_tree.leaves()) >= minsize:
            break

    if len(locus_tree.nodes) <= 1: # TODO: check 1 value
        # total extinction
        coal_tree = treelib.Tree()
        coal_tree.make_root()
        coal_recon = {coal_tree.root: locus_tree.root}
        daughters = set()
    else:
        # simulate coalescence
        
        # create new (expanded) locus tree
        logged_locus_tree, logged_extras = locus_to_logged_tree(locus_tree, popsize = n)
        daughters = logged_extras[0]
        pops = logged_extras[1]
        log_recon = logged_extras[2]
        
#        treelib.assert_tree(logged_locus_tree)
        
        # removed locus_tree_copy from below
        coal_tree, coal_recon = dlcoal.sample_locus_coal_tree(logged_locus_tree,
                                        n=pops, daughters=daughters,
                                        namefunc=lambda lognamex: log_recon[lognamex] + '_' + str(lognamex))

#        print set(coal_tree) - set(coal_tree.postorder())
        treelib.assert_tree(coal_tree)
    
        # clean up coal tree
        if remove_single:
            treelib.remove_single_children(coal_tree)
            phylo.subset_recon(coal_tree, coal_recon)


    if name_internal:
        dlcoal.rename_nodes(coal_tree, name_internal)
        dlcoal.rename_nodes(locus_tree, name_internal)


    # store extra information
    ### TODO: update this now that we're using logged locus tree, new sample function
    extra = {"locus_tree": locus_tree,
             "locus_recon": locus_extras['recon'],
             "locus_events": locus_extras['events'],
             "coal_tree": coal_tree,
             "coal_recon": coal_recon,
             "daughters": daughters}

    return coal_tree, extra
示例#2
0
def _test_bounded_multicoal_tree(stree, n, T, nsamples):
    """test multicoal_tree"""
    tops = {}

    for i in xrange(nsamples):

        # use rejection sampling
        #tree, recon = coal.sample_bounded_multicoal_tree_reject(
        #    stree, n, T, namefunc=lambda x: x)

        # sample tree
        tree, recon = coal.sample_bounded_multicoal_tree(
            stree, n, T, namefunc=lambda x: x)

        top = phylo.hash_tree(tree)
        tops.setdefault(top, [0, tree, recon])[0] += 1

    tab = Table(headers=["top", "simple_top", "percent", "prob"])
    for top, (num, tree, recon) in tops.items():
        tree2 = tree.copy()
        treelib.remove_single_children(tree2)
        tab.add(top=top,
                simple_top=phylo.hash_tree(tree2),
                percent=num/float(nsamples),
                prob=exp(coal.prob_bounded_multicoal_recon_topology(
                    tree, recon, stree, n, T)))
    tab.sort(col="prob", reverse=True)

    return tab, tops
示例#3
0
def dup_loss_topology_prior(tree, stree, recon, birth, death, maxdoom=20, events=None):
    """
    Returns the log prior of a gene tree topology according to dup-loss model
    """

    def gene2species(gene):
        return recon[tree.nodes[gene]].name

    if events is None:
        events = phylo.label_events(tree, recon)
    leaves = set(tree.leaves())
    phylo.add_implied_spec_nodes(tree, stree, recon, events)

    pstree, snodes, snodelookup = spidir.make_ptree(stree)

    # get doomtable
    doomtable = calc_doom_table(stree, birth, death, maxdoom)

    prod = 0.0
    for node in tree:
        if events[node] == "spec":
            for schild in recon[node].children:
                nodes2 = [x for x in node.children if recon[x] == schild]
                if len(nodes2) > 0:
                    node2 = nodes2[0]
                    subleaves = get_sub_tree(node2, schild, recon, events)
                    nhist = birthdeath.num_topology_histories(node2, subleaves)
                    s = len(subleaves)
                    thist = stats.factorial(s) * stats.factorial(s - 1) / 2 ** (s - 1)

                    if len(set(subleaves) & leaves) == 0:
                        # internal
                        prod += log(num_redundant_topology(node2, gene2species, subleaves, True))
                    else:
                        # leaves
                        prod += log(num_redundant_topology(node2, gene2species, subleaves, False))

                else:
                    nhist = 1.0
                    thist = 1.0
                    s = 0

                t = sum(
                    stats.choose(s + i, i)
                    * birthdeath.prob_birth_death1(s + i, schild.dist, birth, death)
                    * exp(doomtable[snodelookup[schild]]) ** i
                    for i in range(maxdoom + 1)
                )

                prod += log(nhist) - log(thist) + log(t)

    # correct for renumbering
    nt = num_redundant_topology(tree.root, gene2species)
    prod -= log(nt)

    # phylo.removeImpliedSpecNodes(tree, recon, events)
    treelib.remove_single_children(tree)

    return prod
示例#4
0
def sample_dlcoal(stree, n, duprate, lossrate, namefunc=lambda x: x,
                  remove_single=True, name_internal="n",
                  minsize=0):
    """Sample a gene tree from the DLCoal model"""

    # generate the locus tree
    while True:
        locus_tree, locus_recon, locus_events = \
                    birthdeath.sample_birth_death_gene_tree(
            stree, duprate, lossrate)
        if len(locus_tree.leaves()) >= minsize:
            break

    if len(locus_tree.nodes) <= 1:
        # total extinction
        coal_tree = treelib.Tree()
        coal_tree.make_root()
        coal_recon = {coal_tree.root: locus_tree.root}
        daughters = set()
    else:
        # simulate coalescence
        
        # choose daughter duplications
        daughters = set()
        for node in locus_tree:
            if locus_events[node] == "dup":
                daughters.add(node.children[random.randint(0, 1)])

        coal_tree, coal_recon = sample_multicoal_tree(locus_tree, n,
                                                      daughters=daughters,
                                                      namefunc=namefunc)

        # clean up coal tree
        if remove_single:
            treelib.remove_single_children(coal_tree)
            phylo.subset_recon(coal_tree, coal_recon)

    if name_internal:
        rename_nodes(coal_tree, name_internal)
        rename_nodes(locus_tree, name_internal)


    # store extra information
    extra = {"locus_tree": locus_tree,
             "locus_recon": locus_recon,
             "locus_events": locus_events,
             "coal_tree": coal_tree,
             "coal_recon": coal_recon,
             "daughters": daughters}

    return coal_tree, extra
示例#5
0
def test_trans_switch():
    """
    Calculate transition probabilities for switch matrix

    Only calculate a single matrix
    """
    create_data = False
    if create_data:
        make_clean_dir('test/data/test_trans_switch')

    # model parameters
    k = 12
    n = 1e4
    rho = 1.5e-8 * 20
    length = 1000
    times = argweaver.get_time_points(ntimes=20, maxtime=200000)
    popsizes = [n] * len(times)
    ntests = 100

    # generate test data
    if create_data:
        for i in range(ntests):
            # Sample ARG with at least one recombination.
            while True:
                arg = argweaver.sample_arg_dsmc(k,
                                                2 * n,
                                                rho,
                                                start=0,
                                                end=length,
                                                times=times)
                if any(x.event == "recomb" for x in arg):
                    break
            arg.write('test/data/test_trans_switch/%d.arg' % i)

    for i in range(ntests):
        print('arg', i)
        arg = arglib.read_arg('test/data/test_trans_switch/%d.arg' % i)
        argweaver.discretize_arg(arg, times)
        recombs = [x.pos for x in arg if x.event == "recomb"]
        pos = recombs[0]
        tree = arg.get_marginal_tree(pos - .5)
        rpos, r, c = next(arglib.iter_arg_sprs(arg, start=pos - .5))
        spr = (r, c)

        if not argweaverc.assert_transition_switch_probs(
                tree, spr, times, popsizes, rho):
            tree2 = tree.get_tree()
            treelib.remove_single_children(tree2)
            treelib.draw_tree_names(tree2, maxlen=5, minlen=5)
            assert False
示例#6
0
def test_trans_switch():
    """
    Calculate transition probabilities for switch matrix

    Only calculate a single matrix
    """
    create_data = False
    if create_data:
        make_clean_dir('test/data/test_trans_switch')

    # model parameters
    k = 12
    n = 1e4
    rho = 1.5e-8 * 20
    length = 1000
    times = argweaver.get_time_points(ntimes=20, maxtime=200000)
    popsizes = [n] * len(times)
    ntests = 100

    # generate test data
    if create_data:
        for i in range(ntests):
            # Sample ARG with at least one recombination.
            while True:
                arg = argweaver.sample_arg_dsmc(
                    k, 2*n, rho, start=0, end=length, times=times)
                if any(x.event == "recomb" for x in arg):
                    break
            arg.write('test/data/test_trans_switch/%d.arg' % i)

    for i in range(ntests):
        print 'arg', i
        arg = arglib.read_arg('test/data/test_trans_switch/%d.arg' % i)
        argweaver.discretize_arg(arg, times)
        recombs = [x.pos for x in arg if x.event == "recomb"]
        pos = recombs[0]
        tree = arg.get_marginal_tree(pos-.5)
        rpos, r, c = arglib.iter_arg_sprs(arg, start=pos-.5).next()
        spr = (r, c)

        if not argweaverc.assert_transition_switch_probs(
                tree, spr, times, popsizes, rho):
            tree2 = tree.get_tree()
            treelib.remove_single_children(tree2)
            treelib.draw_tree_names(tree2, maxlen=5, minlen=5)
            assert False
示例#7
0
def _test_prog_infsites():

    make_clean_dir("test/data/test_prog_infsites")

    run_cmd("""bin/arg-sim \
        -k 40 -L 200000 \
        -N 1e4 -r 1.5e-8 -m 2.5e-8 --infsites \
        --ntimes 20 --maxtime 400e3 \
        -o test/data/test_prog_infsites/0""")

    make_clean_dir("test/data/test_prog_infsites/0.sample")
    run_cmd("""bin/arg-sample \
        -s test/data/test_prog_infsites/0.sites \
        -N 1e4 -r 1.5e-8 -m 2.5e-8 \
        --ntimes 5 --maxtime 100e3 -c 1 \
        --climb 0 -n 20 --infsites \
        -x 1 \
        -o test/data/test_prog_infsites/0.sample/out""")

    arg = argweaver.read_arg(
        "test/data/test_prog_infsites/0.sample/out.0.smc.gz")
    sites = argweaver.read_sites("test/data/test_prog_infsites/0.sites")
    print "names", sites.names
    print

    noncompats = []
    for block, tree in arglib.iter_local_trees(arg):
        tree = tree.get_tree()
        treelib.remove_single_children(tree)
        phylo.hash_order_tree(tree)
        for pos, col in sites.iter_region(block[0]+1, block[1]+1):
            assert block[0]+1 <= pos <= block[1]+1, (block, pos)
            split = sites_split(sites.names, col)
            node = arglib.split_to_tree_branch(tree, split)
            if node is None:
                noncompats.append(pos)
                print "noncompat", block, pos, col
                print phylo.hash_tree(tree)
                print tree.leaf_names()
                print "".join(col[sites.names.index(name)]
                              for name in tree.leaf_names())
                print split
                print
    print "num noncompats", len(noncompats)
示例#8
0
def _test_prog_infsites():

    make_clean_dir("test/tmp/test_prog_infsites")

    run_cmd("""bin/arg-sim \
        -k 40 -L 200000 \
        -N 1e4 -r 1.5e-8 -m 2.5e-8 --infsites \
        --ntimes 20 --maxtime 400e3 \
        -o test/tmp/test_prog_infsites/0""")

    make_clean_dir("test/tmp/test_prog_infsites/0.sample")
    run_cmd("""bin/arg-sample \
        -s test/tmp/test_prog_infsites/0.sites \
        -N 1e4 -r 1.5e-8 -m 2.5e-8 \
        --ntimes 5 --maxtime 100e3 -c 1 \
        --climb 0 -n 20 --infsites \
        -x 1 \
        -o test/tmp/test_prog_infsites/0.sample/out""")

    arg = argweaver.read_arg(
        "test/tmp/test_prog_infsites/0.sample/out.0.smc.gz")
    sites = argweaver.read_sites("test/tmp/test_prog_infsites/0.sites")
    print "names", sites.names
    print

    noncompats = []
    for block, tree in arglib.iter_local_trees(arg):
        tree = tree.get_tree()
        treelib.remove_single_children(tree)
        phylo.hash_order_tree(tree)
        for pos, col in sites.iter_region(block[0] + 1, block[1] + 1):
            assert block[0] + 1 <= pos <= block[1] + 1, (block, pos)
            split = sites_split(sites.names, col)
            node = arglib.split_to_tree_branch(tree, split)
            if node is None:
                noncompats.append(pos)
                print "noncompat", block, pos, col
                print phylo.hash_tree(tree)
                print tree.leaf_names()
                print "".join(col[sites.names.index(name)]
                              for name in tree.leaf_names())
                print split
                print
    print "num noncompats", len(noncompats)
示例#9
0
    def eval_proposal(self, proposal):
        """Compute probability of proposal"""

        # compute recon probability
        phylo.add_implied_spec_nodes(proposal["locus_tree"], self.stree,
                                     proposal["locus_recon"],
                                     proposal["locus_events"])
        p = prob_dlcoal_recon_topology(self.coal_tree,
                                       proposal["coal_recon"],
                                       proposal["locus_tree"],
                                       proposal["locus_recon"],
                                       proposal["locus_events"],
                                       proposal["daughters"],
                                       self.stree, self.n,
                                       self.duprate, self.lossrate,
                                       self.pretime, self.premean,
                                       maxdoom=self.maxdoom,
                                       nsamples=self.nsamples,
                                       add_spec=False)
        treelib.remove_single_children(proposal["locus_tree"])
        phylo.subset_recon(proposal["locus_tree"], proposal["locus_recon"])

        return p
示例#10
0
def sample_birth_death_gene_tree(stree,
                                 birth,
                                 death,
                                 genename=lambda sp, x: sp + "_" + str(x),
                                 removeloss=True):
    """Simulate a gene tree within a species tree with birth and death rates"""

    # initialize gene tree
    tree = treelib.Tree()
    tree.make_root()
    recon = {tree.root: stree.root}
    events = {tree.root: "spec"}
    losses = set()

    def walk(snode, node):
        if snode.is_leaf():
            tree.rename(node.name, genename(snode.name, node.name))
            events[node] = "gene"
        else:
            for child in snode:
                # determine if loss will occur
                tree2, doom = sample_birth_death_tree(child.dist,
                                                      birth,
                                                      death,
                                                      tree=tree,
                                                      node=node,
                                                      keepdoom=True)

                # record reconciliation
                next_nodes = []

                def walk2(node):
                    node.recurse(walk2)
                    recon[node] = child
                    if node in doom:
                        losses.add(node)
                        events[node] = "gene"
                    elif node.is_leaf():
                        events[node] = "spec"
                        next_nodes.append(node)
                    else:
                        events[node] = "dup"

                walk2(node.children[-1])

                # recurse
                for leaf in next_nodes:
                    walk(child, leaf)

            # if no child for node then it is a loss
            if node.is_leaf():
                losses.add(node)

    walk(stree.root, tree.root)

    # remove lost nodes
    if removeloss:
        treelib.remove_exposed_internal_nodes(tree,
                                              set(tree.leaves()) - losses)
        treelib.remove_single_children(tree, simplify_root=False)

        delnodes = set()
        for node in recon:
            if node.name not in tree.nodes:
                delnodes.add(node)
        for node in delnodes:
            del recon[node]
            del events[node]

    if len(tree.nodes) <= 1:
        tree.nodes = {tree.root.name: tree.root}
        recon = {tree.root: stree.root}
        events = {tree.root: "spec"}

    return tree, recon, events
示例#11
0
def sample_dlcoal(stree, n, duprate, lossrate,
                  leaf_counts=None,
                  namefunc=lambda x: x,
                  remove_single=True, name_internal="n",
                  minsize=0, reject=False):
    """Sample a gene tree from the DLCoal model"""

    # generate the locus tree
    while True:
        # TODO: does this take a namefunc?
        locus_tree, locus_recon, locus_events = \
                    birthdeath.sample_birth_death_gene_tree(
                        stree, duprate, lossrate)
        if len(locus_tree.leaves()) >= minsize:
            break

    # if n is a dict, update it with gene names from locus tree
    if isinstance(n, dict):
        n2 = {}
        for node, snode in locus_recon.iteritems():
            n2[node.name] = n[snode.name]
    else:
        n2 = n

    # if leaf_counts is a dict, update it with gene names from locus tree
    # TODO: how to handle copy number polymorphism?
    if isinstance(leaf_counts, dict):
        leaf_counts2 = {}
        for node in locus_tree.leaves():
            snode = locus_recon[node]
            leaf_counts2[node.name] = leaf_counts[snode.name]
    else:
        leaf_counts2 = leaf_counts
        
    if len(locus_tree.nodes) <= 1:
        # total extinction
        coal_tree = treelib.Tree()
        coal_tree.make_root()
        coal_recon = {coal_tree.root: locus_tree.root}
        daughters = set()
    else:
        # simulate coalescence
        
        # choose daughter duplications
        daughters = set()
        for node in locus_tree:
            if locus_events[node] == "dup":
                daughters.add(node.children[random.randint(0, 1)])

        if reject:
            # use slow rejection sampling (for testing)
            coal_tree, coal_recon = sample_multilocus_tree_reject(
                locus_tree, n2, leaf_counts=leaf_counts2, daughters=daughters, namefunc=namefunc)
        else:
            coal_tree, coal_recon = sample_multilocus_tree(
                locus_tree, n2, leaf_counts=leaf_counts2, daughters=daughters, namefunc=namefunc)

        # clean up coal tree
        if remove_single:
            treelib.remove_single_children(coal_tree)
            phylo.subset_recon(coal_tree, coal_recon)

    if name_internal:
        dlcoal.rename_nodes(coal_tree, name_internal)
        dlcoal.rename_nodes(locus_tree, name_internal)

    # store extra information
    extra = {"locus_tree": locus_tree,
             "locus_recon": locus_recon,
             "locus_events": locus_events,
             "coal_tree": coal_tree,
             "coal_recon": coal_recon,
             "daughters": daughters}

    return coal_tree, extra
示例#12
0
def dlcoal_recon_old(tree, stree, gene2species,
                 n, duprate, lossrate,
                 pretime=None, premean=None,
                 nsearch=1000,
                 maxdoom=20, nsamples=100,
                 search=phylo.TreeSearchNni):
    """
    Perform reconciliation using the DLCoal model

    Returns (maxp, maxrecon) where 'maxp' is the probability of the
    MAP reconciliation 'maxrecon' which further defined as

    maxrecon = {'coal_recon': coal_recon,
                'locus_tree': locus_tree,
                'locus_recon': locus_recon,
                'locus_events': locus_events,
                'daughters': daughters}
    
    """

    # init coal tree
    coal_tree = tree

    # init locus tree as congruent to coal tree
    # equivalent to assuming no ILS
    locus_tree = coal_tree.copy()

    maxp = - util.INF
    maxrecon = None

    # init search
    locus_search = search(locus_tree)

    for i in xrange(nsearch):       
        # TODO: propose other reconciliations beside LCA
        locus_tree2 = locus_tree.copy()
        phylo.recon_root(locus_tree2, stree, gene2species, newCopy=False)
        locus_recon = phylo.reconcile(locus_tree2, stree, gene2species)
        locus_events = phylo.label_events(locus_tree2, locus_recon)

        # propose daughters (TODO)
        daughters = set()

        # propose coal recon (TODO: propose others beside LCA)
        coal_recon = phylo.reconcile(coal_tree, locus_tree2, lambda x: x)

        # compute recon probability
        phylo.add_implied_spec_nodes(locus_tree2, stree,
                                     locus_recon, locus_events)
        p = prob_dlcoal_recon_topology(coal_tree, coal_recon,
                                       locus_tree2, locus_recon, locus_events,
                                       daughters,
                                       stree, n, duprate, lossrate,
                                       pretime, premean,
                                       maxdoom=maxdoom, nsamples=nsamples,
                                       add_spec=False)
        treelib.remove_single_children(locus_tree2)

        if p > maxp:
            maxp = p
            maxrecon = {"coal_recon": coal_recon,
                        "locus_tree": locus_tree2,
                        "locus_recon": locus_recon,
                        "locus_events": locus_events,
                        "daughters": daughters}
            locus_tree = locus_tree2.copy()
            locus_search.set_tree(locus_tree)
        else:
            locus_search.revert()

        # perform local rearrangement to locus tree
        locus_search.propose()




    return maxp, maxrecon
示例#13
0
def sample_birth_death_gene_tree(stree, birth, death, 
                                 genename=lambda sp, x: sp + "_" + str(x),
                                 removeloss=True):
    """Simulate a gene tree within a species tree with birth and death rates"""
    
    # initialize gene tree
    tree = treelib.Tree()
    tree.make_root()
    recon = {tree.root: stree.root}
    events = {tree.root: "spec"}
    losses = set()
    
    def walk(snode, node):
        if snode.is_leaf():
            tree.rename(node.name, genename(snode.name, node.name))
            events[node] = "gene"
        else:
            for child in snode:
                # determine if loss will occur
                tree2, doom = sample_birth_death_tree(
                    child.dist, birth, death, 
                    tree=tree, node=node, keepdoom=True)
                
                # record reconciliation
                next_nodes = []
                def walk2(node):
                    node.recurse(walk2)
                    recon[node] = child
                    if node in doom:
                        losses.add(node)
                        events[node] = "gene"
                    elif node.is_leaf():
                        events[node] = "spec"
                        next_nodes.append(node)
                    else:
                        events[node] = "dup"
                walk2(node.children[-1])
                
                # recurse
                for leaf in next_nodes:
                    walk(child, leaf)
            
            # if no child for node then it is a loss
            if node.is_leaf():
                losses.add(node)
    walk(stree.root, tree.root)
    
    
    # remove lost nodes
    if removeloss:
        treelib.remove_exposed_internal_nodes(tree,
                                              set(tree.leaves()) - losses)
        treelib.remove_single_children(tree, simplify_root=False)
        
        delnodes = set()
        for node in recon:
            if node.name not in tree.nodes:
                delnodes.add(node)
        for node in delnodes:
            del recon[node]
            del events[node]

    if len(tree.nodes) <= 1:
        tree.nodes = {tree.root.name : tree.root}
        recon = {tree.root: stree.root}
        events = {tree.root: "spec"}
    
    return tree, recon, events
示例#14
0
def labeledrecon_to_recon(gene_tree, labeled_recon, stree,
                          name_internal="n"):
    """Convert from DLCpar to DLCoal reconciliation model

    NOTE: This is non-reversible because it produces NON-dated coalescent and locus trees
    """

    locus_map = labeled_recon.locus_map
    species_map = labeled_recon.species_map
    order = labeled_recon.order

    # coalescent tree equals gene tree
    coal_tree = gene_tree.copy()

    # factor gene tree
    events = phylo.label_events(gene_tree, species_map)
    subtrees = factor_tree(gene_tree, stree, species_map, events)

    # gene names
    genenames = {}
    for snode in stree:
        genenames[snode] = {}
    for leaf in gene_tree.leaves():
        genenames[species_map[leaf]][locus_map[leaf]] = leaf.name

    # 2D dict to keep track of locus tree nodes by hashing by speciation node and locus
    # key1 = snode, key2 = locus, value = list of nodes (sorted from oldest to most recent)
    locus_tree_map = {}
    for snode in stree:
        locus_tree_map[snode] = {}

    # initialize locus tree, coal/locus recon, and daughters
    locus_tree = treelib.Tree()
    coal_recon = {}
    locus_recon = {}
    locus_events = {}
    daughters = []

    # initialize root of locus tree
    root = treelib.TreeNode(locus_tree.new_name())
    locus_tree.add(root)
    locus_tree.root = root
    sroot = species_map[gene_tree.root]
    locus = locus_map[gene_tree.root]
    coal_recon[coal_tree.root] = root
    locus_recon[root] = sroot
    locus_tree_map[sroot][locus] = [root]

    # build locus tree along each species branch
    for snode in stree.preorder(sroot):
        subtrees_snode = subtrees[snode]

        # skip if no branches in this species branch
        if len(subtrees_snode) == 0:
            continue

        # build locus tree
        # 1) speciation
        if snode.parent:
            for (root, rootchild, leaves) in subtrees_snode:
                if rootchild:
                    locus = locus_map[root]     # use root locus!

                    # create new locus tree node in this species branch
                    if locus not in locus_tree_map[snode]:
                        old_node = locus_tree_map[snode.parent][locus][-1]

                        new_node = treelib.TreeNode(locus_tree.new_name())
                        locus_tree.add_child(old_node, new_node)
                        locus_recon[new_node] = snode

                        locus_events[old_node] = "spec"

                        locus_tree_map[snode][locus] = [new_node]

                    # update coal_recon
                    cnode = coal_tree.nodes[rootchild.name]
                    lnode = locus_tree_map[snode][locus][-1]
                    coal_recon[cnode] = lnode

        # 2) duplication
        if snode in order:
            # may have to reorder loci (in case of multiple duplications)
            queue = collections.deque(order[snode].keys())
            while len(queue) > 0:
                plocus = queue.popleft()
                if plocus not in locus_tree_map[snode]:
                    # punt
                    queue.append(plocus)
                    continue

                # handle this ordered list
                lst =  order[snode][plocus]
                for gnode in lst:
                    locus = locus_map[gnode]
                    cnode = coal_tree.nodes[gnode.name]

                    if locus != plocus:     # duplication
                        # update locus_tree, locus_recon, and daughters
                        old_node = locus_tree_map[snode][plocus][-1]

                        new_node1 = treelib.TreeNode(locus_tree.new_name())
                        locus_tree.add_child(old_node, new_node1)
                        locus_recon[new_node1] = snode

                        new_node2 = treelib.TreeNode(locus_tree.new_name())
                        locus_tree.add_child(old_node, new_node2)
                        coal_recon[cnode] = new_node2
                        locus_recon[new_node2] = snode
                        daughters.append(new_node2)

                        locus_events[old_node] = "dup"

                        locus_tree_map[snode][plocus].append(new_node1)
                        locus_tree_map[snode][locus] = [new_node2]

                    else:                   # deep coalescence
                        lnode = locus_tree_map[snode][locus][-1]
                        coal_recon[cnode] = lnode

        # reconcile remaining coal tree nodes to locus tree
        # (no duplication so only a single locus tree node with the desired locus)
        for (root, rootchild, leaves) in subtrees_snode:
            if rootchild:
                for gnode in gene_tree.preorder(rootchild, is_leaf=lambda x: x in leaves):
                    cnode = coal_tree.nodes[gnode.name]
                    if cnode not in coal_recon:
                        locus = locus_map[gnode]
                        assert len(locus_tree_map[snode][locus]) == 1
                        lnode = locus_tree_map[snode][locus][-1]
                        coal_recon[cnode] = lnode

        # tidy up if at an extant species
        if snode.is_leaf():
            for locus, nodes in locus_tree_map[snode].iteritems():
                genename = genenames[snode][locus]
                lnode = nodes[-1]
                cnode = coal_tree.nodes[genename]

                # relabel genes in locus tree
                locus_tree.rename(lnode.name, genename)

                # relabel locus events
                locus_events[lnode] = "gene"

                # reconcile genes (genes in coal tree reconcile to genes in locus tree)
                # possible mismatch due to genes having an internal ordering even though all exist to present time
                # [could also do a new round of "speciation" at bottom of extant species branches,
                # but this introduces single children nodes that would just be removed anyway]
                coal_recon[cnode] = lnode

    # rename internal nodes
    common.rename_nodes(locus_tree, name_internal)

    # simplify coal_tree (and reconciliations)
    removed = treelib.remove_single_children(coal_tree)
    for cnode in removed:
        del coal_recon[cnode]

    # simplify locus_tree (and reconciliations + daughters)
    removed = treelib.remove_single_children(locus_tree)
    for cnode, lnode in coal_recon.items():
        if lnode in removed:
            # reconciliation updates to first child that is not removed
            new_lnode = lnode
            while new_lnode in removed:
                new_lnode = new_lnode.children[0]
            coal_recon[cnode] = new_lnode
    for lnode in removed:
        del locus_recon[lnode]
        del locus_events[lnode]
    for ndx, lnode in enumerate(daughters):
        if lnode in removed:
            # daughter updates to first child that is not removed
            new_lnode = lnode
            while new_lnode in removed:
                new_lnode = new_lnode.children[0]
            daughters[ndx] = new_lnode

##    locus_events = phylo.label_events(locus_tree, locus_recon)
    assert all([lnode in locus_events for lnode in locus_tree])

    #========================================
    # put everything together

    return coal_tree, phyloDLC.Recon(coal_recon, locus_tree, locus_recon, locus_events, daughters)
示例#15
0
def dup_loss_topology_prior(tree,
                            stree,
                            recon,
                            birth,
                            death,
                            maxdoom=20,
                            events=None):
    """
    Returns the log prior of a gene tree topology according to dup-loss model
    """
    def gene2species(gene):
        return recon[tree.nodes[gene]].name

    if events is None:
        events = phylo.label_events(tree, recon)
    leaves = set(tree.leaves())
    phylo.add_implied_spec_nodes(tree, stree, recon, events)

    pstree, snodes, snodelookup = spidir.make_ptree(stree)

    # get doomtable
    doomtable = calc_doom_table(stree, birth, death, maxdoom)

    prod = 0.0
    for node in tree:
        if events[node] == "spec":
            for schild in recon[node].children:
                nodes2 = [x for x in node.children if recon[x] == schild]
                if len(nodes2) > 0:
                    node2 = nodes2[0]
                    subleaves = get_sub_tree(node2, schild, recon, events)
                    nhist = birthdeath.num_topology_histories(node2, subleaves)
                    s = len(subleaves)
                    thist = stats.factorial(s) * stats.factorial(s - 1) / 2**(
                        s - 1)

                    if len(set(subleaves) & leaves) == 0:
                        # internal
                        prod += log(
                            num_redundant_topology(node2, gene2species,
                                                   subleaves, True))
                    else:
                        # leaves
                        prod += log(
                            num_redundant_topology(node2, gene2species,
                                                   subleaves, False))

                else:
                    nhist = 1.0
                    thist = 1.0
                    s = 0

                t = sum(
                    stats.choose(s + i, i) * birthdeath.prob_birth_death1(
                        s + i, schild.dist, birth, death) *
                    exp(doomtable[snodelookup[schild]])**i
                    for i in range(maxdoom + 1))

                prod += log(nhist) - log(thist) + log(t)

    # correct for renumbering
    nt = num_redundant_topology(tree.root, gene2species)
    prod -= log(nt)

    #phylo.removeImpliedSpecNodes(tree, recon, events)
    treelib.remove_single_children(tree)

    return prod
示例#16
0
def sample_dlcoal_hem(stree, n, duprate, lossrate,
                      freq, freqdup, freqloss, steptime,
                      namefunc=lambda x: x,
                      keep_extinct=False,
                      remove_single=True,
                      name_internal="n", minsize=0):
    """Sample a gene tree from the DLCoal model with hemiplasy"""

    # generate the locus tree
    while True:
        locus_tree, locus_extras = sample_locus_tree_hem(
            stree, n, duprate, lossrate, 
            freq, freqdup, freqloss, steptime,
            keep_extinct=keep_extinct)
        if len(locus_tree.leaves()) >= minsize:
            break

    if len(locus_tree.nodes) <= 1: # TODO: check 1 value
        # total extinction
        coal_tree = treelib.Tree()
        coal_tree.make_root()
        coal_recon = {coal_tree.root: locus_tree.root}
        daughters = set()
    else:
        # simulate coalescence
        
        # create new (expanded) locus tree
        logged_locus_tree, logged_extras = locus_to_logged_tree(
            locus_tree, popsize=n)
        daughters = logged_extras[0]
        pops = logged_extras[1]
        log_recon = logged_extras[2]
        
        #treelib.assert_tree(logged_locus_tree)
        
        # removed locus_tree_copy from below
        coal_tree, coal_recon = dlcoal.sim.sample_multilocus_tree(
            logged_locus_tree, n=pops, daughters=daughters,
            namefunc=lambda lognamex: log_recon[lognamex]+'_'+str(lognamex))

        #print set(coal_tree) - set(coal_tree.postorder())
        treelib.assert_tree(coal_tree)
    
        # clean up coal tree
        if remove_single:
            treelib.remove_single_children(coal_tree)
            phylo.subset_recon(coal_tree, coal_recon)


    if name_internal:
        dlcoal.rename_nodes(coal_tree, name_internal)
        dlcoal.rename_nodes(locus_tree, name_internal)


    # store extra information
    extra = {"locus_tree": locus_tree,
             "locus_recon": locus_extras['recon'],
             "locus_events": locus_extras['events'],
             "coal_tree": coal_tree,
             "coal_recon": coal_recon,
             "daughters": daughters}

    if keep_extinct:
        extra["full_locus_tree"] = locus_extras["full_locus_tree"]

    return coal_tree, extra