Esempio n. 1
0
def dup_loss_topology_prior(tree, stree, recon, birth, death, maxdoom=20, events=None):
    """
    Returns the log prior of a gene tree topology according to dup-loss model
    """

    def gene2species(gene):
        return recon[tree.nodes[gene]].name

    if events is None:
        events = phylo.label_events(tree, recon)
    leaves = set(tree.leaves())
    phylo.add_implied_spec_nodes(tree, stree, recon, events)

    pstree, snodes, snodelookup = spidir.make_ptree(stree)

    # get doomtable
    doomtable = calc_doom_table(stree, birth, death, maxdoom)

    prod = 0.0
    for node in tree:
        if events[node] == "spec":
            for schild in recon[node].children:
                nodes2 = [x for x in node.children if recon[x] == schild]
                if len(nodes2) > 0:
                    node2 = nodes2[0]
                    subleaves = get_sub_tree(node2, schild, recon, events)
                    nhist = birthdeath.num_topology_histories(node2, subleaves)
                    s = len(subleaves)
                    thist = stats.factorial(s) * stats.factorial(s - 1) / 2 ** (s - 1)

                    if len(set(subleaves) & leaves) == 0:
                        # internal
                        prod += log(num_redundant_topology(node2, gene2species, subleaves, True))
                    else:
                        # leaves
                        prod += log(num_redundant_topology(node2, gene2species, subleaves, False))

                else:
                    nhist = 1.0
                    thist = 1.0
                    s = 0

                t = sum(
                    stats.choose(s + i, i)
                    * birthdeath.prob_birth_death1(s + i, schild.dist, birth, death)
                    * exp(doomtable[snodelookup[schild]]) ** i
                    for i in range(maxdoom + 1)
                )

                prod += log(nhist) - log(thist) + log(t)

    # correct for renumbering
    nt = num_redundant_topology(tree.root, gene2species)
    prod -= log(nt)

    # phylo.removeImpliedSpecNodes(tree, recon, events)
    treelib.remove_single_children(tree)

    return prod
Esempio n. 2
0
def prob_coal_recon_topology(tree, recon, locus_tree, n, daughters):
    """
    Returns the log probability of a reconciled gene tree ('tree', 'recon')
    from the coalescent model given a locus_tree 'locus_tree',
    population sizes 'n', and daughters set 'daughters'
    """

    # init population sizes
    popsizes = coal.init_popsizes(locus_tree, n)

    # log probability
    lnp = 0.0

    nodes = set(tree.postorder())

    # init reverse reconciliation
    rev_recon = {}
    for node, snode in recon.iteritems():
        if node not in nodes:
            raise Exception("node '%s' not in tree" % node.name)
        rev_recon.setdefault(snode, []).append(node)

    # init lineage counts
    lineages = {}
    for snode in locus_tree:
        if snode.is_leaf():
            lineages[snode] = len([x for x in rev_recon[snode]
                                   if x.is_leaf()])
        else:
            lineages[snode] = 0

    # iterate through species tree branches
    for snode in locus_tree.postorder():
        if snode.parent:
            # non root branch
            u = lineages[snode]

            # subtract number of coals in branch
            v = u - len([x for x in rev_recon.get(snode, [])
                         if not x.is_leaf()])            
            lineages[snode.parent] += v

            if snode not in daughters:
                try:
                    lnp += util.safelog(
                        coal.prob_coal_counts(u, v, snode.dist,
                                              popsizes[snode.name]))
                except:
                    print u, v, snode.dist, popsizes[snode.name]
                    raise
            else:
                assert v == 1
            lnp -= util.safelog(coal.num_labeled_histories(u, v))
        else:
            # normal coalesent
            u = lineages[snode]
            lnp -= util.safelog(coal.num_labeled_histories(u, 1))

    
    # correct for topologies H(T)
    # find connected subtrees that are in the same species branch
    subtrees = []
    subtree_root = {}
    for node in tree.preorder():
        if node.parent and recon[node] == recon[node.parent]:
            subtree_root[node] = subtree_root[node.parent]
        else:
            subtrees.append(node)
            subtree_root[node] = node

    # find leaves through recursion
    def walk(node, subtree, leaves):
        if node.is_leaf():
            leaves.append(node)
        elif (subtree_root[node.children[0]] != subtree and
              subtree_root[node.children[1]] != subtree):
            leaves.append(node)
        else:
            for child in node.children:
                walk(child, subtree, leaves)

    # apply correction for each subtree
    for subtree in subtrees:
        leaves = []
        for child in subtree.children:
            walk(subtree, subtree, leaves)
        if len(leaves) > 2:
            lnp += util.safelog(
                birthdeath.num_topology_histories(subtree, leaves))

    return lnp
Esempio n. 3
0
def dup_loss_topology_prior(tree,
                            stree,
                            recon,
                            birth,
                            death,
                            maxdoom=20,
                            events=None):
    """
    Returns the log prior of a gene tree topology according to dup-loss model
    """
    def gene2species(gene):
        return recon[tree.nodes[gene]].name

    if events is None:
        events = phylo.label_events(tree, recon)
    leaves = set(tree.leaves())
    phylo.add_implied_spec_nodes(tree, stree, recon, events)

    pstree, snodes, snodelookup = spidir.make_ptree(stree)

    # get doomtable
    doomtable = calc_doom_table(stree, birth, death, maxdoom)

    prod = 0.0
    for node in tree:
        if events[node] == "spec":
            for schild in recon[node].children:
                nodes2 = [x for x in node.children if recon[x] == schild]
                if len(nodes2) > 0:
                    node2 = nodes2[0]
                    subleaves = get_sub_tree(node2, schild, recon, events)
                    nhist = birthdeath.num_topology_histories(node2, subleaves)
                    s = len(subleaves)
                    thist = stats.factorial(s) * stats.factorial(s - 1) / 2**(
                        s - 1)

                    if len(set(subleaves) & leaves) == 0:
                        # internal
                        prod += log(
                            num_redundant_topology(node2, gene2species,
                                                   subleaves, True))
                    else:
                        # leaves
                        prod += log(
                            num_redundant_topology(node2, gene2species,
                                                   subleaves, False))

                else:
                    nhist = 1.0
                    thist = 1.0
                    s = 0

                t = sum(
                    stats.choose(s + i, i) * birthdeath.prob_birth_death1(
                        s + i, schild.dist, birth, death) *
                    exp(doomtable[snodelookup[schild]])**i
                    for i in range(maxdoom + 1))

                prod += log(nhist) - log(thist) + log(t)

    # correct for renumbering
    nt = num_redundant_topology(tree.root, gene2species)
    prod -= log(nt)

    #phylo.removeImpliedSpecNodes(tree, recon, events)
    treelib.remove_single_children(tree)

    return prod