Esempio n. 1
0
    def _ml_t_root_leaves(self):
        """
        Given the location probability distribution, computed by the propagation
        from leaves to root, set the root most-likely location. Estimate the
        tree likelihood. Report the root location probability distribution
        message towards the leaves. For each internal node, compute the final
        location probability distribution based on the pair of messages (from the
        leaves and from the root), and find the most likely position of the
        internal nodes and finally, convert it to the date-time information

        Args:

        - None: all the requires parameters are pre-set in the previous steps.

        Returns:
         - None: all the internal nodes are assigned probability distributions
           of their locations. The branch lengths are updated to reflect the most
           likely node locations.

        """
        self.logger("ClockTree: Propagating root -> leaves...", 2)
        # Main method - propagate from root to the leaves and set the LH distributions
        # to each node
        for node in self.tree.find_clades(order='preorder'):  # ancestors first, msg to children
            ## This is the root node
            if node.up is None:
                node.msg_from_parent = None # nothing beyond the root
            elif (node.msg_to_parent is not None) and node.msg_to_parent.is_delta:
                node.msg_from_parent = None
            else:
                parent = node.up
                complementary_msgs = [parent.msgs_from_leaves[k]
                                      for k in parent.msgs_from_leaves
                                      if k != node]

                if parent.msg_from_parent is not None: # the parent is not root => got something from the parent
                    complementary_msgs.append(parent.msg_from_parent)

                msg_parent_to_node = NodeInterpolator.multiply(complementary_msgs)
                msg_parent_to_node._adjust_grid(rel_tol=self.rel_tol_prune)
                res = NodeInterpolator.convolve(msg_parent_to_node, node.branch_length_interpolator,
                                                inverse_time=False, n_integral=self.n_integral,
                                                rel_tol=self.rel_tol_refine)
                node.msg_from_parent = res
                self.logger('ClockTree._ml_t_root_to_leaves: computed convolution'
                            ' with %d points at node %s'%(len(res.x),node.name),4)
                if self.debug:
                    tmp = np.diff(res.y-res.peak_val)
                    nsign_changed = np.sum((tmp[1:]*tmp[:-1]<0)&(res.y[1:-1]-res.peak_val<500))
                    if nsign_changed>1:
                        import matplotlib.pyplot as plt
                        plt.ion()
                        plt.plot(res.x, res.y-res.peak_val, '-o')
                        plt.plot(res.peak_pos - node.branch_length_interpolator.x,
                                 node.branch_length_interpolator.y-node.branch_length_interpolator.peak_val, '-o')
                        plt.plot(msg_parent_to_node.x,msg_parent_to_node.y-msg_parent_to_node.peak_val, '-o')
                        plt.ylim(0,100)
                        plt.xlim(-0.01, 0.01)
                        import ipdb; ipdb.set_trace()
Esempio n. 2
0
    def _set_final_dates(self):
        """
        Given the location of the node in branch length units, convert it to the
        date-time information.

        Args:
         - node(Phylo.Clade): tree node. NOTE the node should have the abs_t attribute
         to have a valid value. This is automatically taken care of in the
         procedure to get the node location probability distribution.

        """
        self.logger("ClockTree: Setting dates and node distributions...", 2)
        def collapse_func(dist):
            if dist.is_delta:
                return dist.peak_pos
            else:
                return dist.peak_pos


        for node in self.tree.find_clades(order='preorder'):  # ancestors first, msg to children
            # set marginal distribution
            ## This is the root node
            if node.up is None:
                node.marginal_lh = node.msg_to_parent
            elif node.msg_to_parent is None:
                node.marginal_lh = node.msg_from_parent
            elif node.msg_to_parent.is_delta:
                node.marginal_lh = node.msg_to_parent
            else:
                node.marginal_lh = NodeInterpolator.multiply((node.msg_from_parent, node.msg_to_parent))

            if node.up is None:
                node.joint_lh = node.msg_to_parent
                node.time_before_present = collapse_func(node.joint_lh)
                node.branch_length = self.one_mutation
            else:
                # shift position of parent node (time_before_present) by the branch length
                # towards the present. To do so, add branch length to negative time_before_present
                # and rescale the resulting distribution by -1.0
                res = Distribution.shifted_x(node.branch_length_interpolator, -node.up.time_before_present)
                res.x_rescale(-1.0)
                # multiply distribution from parent with those from children and determine peak
                if node.msg_to_parent is not None:
                    node.joint_lh = NodeInterpolator.multiply((node.msg_to_parent, res))
                else:
                    node.joint_lh = res
                node.time_before_present = collapse_func(node.joint_lh)

                node.branch_length = node.up.time_before_present - node.time_before_present
            node.clock_length = node.branch_length
Esempio n. 3
0
    def _ml_t_leaves_root(self):
        """
        Compute the probability distribution of the internal nodes positions by
        propagating from the tree leaves towards the root. The result of
        this operation are the probability distributions of each internal node,
        conditional on the constraints on leaves in the descendant subtree. The exception
        is the root of the tree, as its subtree includes all the constrained leaves.
        To the final location probability distribution of the internal nodes,
        is calculated via back-propagation in _ml_t_root_to_leaves.

        Args:

         - None: all required parameters are pre-set as the node attributes during
           tree preparation

        Returns:

         - None: Every internal node is assigned the probability distribution in form
           of an interpolation object and sends this distribution further towards the
           root.

        """
        def _send_message(node, **kwargs):
            """
            Calc the desired LH distribution of the parent
            """
            if node.msg_to_parent.is_delta:
                res = Distribution.shifted_x(node.branch_length_interpolator, node.msg_to_parent.peak_pos)
            else: # convolve two distributions
                res =  NodeInterpolator.convolve(node.msg_to_parent,
                            node.branch_length_interpolator, n_integral=self.n_integral,
                            rel_tol=self.rel_tol_refine)
            self.logger("ClockTree._ml_t_leaves_root._send_message: "
                        "computed convolution with %d points at node %s"%(len(res.x),node.name),4)
            return res

        self.logger("ClockTree: Propagating leaves -> root...", 2)
        # go through the nodes from leaves towards the root:
        for node in self.tree.find_clades(order='postorder'):  # children first, msg to parents
            if node.is_terminal():
                node.msgs_from_leaves = {}
            else:
                # save all messages from the children nodes with constraints
                # store as dictionary to exclude nodes from the set when necessary
                # (see below)
                node.msgs_from_leaves = {clade: _send_message(clade) for clade in node.clades
                                                if clade.msg_to_parent is not None}

                if len(node.msgs_from_leaves) < 1:  # we need at least one constraint
                    continue
                # this is what the node sends to the parent
                node.msg_to_parent = NodeInterpolator.multiply(node.msgs_from_leaves.values())
                node.msg_to_parent._adjust_grid(rel_tol=self.rel_tol_prune)
Esempio n. 4
0
 def _send_message(node, **kwargs):
     """
     Calc the desired LH distribution of the parent
     """
     if node.msg_to_parent.is_delta:
         res = Distribution.shifted_x(node.branch_length_interpolator, node.msg_to_parent.peak_pos)
     else: # convolve two distributions
         res =  NodeInterpolator.convolve(node.msg_to_parent,
                     node.branch_length_interpolator, n_integral=self.n_integral,
                     rel_tol=self.rel_tol_refine)
     self.logger("ClockTree._ml_t_leaves_root._send_message: "
                 "computed convolution with %d points at node %s"%(len(res.x),node.name),4)
     return res
Esempio n. 5
0
    def _ml_t_marginal(self, assign_dates=False):
        """
        Compute the marginal probability distribution of the internal nodes positions by
        propagating from the tree leaves towards the root. The result of
        this operation are the probability distributions of each internal node,
        conditional on the constraints on all leaves of the tree, which have sampling dates.
        The probability distributions are set as marginal_pos_LH attributes to the nodes.

        Parameters
        ----------

         assign_dates : bool, default False
            If True, the inferred dates will be assigned to the nodes as
            :code:`time_before_present' attributes, and their branch lengths
            will be corrected accordingly.
            .. Note::
                Normally, the dates are assigned by running joint reconstruction.

        Returns
        -------

         None
            Every internal node is assigned the probability distribution in form
            of an interpolation object and sends this distribution further towards the
            root.

        """

        def _cleanup():
            for node in self.tree.find_clades():
                try:
                    del node.marginal_pos_Lx
                    del node.subtree_distribution
                    del node.msg_from_parent
                    #del node.marginal_pos_LH
                except:
                    pass


        self.logger("ClockTree - Marginal reconstruction:  Propagating leaves -> root...", 2)
        # go through the nodes from leaves towards the root:
        for node in self.tree.find_clades(order='postorder'):  # children first, msg to parents
            if node.bad_branch:
                # no information
                node.marginal_pos_Lx = None
            else: # all other nodes
                if node.date_constraint is not None and node.date_constraint.is_delta: # there is a time constraint
                    # initialize the Lx for nodes with precise date constraint:
                    # subtree probability given the position of the parent node
                    # position of the parent node is given by the branch length
                    # distribution attached to the child node position
                    node.subtree_distribution = node.date_constraint
                    bl = node.branch_length_interpolator.x
                    x = bl + node.date_constraint.peak_pos
                    node.marginal_pos_Lx = Distribution(x, node.branch_length_interpolator(bl),
                                                        min_width=self.min_width, is_log=True)

                else: # all nodes without precise constraint but positional information
                      # subtree likelihood given the node's constraint and child msg:
                    msgs_to_multiply = [node.date_constraint] if node.date_constraint is not None else []
                    msgs_to_multiply.extend([child.marginal_pos_Lx for child in node.clades
                                             if child.marginal_pos_Lx is not None])

                    # combine the different msgs and constraints
                    if len(msgs_to_multiply)==0:
                        # no information
                        node.marginal_pos_Lx = None
                        continue
                    elif len(msgs_to_multiply)==1:
                        node.subtree_distribution = msgs_to_multiply[0]
                    else: # combine the different msgs and constraints
                        node.subtree_distribution = Distribution.multiply(msgs_to_multiply)

                    if node.up is None: # this is the root, set dates
                        node.subtree_distribution._adjust_grid(rel_tol=self.rel_tol_prune)
                        node.marginal_pos_Lx = node.subtree_distribution
                        node.marginal_pos_LH = node.subtree_distribution
                        self.tree.positional_marginal_LH = -node.subtree_distribution.peak_val
                    else: # otherwise propagate to parent
                        res, res_t = NodeInterpolator.convolve(node.subtree_distribution,
                                        node.branch_length_interpolator,
                                        max_or_integral='integral',
                                        n_grid_points = self.node_grid_points,
                                        n_integral=self.n_integral,
                                        rel_tol=self.rel_tol_refine)
                        res._adjust_grid(rel_tol=self.rel_tol_prune)
                        node.marginal_pos_Lx = res

        self.logger("ClockTree - Marginal reconstruction:  Propagating root -> leaves...", 2)
        from scipy.interpolate import interp1d
        for node in self.tree.find_clades(order='preorder'):

            ## The root node
            if node.up is None:
                node.msg_from_parent = None # nothing beyond the root
            # all other cases (All internal nodes + unconstrained terminals)
            else:
                parent = node.up
                # messages from the complementary subtree (iterate over all sister nodes)
                complementary_msgs = [sister.marginal_pos_Lx for sister in parent.clades
                                            if (sister != node) and (sister.marginal_pos_Lx is not None)]

                # if parent itself got smth from the root node, include it
                if parent.msg_from_parent is not None:
                    complementary_msgs.append(parent.msg_from_parent)
                elif parent.marginal_pos_Lx is not None:
                    complementary_msgs.append(parent.marginal_pos_LH)

                if len(complementary_msgs):
                    msg_parent_to_node = NodeInterpolator.multiply(complementary_msgs)
                    msg_parent_to_node._adjust_grid(rel_tol=self.rel_tol_prune)
                else:
                    from utils import numeric_date
                    x = [parent.numdate, numeric_date()]
                    msg_parent_to_node = NodeInterpolator(x, [1.0, 1.0],min_width=self.min_width)

                # integral message, which delivers to the node the positional information
                # from the complementary subtree
                res, res_t = NodeInterpolator.convolve(msg_parent_to_node, node.branch_length_interpolator,
                                                    max_or_integral='integral',
                                                    inverse_time=False,
                                                    n_grid_points = self.node_grid_points,
                                                    n_integral=self.n_integral,
                                                    rel_tol=self.rel_tol_refine)

                node.msg_from_parent = res
                if node.marginal_pos_Lx is None:
                    node.marginal_pos_LH = node.msg_from_parent
                else:
                    node.marginal_pos_LH = NodeInterpolator.multiply((node.msg_from_parent, node.subtree_distribution))

                self.logger('ClockTree._ml_t_root_to_leaves: computed convolution'
                                ' with %d points at node %s'%(len(res.x),node.name),4)

                if self.debug:
                    tmp = np.diff(res.y-res.peak_val)
                    nsign_changed = np.sum((tmp[1:]*tmp[:-1]<0)&(res.y[1:-1]-res.peak_val<500))
                    if nsign_changed>1:
                        import matplotlib.pyplot as plt
                        plt.ion()
                        plt.plot(res.x, res.y-res.peak_val, '-o')
                        plt.plot(res.peak_pos - node.branch_length_interpolator.x,
                                 node.branch_length_interpolator(node.branch_length_interpolator.x)-node.branch_length_interpolator.peak_val, '-o')
                        plt.plot(msg_parent_to_node.x,msg_parent_to_node.y-msg_parent_to_node.peak_val, '-o')
                        plt.ylim(0,100)
                        plt.xlim(-0.05, 0.05)
                        import ipdb; ipdb.set_trace()

            # assign positions of nodes and branch length only when desired
            # since marginal reconstruction can result in negative branch length
            if assign_dates:
                node.time_before_present = node.marginal_pos_LH.peak_pos
                if node.up:
                    node.clock_length = node.up.time_before_present - node.time_before_present
                    node.branch_length = node.clock_length

            # construct the inverse cumulant distribution to evaluate confidence intervals
            if node.marginal_pos_LH.is_delta:
                node.marginal_inverse_cdf=interp1d([0,1], node.marginal_pos_LH.peak_pos*np.ones(2), kind="linear")
            else:
                dt = np.diff(node.marginal_pos_LH.x)
                y = node.marginal_pos_LH.prob_relative(node.marginal_pos_LH.x)
                int_y = np.concatenate(([0], np.cumsum(dt*(y[1:]+y[:-1])/2.0)))
                int_y/=int_y[-1]
                node.marginal_inverse_cdf = interp1d(int_y, node.marginal_pos_LH.x, kind="linear")
                node.marginal_cdf = interp1d(node.marginal_pos_LH.x, int_y, kind="linear")

        if not self.debug:
            _cleanup()

        return
Esempio n. 6
0
    def _ml_t_joint(self):
        """
        Compute the joint maximum likelihood assignment of the internal nodes positions by
        propagating from the tree leaves towards the root. Given the assignment of parent nodes,
        reconstruct the maximum-likelihood positions of the child nodes by propagating
        from the root to the leaves. The result of this operation is the time_before_present
        value, which is the position of the node, expressed in the units of the
        branch length, and scaled from the present-day. The value is assigned to the
        corresponding attribute of each node of the tree.

        Returns
        -------

         None
            Every internal node is assigned the probability distribution in form
            of an interpolation object and sends this distribution further towards the
            root.

        """

        def _cleanup():
            for node in self.tree.find_clades():
                del node.joint_pos_Lx
                del node.joint_pos_Cx


        self.logger("ClockTree - Joint reconstruction:  Propagating leaves -> root...", 2)
        # go through the nodes from leaves towards the root:
        for node in self.tree.find_clades(order='postorder'):  # children first, msg to parents
            # Lx is the maximal likelihood of a subtree given the parent position
            # Cx is the branch length corresponding to the maximally likely subtree
            if node.bad_branch:
                # no information at the node
                node.joint_pos_Lx = None
                node.joint_pos_Cx = None
            else: # all other nodes
                if node.date_constraint is not None and node.date_constraint.is_delta: # there is a time constraint
                    # subtree probability given the position of the parent node
                    # Lx.x is the position of the parent node
                    # Lx.y is the probablity of the subtree (consisting of one terminal node in this case)
                    # Cx.y is the branch length corresponding the optimal subtree
                    bl = node.branch_length_interpolator.x
                    x = bl + node.date_constraint.peak_pos
                    node.joint_pos_Lx = Distribution(x, node.branch_length_interpolator(bl),
                                                     min_width=self.min_width, is_log=True)
                    node.joint_pos_Cx = Distribution(x, bl, min_width=self.min_width) # map back to the branch length
                else: # all nodes without precise constraint but positional information
                    msgs_to_multiply = [node.date_constraint] if node.date_constraint is not None else []
                    msgs_to_multiply.extend([child.joint_pos_Lx for child in node.clades
                                             if child.joint_pos_Lx is not None])

                    # subtree likelihood given the node's constraint and child messages
                    if len(msgs_to_multiply) == 0: # there are no constraints
                        node.joint_pos_Lx = None
                        node.joint_pos_Cx = None
                        continue
                    elif len(msgs_to_multiply)>1: # combine the different msgs and constraints
                        subtree_distribution = Distribution.multiply(msgs_to_multiply)
                    else: # there is exactly one constraint.
                        subtree_distribution = msgs_to_multiply[0]
                    if node.up is None: # this is the root, set dates
                        subtree_distribution._adjust_grid(rel_tol=self.rel_tol_prune)

                        # set root position and joint likelihood of the tree
                        node.time_before_present = subtree_distribution.peak_pos
                        node.joint_pos_Lx = subtree_distribution
                        node.joint_pos_Cx = None
                        node.clock_length = node.branch_length
                    else: # otherwise propagate to parent
                        res, res_t = NodeInterpolator.convolve(subtree_distribution,
                                        node.branch_length_interpolator,
                                        max_or_integral='max',
                                        inverse_time=True,
                                        n_grid_points = self.node_grid_points,
                                        n_integral=self.n_integral,
                                        rel_tol=self.rel_tol_refine)

                        res._adjust_grid(rel_tol=self.rel_tol_prune)

                        node.joint_pos_Lx = res
                        node.joint_pos_Cx = res_t


        # go through the nodes from root towards the leaves and assign joint ML positions:
        self.logger("ClockTree - Joint reconstruction:  Propagating root -> leaves...", 2)
        for node in self.tree.find_clades(order='preorder'):  # root first, msgs to children

            if node.up is None: # root node
                continue # the position was already set on the previous step

            if node.joint_pos_Cx is None: # no constraints or branch is bad - reconstruct from the branch len interpolator
                node.branch_length = node.branch_length_interpolator.peak_pos

            elif isinstance(node.joint_pos_Cx, Distribution):
                # NOTE the Lx distribution is the likelihood, given the position of the parent
                # (Lx.x = parent position, Lx.y = LH of the node_pos given Lx.x,
                # the length of the branch corresponding to the most likely
                # subtree is node.Cx(node.time_before_present))
                subtree_LH = node.joint_pos_Lx(node.up.time_before_present)
                node.branch_length = node.joint_pos_Cx(max(node.joint_pos_Cx.xmin,
                                            node.up.time_before_present)+ttconf.TINY_NUMBER)

            node.time_before_present = node.up.time_before_present - node.branch_length
            node.clock_length = node.branch_length

            # just sanity check, should never happen:
            if node.branch_length < 0 or node.time_before_present < 0:
                if node.branch_length<0 and node.branch_length>-ttconf.TINY_NUMBER:
                    self.logger("ClockTree - Joint reconstruction: correcting rounding error of %s"%node.name, 4)
                    node.branch_length = 0

        self.tree.positional_joint_LH = self.timetree_likelihood()
        # cleanup, if required
        if not self.debug:
            _cleanup()
Esempio n. 7
0
    def init_date_constraints(self, ancestral_inference=False, slope=None, **kwarks):
        """
        Get the conversion coefficients between the dates and the branch
        lengths as they are used in ML computations. The conversion formula is
        assumed to be 'length = k*numdate_given + b'. For convenience, these
        coefficients as well as regression parameters are stored in the
        dates2dist object.

        Note: that tree must have dates set to all nodes before calling this
        function. (This is accomplished by calling load_dates func).

        Params:
            ancestral_inference: bool -- whether or not to reinfer ancestral sequences
                                 done by default when ancestral sequences are missing

        """
        self.logger("ClockTree.init_date_constraints...",2)

        if ancestral_inference or (not hasattr(self.tree.root, 'sequence')):
            self.infer_ancestral_sequences('ml',sample_from_profile='root',**kwarks)

        # set the None  for the date-related attributes in the internal nodes.
        # make interpolation objects for the branches
        self.logger('ClockTree.init_date_constraints: Initializing branch length interpolation objects...',3)
        for node in self.tree.find_clades():
            if node.up is None:
                node.branch_length_interpolator = None
            else:
                # copy the merger rate and gamma if they are set
                if hasattr(node,'branch_length_interpolator'):
                    gamma = node.branch_length_interpolator.gamma
                    merger_rate = node.branch_length_interpolator.merger_rate
                else:
                    gamma = 1.0
                    merger_rate = self.merger_rate_default
                node.branch_length_interpolator = BranchLenInterpolator(node, self.gtr, one_mutation=self.one_mutation)
                node.branch_length_interpolator.merger_rate = merger_rate
                node.branch_length_interpolator.gamma = gamma
        self.date2dist = utils.DateConversion.from_tree(self.tree, slope)

        # make node distribution objects
        for node in self.tree.find_clades():
            # node is constrained
            if hasattr(node, 'numdate_given') and node.numdate_given is not None:
                # set the absolute time before present in branch length units
                if not np.isscalar(node.numdate_given):
                    node.numdate_given = np.array(node.numdate_given)
                node.time_before_present = self.date2dist.get_time_before_present(node.numdate_given)
                if hasattr(node, 'bad_branch') and node.bad_branch==True:
                    self.logger("ClockTree.init_date_constraints -- WARNING: Branch is marked as bad"
                                ", excluding it from the optimization process"
                                " Will be optimized freely", 4, warn=True)
                    # if there are no constraints - log_prob will be set on-the-fly
                    node.msg_to_parent = None
                else:
                    if np.isscalar(node.numdate_given):
                        node.msg_to_parent = NodeInterpolator.delta_function(node.time_before_present, weight=1)
                    else:
                        node.msg_to_parent = NodeInterpolator(node.time_before_present,
                                                              np.ones_like(node.time_before_present), is_log=False)

            else: # node without sampling date set
                node.numdate_given = None
                node.time_before_present = None
                # if there are no constraints - log_prob will be set on-the-fly
                node.msg_to_parent = None