Example #1
0
def count_dup_loss_coal_tree(gene_tree, extra, stree, gene2species,
                             implied=True):
    """count dup loss coal"""

    ndup = 0
    nloss = 0
    ncoal = 0
    nappear = 0

    # use stree to modify internal locus map and order
    new_srecon = util.mapdict(extra["species_map"], val=lambda snode: stree.nodes[snode.name])
    new_order = util.mapdict(extra["order"], key=lambda snode: stree.nodes[snode.name])

    srecon = new_srecon
    order = new_order
    extra = extra.copy()
    extra["species_map"] = srecon
    extra["order"] = order

    # count appearance
    snode = stree.nodes[srecon[gene_tree.root].name]
    snode.data["appear"] += 1
    nappear += 1

    # factor gene tree
    events = phylo.label_events(gene_tree, srecon)
    subtrees = factor_tree(gene_tree, stree, srecon, events)

    # count events along each species branch
    for snode in stree:
        subtrees_snode = subtrees[snode]
        if len(subtrees_snode) == 0:
            continue

        # count genes
        if snode.is_leaf():
            for (root, rootchild, leaves) in subtrees_snode:
                if leaves is not None:
                    snode.data["genes"] += len(leaves)

        # count dups
        ndup_snode = count_dup_snode(gene_tree, stree, extra, snode,
                                     subtrees, subtrees_snode)
        snode.data["dup"] += ndup_snode
        ndup += ndup_snode

        # count losses
        nloss_snode = count_loss_snode(gene_tree, stree, extra, snode,
                                       subtrees, subtrees_snode)
        snode.data["loss"] += nloss_snode
        nloss += nloss_snode

        # count deep coalescence (extra lineages)
        ncoal_snode = count_coal_snode(gene_tree, stree, extra, snode,
                                       subtrees, subtrees_snode,
                                       implied=implied)
        snode.data["coal"] += ncoal_snode
        ncoal += ncoal_snode

    return ndup, nloss, ncoal, nappear
Example #2
0
    def next_proposal(self):        
        self.locus_search.propose()
        
        # TODO: propose other reconciliations beside LCA
        locus_tree = self.locus_search.get_tree().copy()
        phylo.recon_root(locus_tree, self.reconer.stree,
                         self.reconer.gene2species,
                         newCopy=False)
        locus_recon = phylo.reconcile(locus_tree, self.reconer.stree,
                                      self.reconer.gene2species)
        locus_events = phylo.label_events(locus_tree, locus_recon)

        # propose daughters (TODO)
        daughters = set()

        # propose coal recon (TODO: propose others beside LCA)
        coal_recon = phylo.reconcile(self.reconer.coal_tree,
                                     locus_tree, lambda x: x)

        recon = {"coal_recon": coal_recon,
                 "locus_tree": locus_tree,
                 "locus_recon": locus_recon,
                 "locus_events": locus_events,
                 "daughters": daughters}
        return recon
Example #3
0
def count_dup_loss_coal_tree(coal_tree, extra, stree, gene2species,
                             implied=True, locus_mpr=True):
    """count dup loss coal"""

    if not locus_mpr:
        raise Exception("not implemented")

    # TODO: use locus_recon and locus_events rather than MPR
    #       (currently, phylo.py reconciliation functions fail for non-MPR)
    locus_tree = extra["locus_tree"]
    locus_recon = phylo.reconcile(locus_tree, stree, gene2species)
    locus_events = phylo.label_events(locus_tree, locus_recon)
    coal_recon = extra["coal_recon"]

    ndup, nloss, nappear = phylo.count_dup_loss_tree(locus_tree, stree, gene2species,
                                                     locus_recon, locus_events)

    # add implied speciation nodes if desired
    # this must be added AFTER counting dups and losses since it affects loss inference
    if implied:
        added = phylo.add_implied_spec_nodes(locus_tree, stree, locus_recon, locus_events)

    # count coals
    ncoal = 0
    counts = coal.count_lineages_per_branch(coal_tree, coal_recon, locus_tree)
    for lnode, (count_bot, count_top) in counts.iteritems():
        n = max(count_top-1, 0)
        locus_recon[lnode].data['coal'] += n
        ncoal += n

    if implied:
        phylo.remove_implied_spec_nodes(locus_tree, added, locus_recon, locus_events)

    return ndup, nloss, ncoal, nappear
Example #4
0
def dup_loss_topology_prior(tree, stree, recon, birth, death, maxdoom=20, events=None):
    """
    Returns the log prior of a gene tree topology according to dup-loss model
    """

    def gene2species(gene):
        return recon[tree.nodes[gene]].name

    if events is None:
        events = phylo.label_events(tree, recon)
    leaves = set(tree.leaves())
    phylo.add_implied_spec_nodes(tree, stree, recon, events)

    pstree, snodes, snodelookup = spidir.make_ptree(stree)

    # get doomtable
    doomtable = calc_doom_table(stree, birth, death, maxdoom)

    prod = 0.0
    for node in tree:
        if events[node] == "spec":
            for schild in recon[node].children:
                nodes2 = [x for x in node.children if recon[x] == schild]
                if len(nodes2) > 0:
                    node2 = nodes2[0]
                    subleaves = get_sub_tree(node2, schild, recon, events)
                    nhist = birthdeath.num_topology_histories(node2, subleaves)
                    s = len(subleaves)
                    thist = stats.factorial(s) * stats.factorial(s - 1) / 2 ** (s - 1)

                    if len(set(subleaves) & leaves) == 0:
                        # internal
                        prod += log(num_redundant_topology(node2, gene2species, subleaves, True))
                    else:
                        # leaves
                        prod += log(num_redundant_topology(node2, gene2species, subleaves, False))

                else:
                    nhist = 1.0
                    thist = 1.0
                    s = 0

                t = sum(
                    stats.choose(s + i, i)
                    * birthdeath.prob_birth_death1(s + i, schild.dist, birth, death)
                    * exp(doomtable[snodelookup[schild]]) ** i
                    for i in range(maxdoom + 1)
                )

                prod += log(nhist) - log(thist) + log(t)

    # correct for renumbering
    nt = num_redundant_topology(tree.root, gene2species)
    prod -= log(nt)

    # phylo.removeImpliedSpecNodes(tree, recon, events)
    treelib.remove_single_children(tree)

    return prod
Example #5
0
def _subtree_helper(tree, stree, extra,
                    subtrees=None):
    """Returns a dictionary of subtrees for each species branch"""

    if subtrees is None:
        recon = extra["species_map"]
        events = phylo.label_events(tree, recon)
        subtrees = factor_tree(tree, stree, recon, events)
    return subtrees
 def compute_cost(self, gtree):
     """Returns the duplication-loss cost"""
     recon = phylo.reconcile(gtree, self.stree, self.gene2species)
     events = phylo.label_events(gtree, recon)
     cost = 0
     if self.dupcost != 0:
         cost += phylo.count_dup(gtree, events) * self.dupcost
     if self.losscost != 0:
         cost += phylo.count_loss(gtree, self.stree, recon) * self.losscost
     return cost
Example #7
0
    def prescreen(self, tree):

        recon = phylo.reconcile(tree, self.stree, self.gene2species)
        events = phylo.label_events(tree, recon)

        #print tree.root.name
        #treelib.draw_tree_names(tree, maxlen=8)
        
        return duploss.prob_dup_loss(
            tree, self.stree, recon, events,
            self.duprate, self.lossrate)
Example #8
0
    def _compute_coalcost(self, gtree, ltree):
        """Returns deep coalescent cost from coalescent tree (gene tree) to locus tree

        Note: uses Zhang (RECOMB 2000) result that C = L - 2*D
        """
        cost = 0
        if self.coalcost > 0:
            recon = phylo.reconcile(gtree, ltree)
            events = phylo.label_events(gtree, recon)
            cost = (phylo.count_loss(gtree, ltree, recon) - 2*phylo.count_dup(gtree, events)) * self.coalcost
        return cost
Example #9
0
 def _compute_duplosscost(self, ltree):
     """Returns dup/loss cost from locus tree to species tree"""
     cost = 0
     if self.dupcost > 0 or self.losscost > 0:
         recon = phylo.reconcile(ltree, self.stree, self.gene2species)
         events = phylo.label_events(ltree, recon)
         if self.dupcost != 0:
             cost += phylo.count_dup(ltree, events) * self.dupcost
         if self.losscost != 0:
             cost += phylo.count_loss(ltree, self.stree, recon) * self.losscost
     return cost
Example #10
0
 def setup_recon(self, recon=None):
     # construct default reconciliation
     if recon == None and self.stree and self.gene2species:
         self.recon = phylo.reconcile(self.tree, self.stree, self.gene2species)
     else:
         self.recon = recon
     
     # construct events
     if self.recon:
         self.events = phylo.label_events(self.tree, self.recon)
         self.losses = phylo.find_loss(self.tree, self.stree, self.recon)
     else:
         self.events = None
         self.losses = None
Example #11
0
    def _recon_lca(self, locus_tree):
        # get locus tree, and LCA (MPR) locus_recon
        locus_recon = phylo.reconcile(locus_tree, self._stree,
                                      self._gene2species)
        locus_events = phylo.label_events(locus_tree, locus_recon)

        # propose LCA (MPR) coal_recon
        coal_recon = phylo.reconcile(self._coal_tree, locus_tree, lambda x: x)

        # propose daughters
        daughters = self._propose_daughters(self._coal_tree, coal_recon,
                                            locus_tree, locus_recon,
                                            locus_events)

        return phyloDLC.Recon(coal_recon, locus_tree, locus_recon,
                              locus_events, daughters)
Example #12
0
    def recon(self):
        """Perform reconciliation"""

        self.log.start("Reconciling")

        # log input gene and species trees
        self.log.log("gene tree\n")
        log_tree(self.gtree, self.log, func=treelib.draw_tree_names)
        self.log.log("species tree\n")
        log_tree(self.stree, self.log, func=treelib.draw_tree_names)

        # infer species map
        self._infer_species_map()
        self.log.log("\n\n")

        # add implied speciation nodes but first start the species tree at the right root
        substree = treelib.subtree(self.stree, self.srecon[self.gtree.root])
        subrecon = util.mapdict(self.srecon,
                                val=lambda snode: substree.nodes[snode.name])

        # switch internal storage with subtrees
        self.stree, subtree = substree, self.stree
        self.srecon, subrecon = subrecon, self.srecon

        # add implied nodes (standard speciation, speciation from duplication, delay nodes)
        # then relabel events (so that factor_tree works)
        reconlib.add_implied_nodes(self.gtree,
                                   self.stree,
                                   self.srecon,
                                   self.sevents,
                                   delay=self.delay)
        self.sevents = phylo.label_events(self.gtree, self.srecon)
        common.rename_nodes(self.gtree, self.name_internal)

        # log gene tree (with species map)
        self.log.log("gene tree (with species map)\n")
        log_tree(self.gtree,
                 self.log,
                 func=draw_tree_srecon,
                 srecon=self.srecon)

        # infer locus map
        self._infer_locus_map()

        self.log.stop()

        return self.count_vectors
Example #13
0
    def prescreen(self, tree):
        recon = phylo.reconcile(tree, self.stree, self.gene2species)
        events = phylo.label_events(tree, recon)

        if self.dupcost == 0:
            dupcost = 0
        else:
            ndup = phylo.count_dup(tree, events)
            dupcost = ndup * self.dupcost

        if self.losscost == 0:
            losscost = 0
        else:
            nloss = phylo.count_loss(tree, self.stree, recon)
            losscost = nloss * self.losscost

        return dupcost + losscost
Example #14
0
    def _recon_lca(self, locus_tree):
        # get locus tree, and LCA (MPR) locus_recon
        locus_recon = phylo.reconcile(locus_tree, self._stree,
                                      self._gene2species)
        locus_events = phylo.label_events(locus_tree, locus_recon)

        # propose LCA (MPR) coal_recon
        coal_recon = phylo.reconcile(self._coal_tree,
                                     locus_tree, lambda x: x)

        # propose daughters
        daughters = self._propose_daughters(
            self._coal_tree, coal_recon,
            locus_tree, locus_recon, locus_events)

        return phyloDLC.Recon(coal_recon, locus_tree, locus_recon, locus_events,
                              daughters)
Example #15
0
    def prescreen(self, tree):
        recon = phylo.reconcile(tree, self.stree, self.gene2species)
        events = phylo.label_events(tree, recon)

        if self.dupcost == 0:
            dupcost = 0
        else:
            ndup = phylo.count_dup(tree, events)
            dupcost = ndup * self.dupcost

        if self.losscost == 0:
            losscost = 0
        else:
            nloss = phylo.count_loss(tree, self.stree, recon)
            losscost = nloss * self.losscost

        return dupcost + losscost
Example #16
0
    def compute_cost(self, gtree):
        """
        Returns -log [P(topology) + P(branch)],
        min cost = min neg log prob = max log prob = max prob
        """
        recon = phylo.reconcile(gtree, self.stree, self.gene2species)
        events = phylo.label_events(gtree, recon)

        # optimize branch lengths
        spidir.find_ml_branch_lengths_hky(gtree, self.align, self.bgfreq, self.kappa,
                                          maxiter=10, parsinit=False)

        branchp = spidir.branch_prior(gtree, self.stree, recon, events,
                                      self.params, self.duprate, self.lossrate, self.pretime)
        topp = spidir.calc_birth_death_prior(gtree, self.stree, recon,
                                             self.duprate, self.lossrate, events)
        return -(topp + branchp)
Example #17
0
def count_dup_loss_coal_tree(coal_tree,
                             extra,
                             stree,
                             gene2species,
                             implied=True,
                             locus_mpr=True):
    """count dup loss coal"""

    if not locus_mpr:
        raise Exception("not implemented")

    # TODO: use locus_recon and locus_events rather than MPR
    #       (currently, phylo.py reconciliation functions fail for non-MPR)
    locus_tree = extra["locus_tree"]
    locus_recon = phylo.reconcile(locus_tree, stree, gene2species)
    locus_events = phylo.label_events(locus_tree, locus_recon)
    coal_recon = extra["coal_recon"]

    ndup, nloss, nappear = phylo.count_dup_loss_tree(locus_tree, stree,
                                                     gene2species, locus_recon,
                                                     locus_events)

    # add implied speciation nodes if desired
    # this must be added AFTER counting dups and losses since it affects loss inference
    if implied:
        added = phylo.add_implied_spec_nodes(locus_tree, stree, locus_recon,
                                             locus_events)

    # count coals
    ncoal = 0
    counts = coal.count_lineages_per_branch(coal_tree, coal_recon, locus_tree)
    for lnode, (count_bot, count_top) in counts.iteritems():
        n = max(count_top - 1, 0)
        locus_recon[lnode].data['coal'] += n
        ncoal += n

    if implied:
        phylo.remove_implied_spec_nodes(locus_tree, added, locus_recon,
                                        locus_events)

    return ndup, nloss, ncoal, nappear
Example #18
0
    def recon(self):
        """Perform reconciliation"""

        self.log.start("Reconciling")

        # log input gene and species trees
        self.log.log("gene tree\n")
        log_tree(self.gtree, self.log, func=treelib.draw_tree_names)
        self.log.log("species tree\n")
        log_tree(self.stree, self.log, func=treelib.draw_tree_names)

        # infer species map
        self._infer_species_map()
        self.log.log("\n\n")

        # add implied speciation nodes but first start the species tree at the right root
        substree = treelib.subtree(self.stree, self.srecon[self.gtree.root])
        subrecon = util.mapdict(self.srecon, val=lambda snode: substree.nodes[snode.name])

        # switch internal storage with subtrees
        self.stree, subtree = substree, self.stree
        self.srecon, subrecon = subrecon, self.srecon

        # add implied nodes (standard speciation, speciation from duplication, delay nodes)
        # then relabel events (so that factor_tree works)
        reconlib.add_implied_nodes(self.gtree, self.stree, self.srecon, self.sevents, delay=self.delay)
        self.sevents = phylo.label_events(self.gtree, self.srecon)
        common.rename_nodes(self.gtree, self.name_internal)

        # log gene tree (with species map)
        self.log.log("gene tree (with species map)\n")
        log_tree(self.gtree, self.log, func=draw_tree_srecon, srecon=self.srecon)

        # infer locus map
        self._infer_locus_map()

        self.log.stop()

        return self.count_vectors
Example #19
0
def prob_dup_loss(tree, stree, recon, events, duprate, lossrate):
    """Returns the topology prior of a gene tree"""

    if dlcoal.dlcoalc:
        if events is None:
            events = phylo.label_events(tree, recon)

        ptree, nodes, nodelookup = dlcoal.make_ptree(tree)
        pstree, snodes, snodelookup = dlcoal.make_ptree(stree)

        ctree = dlcoal.tree2ctree(tree)
        cstree = dlcoal.tree2ctree(stree)
        recon2 = dlcoal.make_recon_array(tree, recon, nodes, snodelookup)
        events2 = dlcoal.make_events_array(nodes, events)

        doomtable = c_list(c_double, [0] * len(stree.nodes))
        dlcoal.dlcoalc.calcDoomTable(cstree, duprate, lossrate, doomtable)

        p = dlcoal.dlcoalc.birthDeathTreePriorFull(ctree, cstree,
                                    c_list(c_int, recon2), 
                                    c_list(c_int, events2),
                                    duprate, lossrate, doomtable)
        dlcoal.dlcoalc.deleteTree(ctree)
        dlcoal.dlcoalc.deleteTree(cstree)

        return p

    else:
        if "dlcoal_python_fallback" not in globals():
            print >>sys.stderr, "warning: using python code instead of native"
            globals()["dlcoal_python_fallback"] = 1
            # spidir libs
            import spidir
            from spidir import topology_prior
            
        return topology_prior.dup_loss_topology_prior(
            tree, stree, recon, duprate, lossrate,
            events=events)
Example #20
0
    def __init__(self,
                 stree,
                 locus_tree,
                 daughters,
                 gene2species,
                 search=phylo.TreeSearchNni,
                 num_coal_recons=1):
        self._stree = stree
        self._locus_tree = locus_tree
        self._daughters = daughters
        self._coal_search = search(None)

        # locus recon (static) -- propose LCA reconciliation
        self._locus_recon = phylo.reconcile(locus_tree, stree, gene2species)
        self._locus_events = phylo.label_events(locus_tree, self._locus_recon)

        # coal recon search
        self._num_coal_recons = num_coal_recons
        self._i_coal_recons = 1
        self._coal_recon_enum = None
        self._coal_recon_depth = 2
        self._accept_coal = False

        self._recon = None
Example #21
0
    def _recon_lca(self, locus_tree):
        # get locus tree, and LCA locus_recon
        locus_recon = phylo.reconcile(locus_tree, self._stree,
                                      self._gene2species)
        locus_events = phylo.label_events(locus_tree, locus_recon)

        # propose LCA coal_recon
        coal_recon = phylo.reconcile(self._coal_tree,
                                     locus_tree, lambda x: x)

        # propose daughters (TODO)
        daughters = self._propose_daughters(
            self._coal_tree, coal_recon,
            locus_tree, locus_recon, locus_events)


        self._coal_recon_enum = phylo.enum_recon(
            self._coal_tree, locus_tree,
            recon=coal_recon,
            depth=self._coal_recon_depth)


        return Recon(coal_recon, locus_tree, locus_recon, locus_events,
                     daughters)
Example #22
0
def draw_tree(tree,
              labels={},
              xscale=100,
              yscale=20,
              canvas=None,
              leafPadding=10,
              leafFunc=lambda x: str(x.name),
              labelOffset=None,
              fontSize=10,
              labelSize=None,
              minlen=1,
              maxlen=util.INF,
              filename=sys.stdout,
              rmargin=150,
              lmargin=10,
              tmargin=0,
              bmargin=None,
              colormap=None,
              stree=None,
              layout=None,
              gene2species=None,
              lossColor=(0, 0, 1),
              dupColor=(1, 0, 0),
              eventSize=4,
              legendScale=False,
              autoclose=None,
              extendRoot=True,
              labelLeaves=True,
              drawHoriz=True,
              nodeSize=0):

    # set defaults
    fontRatio = 8. / 11.

    if labelSize == None:
        labelSize = .7 * fontSize

    if labelOffset == None:
        labelOffset = -1

    if bmargin == None:
        bmargin = yscale

    if sum(x.dist for x in tree.nodes.values()) == 0:
        legendScale = False
        minlen = xscale

    if colormap == None:
        for node in tree:
            node.color = (0, 0, 0)
    else:
        colormap(tree)

    if stree and gene2species:
        recon = phylo.reconcile(tree, stree, gene2species)
        events = phylo.label_events(tree, recon)
        losses = phylo.find_loss(tree, stree, recon)
    else:
        events = None
        losses = None

    if len(labels) > 0 or (stree and gene2species):
        drawHoriz = True

    # layout tree
    if layout is None:
        coords = treelib.layout_tree(tree, xscale, yscale, minlen, maxlen)
    else:
        coords = layout

    xcoords, ycoords = zip(*coords.values())
    maxwidth = max(xcoords)
    maxheight = max(ycoords) + labelOffset

    # initialize canvas
    if canvas == None:
        canvas = svg.Svg(util.open_stream(filename, "w"))
        width = int(rmargin + maxwidth + lmargin)
        height = int(tmargin + maxheight + bmargin)

        canvas.beginSvg(width, height)

        if autoclose == None:
            autoclose = True
    else:
        if autoclose == None:
            autoclose = False

    # draw tree
    def walk(node):
        x, y = coords[node]
        if node.parent:
            parentx, parenty = coords[node.parent]
        else:
            if extendRoot:
                parentx, parenty = 0, y
            else:
                parentx, parenty = x, y  # e.g. no branch

        # draw branch
        if drawHoriz:
            canvas.line(parentx, y, x, y, color=node.color)
        else:
            canvas.line(parentx, parenty, x, y, color=node.color)

        # draw branch labels
        if node.name in labels:
            branchlen = x - parentx
            lines = str(labels[node.name]).split("\n")
            labelwidth = max(map(len, lines))
            labellen = min(labelwidth * fontRatio * fontSize,
                           max(int(branchlen - 1), 0))

            for i, line in enumerate(lines):
                canvas.text(
                    line, parentx + (branchlen - labellen) / 2.,
                    y + labelOffset + (-len(lines) + 1 + i) * (labelSize + 1),
                    labelSize)

        # draw nodes
        if nodeSize > 0:
            canvas.circle(x,
                          y,
                          nodeSize,
                          strokeColor=svg.null,
                          fillColor=node.color)

        # draw leaf labels or recur
        if node.is_leaf():
            if labelLeaves:
                canvas.text(leafFunc(node),
                            x + leafPadding,
                            y + fontSize / 2.,
                            fontSize,
                            fillColor=node.color)
        else:
            if drawHoriz:
                # draw vertical part of branch
                top = coords[node.children[0]][1]
                bot = coords[node.children[-1]][1]
                canvas.line(x, top, x, bot, color=node.color)

            # draw children
            for child in node.children:
                walk(child)

    canvas.beginTransform(("translate", lmargin, tmargin))
    walk(tree.root)

    if stree and gene2species:
        draw_events(canvas,
                    tree,
                    coords,
                    events,
                    losses,
                    lossColor=lossColor,
                    dupColor=dupColor,
                    size=eventSize)
    canvas.endTransform()

    # draw legend
    if legendScale:
        if legendScale == True:
            # automatically choose a scale
            length = maxwidth / float(xscale)
            order = math.floor(math.log10(length))
            length = 10**order

        drawScale(lmargin,
                  tmargin + maxheight + bmargin - fontSize,
                  length,
                  xscale,
                  fontSize,
                  canvas=canvas)

    if autoclose:
        canvas.endSvg()

    return canvas
Example #23
0
def sample_dup_times(tree,
                     stree,
                     recon,
                     birth,
                     death,
                     pretime=None,
                     premean=None,
                     events=None):
    """
    Sample duplication times for a gene tree in the dup-loss model

    NOTE: Implied speciation nodes must be present
    """
    def gene2species(gene):
        return recon[tree.nodes[gene]].name

    if events is None:
        events = phylo.label_events(tree, recon)

    # get species tree timestamps
    stimes = treelib.get_tree_timestamps(stree)
    #treelib.check_timestamps(stree, stimes)

    # init timestamps for gene tree
    times = {}

    # set pretimes
    if events[tree.root] != "spec":
        if recon[tree.root] != stree.root:
            # tree root is a dup within species tree
            snode = recon[tree.root]
            start_time = stimes[snode.parent]
            time_span = snode.dist

        if recon[tree.root] == stree.root:
            # tree root is a pre-spec dup
            if pretime is None:
                if premean is None:
                    raise Exception("must set pre-mean")

                pretime = 0.0
                while pretime == 0.0:
                    pretime = random.expovariate(1 / premean)
            start_time = stimes[stree.root] + pretime
            time_span = pretime

        sample_dup_times_subtree(times, start_time, time_span, tree.root,
                                 recon, events, stree, birth, death)

    # set times
    for node in tree.preorder():
        if events[node] == "spec":
            # set speciation time
            start_time = times[node] = stimes[recon[node]]
            if node.parent:
                if times[node] > times[node.parent]:
                    print "bad", node.name
                    #raise Exception("bad time")

            # set duplication times within duplication subtree
            for duproot in node.children:
                if events[duproot] == "dup":
                    snode = recon[duproot]
                    time_span = snode.dist

                    #assert start_time - time_span >= stimes[snode], \
                    #       (duproot.name, start_time, time_span, stimes[snode])
                    sample_dup_times_subtree(times, start_time, time_span,
                                             duproot, recon, events, stree,
                                             birth, death)
        elif events[node] == "gene":
            times[node] = 0.0

    return times
Example #24
0
def dlcoal_recon_old(tree, stree, gene2species,
                 n, duprate, lossrate,
                 pretime=None, premean=None,
                 nsearch=1000,
                 maxdoom=20, nsamples=100,
                 search=phylo.TreeSearchNni):
    """
    Perform reconciliation using the DLCoal model

    Returns (maxp, maxrecon) where 'maxp' is the probability of the
    MAP reconciliation 'maxrecon' which further defined as

    maxrecon = {'coal_recon': coal_recon,
                'locus_tree': locus_tree,
                'locus_recon': locus_recon,
                'locus_events': locus_events,
                'daughters': daughters}
    
    """

    # init coal tree
    coal_tree = tree

    # init locus tree as congruent to coal tree
    # equivalent to assuming no ILS
    locus_tree = coal_tree.copy()

    maxp = - util.INF
    maxrecon = None

    # init search
    locus_search = search(locus_tree)

    for i in xrange(nsearch):       
        # TODO: propose other reconciliations beside LCA
        locus_tree2 = locus_tree.copy()
        phylo.recon_root(locus_tree2, stree, gene2species, newCopy=False)
        locus_recon = phylo.reconcile(locus_tree2, stree, gene2species)
        locus_events = phylo.label_events(locus_tree2, locus_recon)

        # propose daughters (TODO)
        daughters = set()

        # propose coal recon (TODO: propose others beside LCA)
        coal_recon = phylo.reconcile(coal_tree, locus_tree2, lambda x: x)

        # compute recon probability
        phylo.add_implied_spec_nodes(locus_tree2, stree,
                                     locus_recon, locus_events)
        p = prob_dlcoal_recon_topology(coal_tree, coal_recon,
                                       locus_tree2, locus_recon, locus_events,
                                       daughters,
                                       stree, n, duprate, lossrate,
                                       pretime, premean,
                                       maxdoom=maxdoom, nsamples=nsamples,
                                       add_spec=False)
        treelib.remove_single_children(locus_tree2)

        if p > maxp:
            maxp = p
            maxrecon = {"coal_recon": coal_recon,
                        "locus_tree": locus_tree2,
                        "locus_recon": locus_recon,
                        "locus_events": locus_events,
                        "daughters": daughters}
            locus_tree = locus_tree2.copy()
            locus_search.set_tree(locus_tree)
        else:
            locus_search.revert()

        # perform local rearrangement to locus tree
        locus_search.propose()




    return maxp, maxrecon
#=============================================================================
# parse options

conf, args = o.parse_args()

#gene2species = phylo.read_gene2species(conf.smap)
stree = treelib1.read_tree(conf.stree)
tree = treelib1.read_tree(conf.tree)
if conf.names:
    snames = dict(util.read_delim(conf.names))
else:
    snames = None

if conf.brecon:

    brecon = phylo.read_brecon(conf.brecon, tree, stree)

elif conf.recon:
    recon, events = phylo.read_recon_events(conf.recon, tree, stree)
    brecon = phylo.recon_events2brecon(recon, events)

else:
    gene2species = phylo.read_gene2species(conf.smap)
    recon = phylo.reconcile(tree, stree, gene2species)
    events = phylo.label_events(tree, recon)
    brecon = phylo.recon_events2brecon(recon, events)

phylo.add_implied_spec_nodes_brecon(tree, brecon)

transsvg.draw_tree(tree, brecon, stree, filename=conf.output, snames=snames)
Example #26
0
def prob_locus_gene_species_alignment_recon(alnfile,
                                            partfile,
                                            stree,
                                            popsizes,
                                            duprate,
                                            lossrate,
                                            subrate,
                                            beta,
                                            pretime,
                                            premean,
                                            coal_tree,
                                            coal_recon,
                                            nsamples_coal,
                                            locus_tree,
                                            locus_recon,
                                            nsamples_locus,
                                            daughters,
                                            rates,
                                            freqs,
                                            alphas,
                                            threads=1,
                                            seed=ALIGNMENT_SEED,
                                            eps=0.1,
                                            info=None):
    """
    (Log) probability of the joint probability of locus_tree, locus_recon, coal_tree,
    coal_recon, daughters and alignment. Mathematically, it computes: 
    
    P(T^G, T^L, R^G, R^L, delta^L, A | S, theta) = P(delta^L | T^L, R^L, S) + P(T^L, R^L | S, theta^S) + 
    int int P(t^L | T^L, R^L, S, theta) * P(T^G, R^G, t^G | t^L, T^L, daughters, R^L, theta) * P(A | T^G, t^G) dt^L dt^G

    alnfile           -- alignment file
    partfile	      -- partition file
    stree	      -- species tree
    popsizes          -- population sizes in species tree
    duprate           -- duplication rate
    lossrate          -- loss rate
    subrate           -- substitution rate
    beta              -- regularization parameter
    pretime           -- starting time before species tree
    premean           -- mean starting time before species tree

    coal_tree         -- coalescent tree
    coal_recon        -- reconciliation of coalescent tree to locus tree
    nsamples_coal     -- number of times to sample coal times t^G
    locus_tree        -- locus tree (has dup-loss)
    locus_recon       -- reconciliation of locus tree to species tree
    nsamples_locus    -- number of times to sample the locus tree times t^L
    daughters         -- daughter nodes
    
    rates, freqs, alphas  -- optimization parameters  

    Note: Adapted from dlcoal.prob_dlcoal_recon_topology(...) [in __init.py]
    """

    # duploss proability: P(T^L, R^L | S, theta)
    locus_events = phylo.label_events(locus_tree, locus_recon)
    dl_prob = duploss.prob_dup_loss(locus_tree, stree, locus_recon,
                                    locus_events, duprate, lossrate)

    # daughters probability: P(daughters | T^L, R^L, S)
    dups = phylo.count_dup(locus_tree, locus_events)
    daughter_prob = dups * log(.5)

    # double integral
    double_integral = prob_gene_species_alignment_recon(alnfile,
                                                        partfile,
                                                        stree,
                                                        popsizes,
                                                        duprate,
                                                        lossrate,
                                                        subrate,
                                                        beta,
                                                        pretime,
                                                        premean,
                                                        coal_tree,
                                                        coal_recon,
                                                        nsamples_coal,
                                                        locus_tree,
                                                        locus_recon,
                                                        nsamples_locus,
                                                        daughters,
                                                        rates,
                                                        freqs,
                                                        alphas,
                                                        threads=1,
                                                        seed=ALIGNMENT_SEED,
                                                        eps=0.1,
                                                        info=None)

    return dl_prob + daughter_prob + double_integral
Example #27
0
def prob_gene_species_alignment_recon(alnfile,
                                      partfile,
                                      stree,
                                      popsizes,
                                      duprate,
                                      lossrate,
                                      subrate,
                                      beta,
                                      pretime,
                                      premean,
                                      coal_tree,
                                      coal_recon,
                                      nsamples_coal,
                                      locus_tree,
                                      locus_recon,
                                      nsamples_locus,
                                      daughters,
                                      rates,
                                      freqs,
                                      alphas,
                                      threads=1,
                                      seed=ALIGNMENT_SEED,
                                      eps=0.1,
                                      info=None):
    """
    Evaluate terms that depend on T^G and R^G.

    That is, fix T^L, R^L, and daughters and evaluate the double integral:
    int int P(t^L | T^L, R^L, S, theta) * P(T^G, R^G, t^G | t^L, T^L, daughters, R^L, theta) * P(A | T^G, t^G) dt^L dt^G

    This is the probability we used in the searching process. 

    alnfile           -- alignment file
    partfile          -- partition file
    stree             -- species tree
    popsizes          -- population sizes in species tree
    duprate           -- duplication rate
    lossrate          -- loss rate
    subrate           -- substitution rate
    beta              -- regularization parameter
    pretime           -- starting time before species tree
    premean           -- mean starting time before species tree

    coal_tree         -- coalescent tree
    coal_recon        -- reconciliation of coalescent tree to locus tree
    nsamples_coal     -- number of times to sample coal times t^G
    locus_tree        -- locus tree (has dup-loss)
    locus_recon       -- reconciliation of locus tree to species tree
    nsamples_locus    -- number of times to sample the locus tree times t^L
    daughters         -- daughter nodes
    
    rates, freqs, alphas  -- optimization parameters  

    """

    locus_events = phylo.label_events(locus_tree, locus_recon)

    # optimize the parameters
    # util.tic("optimize parameter")
    # rates, freqs, alphas = pllprob.optimize_parameters(alnfile, partfile, coal_tree,
    #                                                   threads=threads, seed=seed, eps=eps)
    # util.toc()
    # double integral
    double_integral_list = []
    double_integral = 0.0
    util.tic("recon prob")
    for i in xrange(nsamples_locus):

        # sample t^L, the unit should be in myr
        #util.tic("topo prob")
        locus_times = duploss.sample_dup_times(locus_tree,
                                               stree,
                                               locus_recon,
                                               duprate,
                                               lossrate,
                                               pretime,
                                               premean,
                                               events=locus_events)
        treelib.set_dists_from_timestamps(locus_tree, locus_times)

        # calculate P(T^G, R^G | T^L, t^L, daughters, theta)
        topology_prob = prob_locus_coal_recon_topology(coal_tree, coal_recon,
                                                       locus_tree, popsizes,
                                                       daughters)
        #util.toc()
        # for a fixed t^L, compute coal_prob
        # sample t^G for topology and compute the probabililty of observing the alignment using MonteCarlo integration
        coal_prob = 0.0
        alignment_prob_MonteCarlo = 0.0
        alignment_prob_list = []

        # check probability of lineage counts for this locus tree
        zero_lineage_prob = False

        #util.tic("set times")
        for lnode in locus_tree:
            lineages = coal.count_lineages_per_branch(coal_tree, coal_recon,
                                                      locus_tree)
            bottom_num, top_num = lineages[lnode]
            if lnode.parent:
                T = lnode.dist
            else:
                T = util.INF

            popsizes = popsizes
            lineage_prob = prob_coal_counts(bottom_num, top_num, T, popsizes)

            # set zero_lineage_prob = TRUE if one lineage returns zero probability
            if (lineage_prob == 0.0):
                zero_lineage_prob = True

        #util.toc()
        # if lineage_prob is zero, coal_prob is zero
        if zero_lineage_prob:
            coal_prob = -float("inf")

        # otherwise, we calculate the coal_prob
        else:
            for j in xrange(nsamples_coal):

                # sample coal times and set the coal_tree accordingly
                # locus tree branch lengths are in myr
                # make sure the input popsizes are scaled to fit the time unit (typically myr)

                try:
                    sample_coal_times_topology(coal_tree, coal_recon,
                                               locus_tree, popsizes)
                except (ZeroDivisionError, ValueError):
                    # bad sample
                    util.log("bad sample")
                    alignment_prob = -util.INF
                    continue

                #===============================================================================
                # (log) probability of observing the alignment
                #util.tic("alignment probability")

                # convert branch lengths from myr to sub/site
                for node in coal_tree:
                    node.dist *= subrate

                #util.tic("alignment prob")
                # set a regularization parameter beta
                print beta
                alignment_prob = beta * prob_alignment(alnfile,
                                                       partfile,
                                                       coal_tree,
                                                       rates,
                                                       freqs,
                                                       alphas,
                                                       threads=threads,
                                                       seed=seed,
                                                       eps=eps)
                #util.toc()
                ### util.log("p = %.6f" % alignment_prob)
                #util.toc()

                #===============================================================================
                ### util.log("   log p = %.6g" % alignment_prob)
                ### util.log("   p = %.6g" % exp(alignment_prob))
                alignment_prob_list.append(alignment_prob)

            ### util.log("p = %f" % alignment_prob_MonteCarlo)

            # log_sum_exp function exponentiate the log probability of observing alignment,
            # add them up, and take log again
            if len(alignment_prob_list) == 0:
                # all bad samples
                alignment_prob_MonteCarlo = -util.INF
            else:
                alignment_prob_MonteCarlo = log_sum_exp(
                    alignment_prob_list) - log(nsamples_coal)

            # P(T^G, R^G | T^L, t^L, daughters, theta) * $ P(t^G | ~) * P(A | T^G,t^G) dtG
            # coal_prob is a log probability
            coal_prob += topology_prob + alignment_prob_MonteCarlo

            # add coal probability to a list for further processing
        double_integral_list.append(coal_prob)

        # log_sum_exp function exponentiate the log probability of observing alignment,
        # add them up, and take log again
        double_integral = log_sum_exp(double_integral_list) - log(
            nsamples_locus)

        # logging info
        if info is not None:
            info["topology_prob"] = topology_prob  # one sample of t^L
            info[
                "alignment_prob"] = alignment_prob_MonteCarlo  # one sample of t^L, averaged over t^G
            info["coal_prob"] = double_integral
    util.toc()
    return double_integral
Example #28
0
def draw_tree(tree, labels={}, xscale=100, yscale=20, canvas=None,
              leafPadding=10, leafFunc=lambda x: str(x.name),
              labelOffset=None, fontSize=10, labelSize=None,
              minlen=1, maxlen=util.INF, filename=sys.stdout,
              rmargin=150, lmargin=10, tmargin=0, bmargin=None,
              colormap=None,
              stree=None,
              layout=None,
              gene2species=None,
              lossColor=(0, 0, 1),
              dupColor=(1, 0, 0),
              eventSize=4,
              legendScale=False, autoclose=None,
              extendRoot=True, labelLeaves=True, drawHoriz=True, nodeSize=0):
    
    # set defaults
    fontRatio = 8. / 11.
    
    if labelSize == None:
        labelSize = .7 * fontSize
    
    if labelOffset == None:
        labelOffset = -1
    
    if bmargin == None:
        bmargin = yscale
    
    if sum(x.dist for x in tree.nodes.values()) == 0:
        legendScale = False
        minlen = xscale
    
    if colormap == None:
        for node in tree:
            node.color = (0, 0, 0)
    else:
        colormap(tree)
    
    if stree and gene2species:
        recon = phylo.reconcile(tree, stree, gene2species)
        events = phylo.label_events(tree, recon)
        losses = phylo.find_loss(tree, stree, recon)
    else:
        events = None
        losses = None

    if len(labels) > 0 or (stree and gene2species):
        drawHoriz = True
    
    # layout tree
    if layout is None:
        coords = treelib.layout_tree(tree, xscale, yscale, minlen, maxlen)
    else:
        coords = layout
    
    xcoords, ycoords = zip(* coords.values())
    maxwidth = max(xcoords)
    maxheight = max(ycoords) + labelOffset
    
    
    # initialize canvas
    if canvas == None:
        canvas = svg.Svg(util.open_stream(filename, "w"))
        width = int(rmargin + maxwidth + lmargin)
        height = int(tmargin + maxheight + bmargin)
        
        canvas.beginSvg(width, height)
        
        if autoclose == None:
            autoclose = True
    else:
        if autoclose == None:
            autoclose = False
    
    
    # draw tree
    def walk(node):
        x, y = coords[node]
        if node.parent:
            parentx, parenty = coords[node.parent]
        else:
            if extendRoot:
                parentx, parenty = 0, y
            else:
                parentx, parenty = x, y     # e.g. no branch
        
        # draw branch
        if drawHoriz:
            canvas.line(parentx, y, x, y, color=node.color)
        else:
            canvas.line(parentx, parenty, x, y, color=node.color)

        # draw branch labels
        if node.name in labels:
            branchlen = x - parentx
            lines = str(labels[node.name]).split("\n")
            labelwidth = max(map(len, lines))
            labellen = min(labelwidth * fontRatio * fontSize, 
                           max(int(branchlen-1), 0))
            
            for i, line in enumerate(lines):
                canvas.text(line,
                            parentx + (branchlen - labellen)/2., 
                            y + labelOffset 
                            +(-len(lines)+1+i)*(labelSize+1),
                            labelSize)

        # draw nodes
        if nodeSize > 0:
            canvas.circle(x, y, nodeSize, strokeColor=svg.null, fillColor=node.color)

        # draw leaf labels or recur
        if node.is_leaf():
            if labelLeaves:
                canvas.text(leafFunc(node), 
                            x + leafPadding, y+fontSize/2., fontSize,
                            fillColor=node.color)
        else:
            if drawHoriz:
                # draw vertical part of branch
                top = coords[node.children[0]][1]
                bot = coords[node.children[-1]][1]
                canvas.line(x, top, x, bot, color=node.color)
                
            # draw children
            for child in node.children:
                walk(child)
    
    canvas.beginTransform(("translate", lmargin, tmargin))
    walk(tree.root)
        
    if stree and gene2species:
        draw_events(canvas, tree, coords, events, losses,
                    lossColor=lossColor,
                    dupColor=dupColor,
                    size=eventSize)
    canvas.endTransform()
    
    # draw legend
    if legendScale:
        if legendScale == True:
            # automatically choose a scale
            length = maxwidth / float(xscale)
            order = math.floor(math.log10(length))
            length = 10 ** order
    
        drawScale(lmargin, tmargin + maxheight + bmargin - fontSize, 
                  length, xscale, fontSize, canvas=canvas)
    
    if autoclose:
        canvas.endSvg()
    
    return canvas
#gene2species = phylo.read_gene2species(conf.smap)
stree = treelib1.read_tree(conf.stree)
tree = treelib1.read_tree(conf.tree)
if conf.names:
    snames = dict(util.read_delim(conf.names))
else:
    snames = None

if conf.brecon:

    brecon = phylo.read_brecon(conf.brecon, tree, stree)



elif conf.recon:
    recon, events = phylo.read_recon_events(conf.recon, tree, stree)
    brecon = phylo.recon_events2brecon(recon, events)
    
else:
    gene2species = phylo.read_gene2species(conf.smap)
    recon = phylo.reconcile(tree, stree, gene2species)
    events = phylo.label_events(tree, recon)
    brecon = phylo.recon_events2brecon(recon, events)
    
phylo.add_implied_spec_nodes_brecon(tree, brecon)

transsvg.draw_tree(tree, brecon, stree, filename=conf.output,
                   snames=snames)


Example #30
0
def sample_dup_times(tree, stree, recon, birth, death,
                     pretime=None, premean=None, events=None):
    """
    Sample duplication times for a gene tree in the dup-loss model
    """

    if events is None:
        events = phylo.label_events(tree, recon)

    # get species tree timestamps
    stimes = treelib.get_tree_timestamps(stree)
    #treelib.check_timestamps(stree, stimes)

    # init timestamps for gene tree
    times = {}


    # set pretimes
    if events[tree.root] != "spec":
        if recon[tree.root] != stree.root:
            # tree root is a dup within species tree
            snode = recon[tree.root]
            start_time = stimes[snode.parent]
            time_span = start_time - stimes[snode]
        else:
            # tree root is a pre-spec dup
            if pretime is None:
                if premean is None:
                    raise Exception("must set pre-mean")

                pretime = 0.0
                while pretime == 0.0:
                    pretime = random.expovariate(1/premean)
            start_time = stimes[stree.root] + pretime
            time_span = pretime

        sample_dup_times_subtree(times, start_time, time_span, tree.root, 
                                 recon, events,
                                 stree, birth, death)

    # set times
    for node in tree.preorder():
        if events[node] == "spec":
            # set speciation time
            times[node] = stimes[recon[node]]


        elif (events[node] == "dup" and
              node.parent is not None and
              recon[node] != recon[node.parent]):
            # set duplication times within duplication subtree
            # node is duproot
            snode = recon[node]
            start_time = stimes[snode.parent]
            time_span = start_time - stimes[snode]
            sample_dup_times_subtree(times, start_time, time_span,
                                     node, 
                                     recon, events,
                                     stree, birth, death)
        elif events[node] == "gene":
            times[node] = 0.0

    return times
Example #31
0
def sample_dup_times(tree, stree, recon, birth, death, pretime=None, premean=None, events=None):
    """
    Sample duplication times for a gene tree in the dup-loss model

    NOTE: Implied speciation nodes must be present
    """

    def gene2species(gene):
        return recon[tree.nodes[gene]].name

    if events is None:
        events = phylo.label_events(tree, recon)

    # get species tree timestamps
    stimes = treelib.get_tree_timestamps(stree)
    # treelib.check_timestamps(stree, stimes)

    # init timestamps for gene tree
    times = {}

    # set pretimes
    if events[tree.root] != "spec":
        if recon[tree.root] != stree.root:
            # tree root is a dup within species tree
            snode = recon[tree.root]
            start_time = stimes[snode.parent]
            time_span = snode.dist

        if recon[tree.root] == stree.root:
            # tree root is a pre-spec dup
            if pretime is None:
                if premean is None:
                    raise Exception("must set pre-mean")

                pretime = 0.0
                while pretime == 0.0:
                    pretime = random.expovariate(1 / premean)
            start_time = stimes[stree.root] + pretime
            time_span = pretime

        sample_dup_times_subtree(times, start_time, time_span, tree.root, recon, events, stree, birth, death)

    # set times
    for node in tree.preorder():
        if events[node] == "spec":
            # set speciation time
            start_time = times[node] = stimes[recon[node]]
            if node.parent:
                if times[node] > times[node.parent]:
                    print "bad", node.name
                    # raise Exception("bad time")

            # set duplication times within duplication subtree
            for duproot in node.children:
                if events[duproot] == "dup":
                    snode = recon[duproot]
                    time_span = snode.dist

                    # assert start_time - time_span >= stimes[snode], \
                    #       (duproot.name, start_time, time_span, stimes[snode])
                    sample_dup_times_subtree(times, start_time, time_span, duproot, recon, events, stree, birth, death)
        elif events[node] == "gene":
            times[node] = 0.0

    return times
Example #32
0
def recon_to_labeledrecon(coal_tree, recon, stree, gene2species,
                          name_internal="n", locus_mpr=True):
    """Convert from DLCoal to DLCpar reconciliation model

    If locus_mpr is set (default), use MPR from locus_tree to stree.
    """

    gene_tree = coal_tree.copy()
    coal_recon = recon.coal_recon
    locus_tree = recon.locus_tree
    if not locus_mpr:
        locus_recon = recon.locus_recon
        daughters = recon.daughters
    else:
        locus_recon = phylo.reconcile(locus_tree, stree, gene2species)
        locus_events = phylo.label_events(locus_tree, locus_recon)
        daughters = filter(lambda node: locus_events[node.parent] == "dup", recon.daughters)

    #========================================
    # find species map

    # find species tree subtree
    substree = treelib.subtree(stree, locus_recon[coal_recon[coal_tree.root]])

    # find species map
    species_map = {}
    for node in gene_tree:
        cnode = coal_tree.nodes[node.name]
        lnode = coal_recon[cnode]
        snode = locus_recon[lnode]
        species_map[node] = substree[snode.name]

    # add implied speciation and delay nodes to gene tree
    events = phylo.label_events(gene_tree, species_map)
    added_spec, added_dup, added_delay = add_implied_nodes(gene_tree, substree, species_map, events)

    # rename internal nodes
    common.rename_nodes(gene_tree, name_internal)

    #========================================
    # helper functions

    def walk_up(node):
        if node.name in coal_tree.nodes:
            return coal_tree.nodes[node.name]
        return walk_up(node.parent)

    def walk_down(node):
        if node.name in coal_tree.nodes:
            return coal_tree.nodes[node.name]
        assert len(node.children) == 1, (node.name, node.children)
        return walk_down(node.children[0])

    #========================================
    # find locus map

    # label loci in locus tree
    loci = {}
    next = 1
    # keep track of duplication ages (measured as dist from leaf since root dist may differ in coal and locus trees)
    locus_times = treelib.get_tree_ages(locus_tree)
    dup_times = {}
    dup_snodes = {}
    for lnode in locus_tree.preorder():
        if not lnode.parent:            # root
            loci[lnode] = next
        elif lnode in daughters:        # duplication
            next += 1
            loci[lnode] = next
            dup_times[next] = locus_times[lnode.parent]
            dup_snodes[next] = locus_recon[lnode.parent]
        else:                           # regular node
            loci[lnode] = loci[lnode.parent]

    # label loci in gene tree
    locus_map = {}
    for node in gene_tree:
        if node.name in coal_tree.nodes:
            # node in coal tree
            cnode = coal_tree.nodes[node.name]
            lnode = coal_recon[cnode]
            locus_map[node] = loci[lnode]
        else:
            # node not in coal tree, so use either parent or child locus
            cnode_up = walk_up(node)
            lnode_up = coal_recon[cnode_up]
            loci_up = loci[lnode_up]

            cnode_down = walk_down(node)
            lnode_down = coal_recon[cnode_down]
            loci_down = loci[lnode_down]

            if loci_up == loci_down:
                # parent and child locus match
                locus_map[node] = loci_up
            else:
                # determine whether to use parent or child locus
                snode = species_map[node]
                dup_snode = dup_snodes[loci_down]
                if (snode.name == dup_snode.name) or (snode.name in dup_snode.descendant_names()):
                    locus_map[node] = loci_down
                else:
                    locus_map[node] = loci_up

    #========================================
    # find order

    # find loci that give rise to new loci in each sbranch
    parent_loci = set()
    for node in gene_tree:
        if node.parent:
            locus = locus_map[node]
            plocus = locus_map[node.parent]

            if locus != plocus:
                snode = species_map[node]
                parent_loci.add((snode, plocus))

    # find order (locus tree and coal tree must use same timescale)
    order = {}
    for node in gene_tree:
        if node.parent:
            snode = species_map[node]
            plocus = locus_map[node.parent]

            if (snode, plocus) in parent_loci:
                order.setdefault(snode, {})
                order[snode].setdefault(plocus, [])
                order[snode][plocus].append(node)

    # find coalescent/duplication times (= negative age) and depths
    coal_times = treelib.get_tree_ages(coal_tree)
    depths = get_tree_depths(gene_tree, distfunc=lambda node: 1)
    def get_time(node):
        if locus_map[node.parent] != locus_map[node]:
            # duplication
            return -dup_times[locus_map[node]], depths[node]
        else:
            # walk up to the nearest node in the coal tree
            # if the node was added (due to spec or dup), it has a single child
            # so it can be placed directly after its parent without affecting the extra lineage count
            if node.name in coal_tree.nodes:
                cnode = coal_tree.nodes[node.name]
            else:
                cnode = walk_up(node)
            return -coal_times[cnode], depths[node]

    # sort by node times
    # 1) larger age (smaller dist from root) are earlier in sort
    # 2) if equal dist, then smaller depths are earlier in sort
    for snode, d in order.iteritems():
        for plocus, lst in d.iteritems():
            lst.sort(key=get_time)

    #========================================
    # put everything together

    return gene_tree, LabeledRecon(species_map, locus_map, order)
Example #33
0
def dup_loss_topology_prior(tree,
                            stree,
                            recon,
                            birth,
                            death,
                            maxdoom=20,
                            events=None):
    """
    Returns the log prior of a gene tree topology according to dup-loss model
    """
    def gene2species(gene):
        return recon[tree.nodes[gene]].name

    if events is None:
        events = phylo.label_events(tree, recon)
    leaves = set(tree.leaves())
    phylo.add_implied_spec_nodes(tree, stree, recon, events)

    pstree, snodes, snodelookup = spidir.make_ptree(stree)

    # get doomtable
    doomtable = calc_doom_table(stree, birth, death, maxdoom)

    prod = 0.0
    for node in tree:
        if events[node] == "spec":
            for schild in recon[node].children:
                nodes2 = [x for x in node.children if recon[x] == schild]
                if len(nodes2) > 0:
                    node2 = nodes2[0]
                    subleaves = get_sub_tree(node2, schild, recon, events)
                    nhist = birthdeath.num_topology_histories(node2, subleaves)
                    s = len(subleaves)
                    thist = stats.factorial(s) * stats.factorial(s - 1) / 2**(
                        s - 1)

                    if len(set(subleaves) & leaves) == 0:
                        # internal
                        prod += log(
                            num_redundant_topology(node2, gene2species,
                                                   subleaves, True))
                    else:
                        # leaves
                        prod += log(
                            num_redundant_topology(node2, gene2species,
                                                   subleaves, False))

                else:
                    nhist = 1.0
                    thist = 1.0
                    s = 0

                t = sum(
                    stats.choose(s + i, i) * birthdeath.prob_birth_death1(
                        s + i, schild.dist, birth, death) *
                    exp(doomtable[snodelookup[schild]])**i
                    for i in range(maxdoom + 1))

                prod += log(nhist) - log(thist) + log(t)

    # correct for renumbering
    nt = num_redundant_topology(tree.root, gene2species)
    prod -= log(nt)

    #phylo.removeImpliedSpecNodes(tree, recon, events)
    treelib.remove_single_children(tree)

    return prod
Example #34
0
def sample_coal_times_topology(coal_tree, coal_recon, locus_tree, popsizes):
    """
    Sample the coalescent times for a topology by doing in a two-step way.
    Sample the labeled history consistent with the topology and then sample branch length
    using the labeled history. This sampling process captures the conditional probability:
    P(t^G | T^G, R^G, T^L, t^L, N^L)  

    coal_tree    -- coalescent tree
    coal_recon   -- reconciliation of coalescent tree to locus tree
    locus_tree   -- locus tree
    popsizes     -- population sizes of the locus tree

    """
    times = {}

    coal_events = phylo.label_events(coal_tree, coal_recon)

    lineages = coal.count_lineages_per_branch(coal_tree, coal_recon,
                                              locus_tree)

    added_spec, added_dup, added_delay = reconlib.add_implied_nodes(
        coal_tree, locus_tree, coal_recon, coal_events, delay=True)
    added = added_spec + added_dup + added_delay

    subtrees = reconlib.factor_tree(coal_tree, locus_tree, coal_recon,
                                    coal_events)

    for coal_node in coal_tree.preorder():
        coal_node.dist = 0.0

    for lnode in locus_tree.preorder():
        subtree = subtrees[lnode]

        # enumerate all labeled histories
        all_labeled_histories_subtree = list(
            reconlib.enum_labeled_histories_subtree(coal_tree, subtree))

        # sample a labeled history
        sample_labeled_history = random.sample(all_labeled_histories_subtree,
                                               1)[0]

        bottom_num, top_num = lineages[lnode]
        # print lineages[lnode]
        if lnode.parent:
            T = lnode.dist
        else:
            T = util.INF
        popsizes = popsizes
        # print T

        # sample coalescent times for a branch
        try:
            coal_times_subtree = sample_coal_times_one_branch(
                bottom_num, top_num, popsizes, T)
        except:
            # clean up coal tree nodes
            reconlib.remove_implied_nodes(coal_tree, added, coal_recon,
                                          coal_events)
            raise

        set_coal_tree_subtree(sample_labeled_history, coal_times_subtree,
                              coal_tree, subtree)

    # remove the implied nodes
    reconlib.remove_implied_nodes(coal_tree, added, coal_recon, coal_events)
Example #35
0
def draw_tree(tree,
              stree,
              extra,
              xscale=100,
              yscale=100,
              leaf_padding=10,
              label_size=None,
              label_offset=None,
              font_size=12,
              stree_font_size=20,
              canvas=None,
              autoclose=True,
              rmargin=10,
              lmargin=10,
              tmargin=0,
              bmargin=0,
              stree_color=(.4, .4, 1),
              snode_color=(.2, .2, .7),
              event_size=10,
              rootlen=None,
              stree_width=.8,
              filename=sys.stdout,
              labels=None,
              slabels=None):

    recon = extra["species_map"]
    loci = extra["locus_map"]
    order = extra["order"]

    # setup color map
    all_loci = sorted(set(loci.values()))
    num_loci = len(all_loci)
    colormap = util.rainbow_color_map(low=0, high=num_loci - 1)
    locus_color = {}
    for ndx, locus in enumerate(all_loci):
        locus_color[locus] = colormap.get(ndx)

    # set defaults
    font_ratio = 8. / 11.

    if label_size is None:
        label_size = .7 * font_size

    #if label_offset is None:
    #    label_offset = -1

    if sum(x.dist for x in tree.nodes.values()) == 0:
        legend_scale = False
        minlen = xscale

    snames = dict((x, x) for x in stree.leaf_names())

    if labels is None:
        labels = {}
    if slabels is None:
        slabels = {}

    # layout stree
    slayout = treelib.layout_tree(stree, xscale, yscale)

    if rootlen is None:
        rootlen = .1 * max(l[0] for l in slayout.values())

    # setup slayout
    x, y = slayout[stree.root]
    slayout[None] = (x - rootlen, y)
    for node, (x, y) in slayout.items():
        slayout[node] = (x + rootlen, y - .5 * yscale)

    # layout tree
    ylists = defaultdict(lambda: [])
    yorders = {}

    # layout speciations and genes (y)
    events = phylo.label_events(tree, recon)
    for node in tree.preorder():
        snode = recon[node]
        event = events[node]
        if event == "spec" or event == "gene":
            yorders[node] = len(ylists[snode])
            ylists[snode].append(node)

    # layout internal nodes (y)
    for node in tree.postorder():
        snode = recon[node]
        event = events[node]
        if event != "spec" and event != "gene":
            v = [yorders[child] for child in node.children]
            yorders[node] = stats.mean(v)

    # layout node (x)

    xorders = {}
    xmax = defaultdict(lambda: 0)
    for node in tree.postorder():
        snode = recon[node]
        event = events[node]
        if event == "spec" or event == "gene":
            xorders[node] = 0
        else:
            v = [xorders[child] for child in node.children]
            xorders[node] = max(v) + 1
        xmax[snode] = max(xmax[snode], xorders[node])

##    # initial order
##    xpreorders = {}
##    for node in tree.postorder():
##        snode = recon[node]
##        event = events[node]
##        if event == "spec" or event == "gene":
##            xpreorders[node] = 0
##        else:
##            v = [xpreorders[child] for child in node.children]
##            xpreorders[node] = max(v) + 1
####        print node.name, xpreorders[node]
##    # hack-ish approach : shift x until order is satisfied
##    def shift(node, x):
##        xpreorders[node] += x
##        for child in node.children:
##            if events[child] != "spec":
##                shift(child, x)
##    satisfied = False
##    while not satisfied:
##        satisfied = True
##        for snode, d in order.iteritems():
##            for plocus, lst in d.iteritems():
##                # test each pair
##                for m, node1 in enumerate(lst):
##                    x1 = xpreorders[node1]
##                    for node2 in lst[m+1:]:
##                        x2 = xpreorders[node2]
####                        print node1, node2, x1, x2
##                        if x2 < x1:
##                            # violation - shift all descendants in the sbranch
##                            satisfied = False
####                            print 'violation', node1, node2, x1, x2, x1-x2+1
##                            shift(node2, x1-x2+1)
##                            break
##    # finally, "normalize" xorders
##    xorders = {}
##    xmax = defaultdict(lambda: 0)
##    for node in tree.postorder():
##        snode = recon[node]
##        xorders[node] = xpreorders[node]
##        xmax[snode] = max(xmax[snode], xorders[node])
####        print node.name, xpreorders[node]

# setup layout
    layout = {None: slayout[None]}
    for node in tree:
        snode = recon[node]
        nx, ny = slayout[snode]
        px, py = slayout[snode.parent]

        # calc x
        frac = (xorders[node]) / float(xmax[snode] + 1)
        deltax = nx - px
        x = nx - frac * deltax

        # calc y
        deltay = ny - py
        slope = deltay / float(deltax)
        deltax2 = x - px
        deltay2 = slope * deltax2
        offset = py + deltay2

        frac = (yorders[node] + 1) / float(max(len(ylists[snode]), 1) + 1)
        y = offset + (frac - .5) * stree_width * yscale

        layout[node] = (x, y)

##        if y > max(l[1] for l in slayout.values()) + 50:
##            print nx, ny
##            print px, py
##            print offset, frac
##            print ylists[snode], yorders[node]
##            print node, snode, layout[node]

# layout label sizes
    max_label_size = max(len(x.name)
                         for x in tree.leaves()) * font_ratio * font_size
    max_slabel_size = max(
        len(x.name) for x in stree.leaves()) * font_ratio * stree_font_size

    xcoords, ycoords = zip(*slayout.values())
    maxwidth = max(xcoords) + max_label_size + max_slabel_size
    maxheight = max(ycoords) + .5 * yscale

    # initialize canvas
    if canvas is None:
        canvas = svg.Svg(util.open_stream(filename, "w"))
        width = int(rmargin + maxwidth + lmargin)
        height = int(tmargin + maxheight + bmargin)

        canvas.beginSvg(width, height)
        canvas.beginStyle("font-family: \"Sans\";")

        if autoclose == None:
            autoclose = True
    else:
        if autoclose == None:
            autoclose = False

    canvas.beginTransform(("translate", lmargin, tmargin))

    draw_stree(canvas,
               stree,
               slayout,
               yscale=yscale,
               stree_width=stree_width,
               stree_color=stree_color,
               snode_color=snode_color,
               slabels=slabels)

    # draw stree leaves
    for node in stree:
        x, y = slayout[node]
        if node.is_leaf():
            canvas.text(snames[node.name],
                        x + leaf_padding + max_label_size,
                        y + stree_font_size / 2.,
                        stree_font_size,
                        fillColor=snode_color)

    # draw tree
    for node in tree:
        x, y = layout[node]
        px, py = layout[node.parent]

        if node.parent:
            color = locus_color[loci[node.parent]]
        else:
            color = locus_color[loci[tree.root]]

        canvas.line(x, y, px, py, color=color)

    # draw tree names
    for node in tree:
        x, y = layout[node]
        px, py = layout[node.parent]

        if node.is_leaf():
            canvas.text(node.name,
                        x + leaf_padding,
                        y + font_size / 2.,
                        font_size,
                        fillColor=(0, 0, 0))

        if node.name in labels:
            canvas.text(labels[node.name],
                        x,
                        y,
                        label_size,
                        fillColor=(0, 0, 0))

    # draw events
    for node in tree:
        if node.parent:
            locus = loci[node]
            plocus = loci[node.parent]

            if locus != plocus:
                color = locus_color[locus]
                x, y = layout[node]
                o = event_size / 2.0

                canvas.rect(x - o,
                            y - o,
                            event_size,
                            event_size,
                            fillColor=color,
                            strokeColor=color)

    canvas.endTransform()

    if autoclose:
        canvas.endStyle()
        canvas.endSvg()

    return canvas
Example #36
0
def sample_locus_tree_hem(stree, popsize, duprate, lossrate,
                          freq=1.0, freqdup=.05, freqloss=.05,
                          steptime=1e6, keep_extinct=False):
    
    """
    Sample a locus tree with birth-death and hemiplasy
    
    
    Runs a relaxed fixation assumption simulation on a species tree.
    Some simplifying assumptions are made for this version of the simulator:
      1) All branches of the species tree have the same population size
      2) All branches of the species tree have the same duplication rate
      3) All branches of the species tree have the same loss rate
      4) All branches of the species tree have the same duplication effect
      5) All branches of the species tree have the same loss effect
      6) All branches of the species tree have the same time between forced
           frequency changes
      7) There is a single allele at the root of the species tree.

    A duplication/loss effect is the change in frequency for either event.
    Appropriate default values for these effects may need to be determined.
    Furture iterations should remove these assumptions by incorporating
    dictionaries to allow values for each branch.

    parameters:
    stree is the initial species tree; it may be mutated by the simulator
    popsize is the population size (assmpt. 1)
    freq is the allele frequency (assmpt. 7)
    duprate is the duplication rate (in events/myr/indiv(?); assmpt. 2)
    lossrate is the loss rate (in events/myr/indiv(?); assmpt. 3)
    freqdup is the duplication effect (assmpt. 4)
    freqloss is the loss effect (assmpt. 5)
    forcetime is the maximum time between frequency changes (assmpt. 6)
    
    Returns the locus tree, as well as extra information
    including a reconciliation dictionary and an events dictionary.
    """
    
    ## sanity checks before running the simulator; may be removed or relaxed
    treelib.assert_tree(stree)
    assert popsize > 0
    assert 0.0 <= freq and freq <= 1.0
    assert duprate >= 0.0
    assert lossrate >= 0.0
    assert 0.0 <= freqdup and freqdup <= 1.0
    assert 0.0 <= freqloss and freqloss <= 1.0
    assert steptime > 0.0

    
    # special case: no duplications or losses
    if duprate == 0.0 and lossrate == 0.0:
        locus_tree = stree.copy()
        recon = phylo.reconcile(locus_tree, stree, lambda x: x)
        events = phylo.label_events(locus_tree, recon)

        return locus_tree, {"recon": recon,
                            "events": events,
                            "daughters": set()}
                                
    
    def event_is_dup(duprate, fullrate):
        return random.random() <= duprate / fullrate

    
    def sim_walk(gtree, snode, gparent, p,
                 s_walk_time=0.0, remaining_steptime=steptime,
                 daughter=False):
        """
        eventlog is a log of events along the gtree branch.
        Each entry has the form
          (time_on_branch, event_type, frequency, species_node),
          
        where
           0.0 <= time_on_branch <= branch_node.dist

        event_type is one of
           {'extinction', 'frequency', 'speciation', duplication',
            'loss', 'root', 'gene'},
            
        where 'root' is a unique event not added during the sim_walk process

        frequency is the branch frequency at the event time

        species_node is the name of the node of the species tree branch in
        which the event occurs
        """

        # create new node
        gnode = treelib.TreeNode(gtree.new_name())
        gtree.add_child(gparent, gnode)
        gnode.data = {"freq": p,
                      "log": []}
        eventlog = gnode.data["log"]
        g_walk_time = 0.0
        if daughter:
            eventlog.append((0.0, 'daughter', freqdup, snode.name))
            
        
        # grow this branch, determine next event
        event = None
        while True:
            if p <= 0.0:
                event = "extinct"
                break
            
            # determine remaing time
            remaining_s_dist = snode.dist - s_walk_time
            remaining_time = min(remaining_steptime, remaining_s_dist)

            # sample next dup/loss event
            eff_duprate = duprate * p / freqdup
            eff_lossrate = lossrate * p / freqloss
            eff_bothrate = eff_duprate + eff_lossrate            
            event_time = stats.exponentialvariate(eff_bothrate)

            # advance times
            time_delta = min(event_time, remaining_time)
            s_walk_time += time_delta
            g_walk_time += time_delta

            # sample new frequency
            p = coal.sample_freq_CDF(p, popsize, time_delta)

            # determine event
            if event_time < remaining_time:
                # dup/loss occurs
                if event_is_dup(eff_duprate, eff_bothrate):
                    # dup, stop growing
                    event = "dup"
                    break
                else:
                    # loss, continue growing
                    event = "loss"
                    
            else:
                if remaining_s_dist < remaining_steptime:
                    # we are at a speciation, stop growing
                    event = "spec"
                    break

            # process step
            if event == "loss":
                # LOSS EVENT
                p = max(p - freqloss, 0.0)
                remaining_steptime -= time_delta
                eventlog.append((g_walk_time, 'loss', p, snode.name))
            else:
                # NEXT TIME STEP
                remaining_steptime = steptime
                eventlog.append((g_walk_time, 'frequency', p, snode.name))
                

        # process event
        if event == "extinct":
            # EXTINCTION EVENT (p <= 0)
            gnode.dist = g_walk_time
            gnode.data['freq'] = 0.0
            eventlog.append((g_walk_time, 'extinction', 0.0, snode.name))

        
        elif event == "spec":
            # SPECIATION EVENT
            gnode.dist = g_walk_time
            gnode.data['freq'] = p
                        
            # add speciation event to event log and
            if snode.is_leaf():
                eventlog.append((g_walk_time, 'gene', p, snode.name))
            else:
                eventlog.append((g_walk_time, 'speciation', p, snode.name))
                for schild in snode.children:
                    sim_walk(gtree, schild, gnode, p)


        elif event == "dup":
            # DUPLICATION EVENT
            gnode.dist = g_walk_time
            gnode.data['freq'] = p
            eventlog.append((g_walk_time, 'duplication', p, snode.name))

            # recurse on mother
            sim_walk(gtree, snode, gnode, p, 
                     s_walk_time=s_walk_time, 
                     remaining_steptime=remaining_steptime)

            # recurse on daughter
            sim_walk(gtree, snode, gnode, freqdup, 
                     s_walk_time=s_walk_time, 
                     remaining_steptime=remaining_steptime,
                     daughter=True)

        else:
            raise Exception("unknown event '%s'" % event)
    
    
    # create new gene tree and simulate its evolution
    gtree = treelib.Tree()
    gtree.make_root()
    gtree.root.dist = 0.0
    gtree.root.data['freq'] = freq
    gtree.root.data['log'] = [(0.0, 'speciation', freq, stree.root.name)]

    # simulate locus tree
    sim_walk(gtree, stree.root.children[0], gtree.root, freq)
    sim_walk(gtree, stree.root.children[1], gtree.root, freq)
    
    
    # remove dead branches and single children
    extant_leaves = [leaf.name for leaf in gtree.leaves()
                     if leaf.data['freq'] > 0.0]
    extinctions = [leaf for leaf in gtree.leaves()
                   if leaf.data['freq'] == 0.0]

    if keep_extinct:
        full_gtree = gtree.copy()
        # do deep copy of data
        for node in full_gtree:
            node2 = gtree.nodes[node.name]
            for key, val in node2.data.items():
                node.data[key] = copy.copy(val)
        
    treelib.subtree_by_leaf_names(gtree, extant_leaves, keep_single=True)
    remove_single_children(gtree)

    # determine extra information (recon, events, daughters)
    extras = generate_extras(stree, gtree)

    if keep_extinct:
        extras["full_locus_tree"] = full_gtree
    
    return gtree, extras
Example #37
0
def labeledrecon_to_recon(gene_tree, labeled_recon, stree,
                          name_internal="n"):
    """Convert from DLCpar to DLCoal reconciliation model

    NOTE: This is non-reversible because it produces NON-dated coalescent and locus trees
    """

    locus_map = labeled_recon.locus_map
    species_map = labeled_recon.species_map
    order = labeled_recon.order

    # coalescent tree equals gene tree
    coal_tree = gene_tree.copy()

    # factor gene tree
    events = phylo.label_events(gene_tree, species_map)
    subtrees = factor_tree(gene_tree, stree, species_map, events)

    # gene names
    genenames = {}
    for snode in stree:
        genenames[snode] = {}
    for leaf in gene_tree.leaves():
        genenames[species_map[leaf]][locus_map[leaf]] = leaf.name

    # 2D dict to keep track of locus tree nodes by hashing by speciation node and locus
    # key1 = snode, key2 = locus, value = list of nodes (sorted from oldest to most recent)
    locus_tree_map = {}
    for snode in stree:
        locus_tree_map[snode] = {}

    # initialize locus tree, coal/locus recon, and daughters
    locus_tree = treelib.Tree()
    coal_recon = {}
    locus_recon = {}
    locus_events = {}
    daughters = []

    # initialize root of locus tree
    root = treelib.TreeNode(locus_tree.new_name())
    locus_tree.add(root)
    locus_tree.root = root
    sroot = species_map[gene_tree.root]
    locus = locus_map[gene_tree.root]
    coal_recon[coal_tree.root] = root
    locus_recon[root] = sroot
    locus_tree_map[sroot][locus] = [root]

    # build locus tree along each species branch
    for snode in stree.preorder(sroot):
        subtrees_snode = subtrees[snode]

        # skip if no branches in this species branch
        if len(subtrees_snode) == 0:
            continue

        # build locus tree
        # 1) speciation
        if snode.parent:
            for (root, rootchild, leaves) in subtrees_snode:
                if rootchild:
                    locus = locus_map[root]     # use root locus!

                    # create new locus tree node in this species branch
                    if locus not in locus_tree_map[snode]:
                        old_node = locus_tree_map[snode.parent][locus][-1]

                        new_node = treelib.TreeNode(locus_tree.new_name())
                        locus_tree.add_child(old_node, new_node)
                        locus_recon[new_node] = snode

                        locus_events[old_node] = "spec"

                        locus_tree_map[snode][locus] = [new_node]

                    # update coal_recon
                    cnode = coal_tree.nodes[rootchild.name]
                    lnode = locus_tree_map[snode][locus][-1]
                    coal_recon[cnode] = lnode

        # 2) duplication
        if snode in order:
            # may have to reorder loci (in case of multiple duplications)
            queue = collections.deque(order[snode].keys())
            while len(queue) > 0:
                plocus = queue.popleft()
                if plocus not in locus_tree_map[snode]:
                    # punt
                    queue.append(plocus)
                    continue

                # handle this ordered list
                lst =  order[snode][plocus]
                for gnode in lst:
                    locus = locus_map[gnode]
                    cnode = coal_tree.nodes[gnode.name]

                    if locus != plocus:     # duplication
                        # update locus_tree, locus_recon, and daughters
                        old_node = locus_tree_map[snode][plocus][-1]

                        new_node1 = treelib.TreeNode(locus_tree.new_name())
                        locus_tree.add_child(old_node, new_node1)
                        locus_recon[new_node1] = snode

                        new_node2 = treelib.TreeNode(locus_tree.new_name())
                        locus_tree.add_child(old_node, new_node2)
                        coal_recon[cnode] = new_node2
                        locus_recon[new_node2] = snode
                        daughters.append(new_node2)

                        locus_events[old_node] = "dup"

                        locus_tree_map[snode][plocus].append(new_node1)
                        locus_tree_map[snode][locus] = [new_node2]

                    else:                   # deep coalescence
                        lnode = locus_tree_map[snode][locus][-1]
                        coal_recon[cnode] = lnode

        # reconcile remaining coal tree nodes to locus tree
        # (no duplication so only a single locus tree node with the desired locus)
        for (root, rootchild, leaves) in subtrees_snode:
            if rootchild:
                for gnode in gene_tree.preorder(rootchild, is_leaf=lambda x: x in leaves):
                    cnode = coal_tree.nodes[gnode.name]
                    if cnode not in coal_recon:
                        locus = locus_map[gnode]
                        assert len(locus_tree_map[snode][locus]) == 1
                        lnode = locus_tree_map[snode][locus][-1]
                        coal_recon[cnode] = lnode

        # tidy up if at an extant species
        if snode.is_leaf():
            for locus, nodes in locus_tree_map[snode].iteritems():
                genename = genenames[snode][locus]
                lnode = nodes[-1]
                cnode = coal_tree.nodes[genename]

                # relabel genes in locus tree
                locus_tree.rename(lnode.name, genename)

                # relabel locus events
                locus_events[lnode] = "gene"

                # reconcile genes (genes in coal tree reconcile to genes in locus tree)
                # possible mismatch due to genes having an internal ordering even though all exist to present time
                # [could also do a new round of "speciation" at bottom of extant species branches,
                # but this introduces single children nodes that would just be removed anyway]
                coal_recon[cnode] = lnode

    # rename internal nodes
    common.rename_nodes(locus_tree, name_internal)

    # simplify coal_tree (and reconciliations)
    removed = treelib.remove_single_children(coal_tree)
    for cnode in removed:
        del coal_recon[cnode]

    # simplify locus_tree (and reconciliations + daughters)
    removed = treelib.remove_single_children(locus_tree)
    for cnode, lnode in coal_recon.items():
        if lnode in removed:
            # reconciliation updates to first child that is not removed
            new_lnode = lnode
            while new_lnode in removed:
                new_lnode = new_lnode.children[0]
            coal_recon[cnode] = new_lnode
    for lnode in removed:
        del locus_recon[lnode]
        del locus_events[lnode]
    for ndx, lnode in enumerate(daughters):
        if lnode in removed:
            # daughter updates to first child that is not removed
            new_lnode = lnode
            while new_lnode in removed:
                new_lnode = new_lnode.children[0]
            daughters[ndx] = new_lnode

##    locus_events = phylo.label_events(locus_tree, locus_recon)
    assert all([lnode in locus_events for lnode in locus_tree])

    #========================================
    # put everything together

    return coal_tree, phyloDLC.Recon(coal_recon, locus_tree, locus_recon, locus_events, daughters)