def _ml_t_root_leaves(self): """ Given the location probability distribution, computed by the propagation from leaves to root, set the root most-likely location. Estimate the tree likelihood. Report the root location probability distribution message towards the leaves. For each internal node, compute the final location probability distribution based on the pair of messages (from the leaves and from the root), and find the most likely position of the internal nodes and finally, convert it to the date-time information Args: - None: all the requires parameters are pre-set in the previous steps. Returns: - None: all the internal nodes are assigned probability distributions of their locations. The branch lengths are updated to reflect the most likely node locations. """ self.logger("ClockTree: Propagating root -> leaves...", 2) # Main method - propagate from root to the leaves and set the LH distributions # to each node for node in self.tree.find_clades(order='preorder'): # ancestors first, msg to children ## This is the root node if node.up is None: node.msg_from_parent = None # nothing beyond the root elif (node.msg_to_parent is not None) and node.msg_to_parent.is_delta: node.msg_from_parent = None else: parent = node.up complementary_msgs = [parent.msgs_from_leaves[k] for k in parent.msgs_from_leaves if k != node] if parent.msg_from_parent is not None: # the parent is not root => got something from the parent complementary_msgs.append(parent.msg_from_parent) msg_parent_to_node = NodeInterpolator.multiply(complementary_msgs) msg_parent_to_node._adjust_grid(rel_tol=self.rel_tol_prune) res = NodeInterpolator.convolve(msg_parent_to_node, node.branch_length_interpolator, inverse_time=False, n_integral=self.n_integral, rel_tol=self.rel_tol_refine) node.msg_from_parent = res self.logger('ClockTree._ml_t_root_to_leaves: computed convolution' ' with %d points at node %s'%(len(res.x),node.name),4) if self.debug: tmp = np.diff(res.y-res.peak_val) nsign_changed = np.sum((tmp[1:]*tmp[:-1]<0)&(res.y[1:-1]-res.peak_val<500)) if nsign_changed>1: import matplotlib.pyplot as plt plt.ion() plt.plot(res.x, res.y-res.peak_val, '-o') plt.plot(res.peak_pos - node.branch_length_interpolator.x, node.branch_length_interpolator.y-node.branch_length_interpolator.peak_val, '-o') plt.plot(msg_parent_to_node.x,msg_parent_to_node.y-msg_parent_to_node.peak_val, '-o') plt.ylim(0,100) plt.xlim(-0.01, 0.01) import ipdb; ipdb.set_trace()
def _set_final_dates(self): """ Given the location of the node in branch length units, convert it to the date-time information. Args: - node(Phylo.Clade): tree node. NOTE the node should have the abs_t attribute to have a valid value. This is automatically taken care of in the procedure to get the node location probability distribution. """ self.logger("ClockTree: Setting dates and node distributions...", 2) def collapse_func(dist): if dist.is_delta: return dist.peak_pos else: return dist.peak_pos for node in self.tree.find_clades(order='preorder'): # ancestors first, msg to children # set marginal distribution ## This is the root node if node.up is None: node.marginal_lh = node.msg_to_parent elif node.msg_to_parent is None: node.marginal_lh = node.msg_from_parent elif node.msg_to_parent.is_delta: node.marginal_lh = node.msg_to_parent else: node.marginal_lh = NodeInterpolator.multiply((node.msg_from_parent, node.msg_to_parent)) if node.up is None: node.joint_lh = node.msg_to_parent node.time_before_present = collapse_func(node.joint_lh) node.branch_length = self.one_mutation else: # shift position of parent node (time_before_present) by the branch length # towards the present. To do so, add branch length to negative time_before_present # and rescale the resulting distribution by -1.0 res = Distribution.shifted_x(node.branch_length_interpolator, -node.up.time_before_present) res.x_rescale(-1.0) # multiply distribution from parent with those from children and determine peak if node.msg_to_parent is not None: node.joint_lh = NodeInterpolator.multiply((node.msg_to_parent, res)) else: node.joint_lh = res node.time_before_present = collapse_func(node.joint_lh) node.branch_length = node.up.time_before_present - node.time_before_present node.clock_length = node.branch_length
def _ml_t_leaves_root(self): """ Compute the probability distribution of the internal nodes positions by propagating from the tree leaves towards the root. The result of this operation are the probability distributions of each internal node, conditional on the constraints on leaves in the descendant subtree. The exception is the root of the tree, as its subtree includes all the constrained leaves. To the final location probability distribution of the internal nodes, is calculated via back-propagation in _ml_t_root_to_leaves. Args: - None: all required parameters are pre-set as the node attributes during tree preparation Returns: - None: Every internal node is assigned the probability distribution in form of an interpolation object and sends this distribution further towards the root. """ def _send_message(node, **kwargs): """ Calc the desired LH distribution of the parent """ if node.msg_to_parent.is_delta: res = Distribution.shifted_x(node.branch_length_interpolator, node.msg_to_parent.peak_pos) else: # convolve two distributions res = NodeInterpolator.convolve(node.msg_to_parent, node.branch_length_interpolator, n_integral=self.n_integral, rel_tol=self.rel_tol_refine) self.logger("ClockTree._ml_t_leaves_root._send_message: " "computed convolution with %d points at node %s"%(len(res.x),node.name),4) return res self.logger("ClockTree: Propagating leaves -> root...", 2) # go through the nodes from leaves towards the root: for node in self.tree.find_clades(order='postorder'): # children first, msg to parents if node.is_terminal(): node.msgs_from_leaves = {} else: # save all messages from the children nodes with constraints # store as dictionary to exclude nodes from the set when necessary # (see below) node.msgs_from_leaves = {clade: _send_message(clade) for clade in node.clades if clade.msg_to_parent is not None} if len(node.msgs_from_leaves) < 1: # we need at least one constraint continue # this is what the node sends to the parent node.msg_to_parent = NodeInterpolator.multiply(node.msgs_from_leaves.values()) node.msg_to_parent._adjust_grid(rel_tol=self.rel_tol_prune)
def _send_message(node, **kwargs): """ Calc the desired LH distribution of the parent """ if node.msg_to_parent.is_delta: res = Distribution.shifted_x(node.branch_length_interpolator, node.msg_to_parent.peak_pos) else: # convolve two distributions res = NodeInterpolator.convolve(node.msg_to_parent, node.branch_length_interpolator, n_integral=self.n_integral, rel_tol=self.rel_tol_refine) self.logger("ClockTree._ml_t_leaves_root._send_message: " "computed convolution with %d points at node %s"%(len(res.x),node.name),4) return res
def _ml_t_marginal(self, assign_dates=False): """ Compute the marginal probability distribution of the internal nodes positions by propagating from the tree leaves towards the root. The result of this operation are the probability distributions of each internal node, conditional on the constraints on all leaves of the tree, which have sampling dates. The probability distributions are set as marginal_pos_LH attributes to the nodes. Parameters ---------- assign_dates : bool, default False If True, the inferred dates will be assigned to the nodes as :code:`time_before_present' attributes, and their branch lengths will be corrected accordingly. .. Note:: Normally, the dates are assigned by running joint reconstruction. Returns ------- None Every internal node is assigned the probability distribution in form of an interpolation object and sends this distribution further towards the root. """ def _cleanup(): for node in self.tree.find_clades(): try: del node.marginal_pos_Lx del node.subtree_distribution del node.msg_from_parent #del node.marginal_pos_LH except: pass self.logger("ClockTree - Marginal reconstruction: Propagating leaves -> root...", 2) # go through the nodes from leaves towards the root: for node in self.tree.find_clades(order='postorder'): # children first, msg to parents if node.bad_branch: # no information node.marginal_pos_Lx = None else: # all other nodes if node.date_constraint is not None and node.date_constraint.is_delta: # there is a time constraint # initialize the Lx for nodes with precise date constraint: # subtree probability given the position of the parent node # position of the parent node is given by the branch length # distribution attached to the child node position node.subtree_distribution = node.date_constraint bl = node.branch_length_interpolator.x x = bl + node.date_constraint.peak_pos node.marginal_pos_Lx = Distribution(x, node.branch_length_interpolator(bl), min_width=self.min_width, is_log=True) else: # all nodes without precise constraint but positional information # subtree likelihood given the node's constraint and child msg: msgs_to_multiply = [node.date_constraint] if node.date_constraint is not None else [] msgs_to_multiply.extend([child.marginal_pos_Lx for child in node.clades if child.marginal_pos_Lx is not None]) # combine the different msgs and constraints if len(msgs_to_multiply)==0: # no information node.marginal_pos_Lx = None continue elif len(msgs_to_multiply)==1: node.subtree_distribution = msgs_to_multiply[0] else: # combine the different msgs and constraints node.subtree_distribution = Distribution.multiply(msgs_to_multiply) if node.up is None: # this is the root, set dates node.subtree_distribution._adjust_grid(rel_tol=self.rel_tol_prune) node.marginal_pos_Lx = node.subtree_distribution node.marginal_pos_LH = node.subtree_distribution self.tree.positional_marginal_LH = -node.subtree_distribution.peak_val else: # otherwise propagate to parent res, res_t = NodeInterpolator.convolve(node.subtree_distribution, node.branch_length_interpolator, max_or_integral='integral', n_grid_points = self.node_grid_points, n_integral=self.n_integral, rel_tol=self.rel_tol_refine) res._adjust_grid(rel_tol=self.rel_tol_prune) node.marginal_pos_Lx = res self.logger("ClockTree - Marginal reconstruction: Propagating root -> leaves...", 2) from scipy.interpolate import interp1d for node in self.tree.find_clades(order='preorder'): ## The root node if node.up is None: node.msg_from_parent = None # nothing beyond the root # all other cases (All internal nodes + unconstrained terminals) else: parent = node.up # messages from the complementary subtree (iterate over all sister nodes) complementary_msgs = [sister.marginal_pos_Lx for sister in parent.clades if (sister != node) and (sister.marginal_pos_Lx is not None)] # if parent itself got smth from the root node, include it if parent.msg_from_parent is not None: complementary_msgs.append(parent.msg_from_parent) elif parent.marginal_pos_Lx is not None: complementary_msgs.append(parent.marginal_pos_LH) if len(complementary_msgs): msg_parent_to_node = NodeInterpolator.multiply(complementary_msgs) msg_parent_to_node._adjust_grid(rel_tol=self.rel_tol_prune) else: from utils import numeric_date x = [parent.numdate, numeric_date()] msg_parent_to_node = NodeInterpolator(x, [1.0, 1.0],min_width=self.min_width) # integral message, which delivers to the node the positional information # from the complementary subtree res, res_t = NodeInterpolator.convolve(msg_parent_to_node, node.branch_length_interpolator, max_or_integral='integral', inverse_time=False, n_grid_points = self.node_grid_points, n_integral=self.n_integral, rel_tol=self.rel_tol_refine) node.msg_from_parent = res if node.marginal_pos_Lx is None: node.marginal_pos_LH = node.msg_from_parent else: node.marginal_pos_LH = NodeInterpolator.multiply((node.msg_from_parent, node.subtree_distribution)) self.logger('ClockTree._ml_t_root_to_leaves: computed convolution' ' with %d points at node %s'%(len(res.x),node.name),4) if self.debug: tmp = np.diff(res.y-res.peak_val) nsign_changed = np.sum((tmp[1:]*tmp[:-1]<0)&(res.y[1:-1]-res.peak_val<500)) if nsign_changed>1: import matplotlib.pyplot as plt plt.ion() plt.plot(res.x, res.y-res.peak_val, '-o') plt.plot(res.peak_pos - node.branch_length_interpolator.x, node.branch_length_interpolator(node.branch_length_interpolator.x)-node.branch_length_interpolator.peak_val, '-o') plt.plot(msg_parent_to_node.x,msg_parent_to_node.y-msg_parent_to_node.peak_val, '-o') plt.ylim(0,100) plt.xlim(-0.05, 0.05) import ipdb; ipdb.set_trace() # assign positions of nodes and branch length only when desired # since marginal reconstruction can result in negative branch length if assign_dates: node.time_before_present = node.marginal_pos_LH.peak_pos if node.up: node.clock_length = node.up.time_before_present - node.time_before_present node.branch_length = node.clock_length # construct the inverse cumulant distribution to evaluate confidence intervals if node.marginal_pos_LH.is_delta: node.marginal_inverse_cdf=interp1d([0,1], node.marginal_pos_LH.peak_pos*np.ones(2), kind="linear") else: dt = np.diff(node.marginal_pos_LH.x) y = node.marginal_pos_LH.prob_relative(node.marginal_pos_LH.x) int_y = np.concatenate(([0], np.cumsum(dt*(y[1:]+y[:-1])/2.0))) int_y/=int_y[-1] node.marginal_inverse_cdf = interp1d(int_y, node.marginal_pos_LH.x, kind="linear") node.marginal_cdf = interp1d(node.marginal_pos_LH.x, int_y, kind="linear") if not self.debug: _cleanup() return
def _ml_t_joint(self): """ Compute the joint maximum likelihood assignment of the internal nodes positions by propagating from the tree leaves towards the root. Given the assignment of parent nodes, reconstruct the maximum-likelihood positions of the child nodes by propagating from the root to the leaves. The result of this operation is the time_before_present value, which is the position of the node, expressed in the units of the branch length, and scaled from the present-day. The value is assigned to the corresponding attribute of each node of the tree. Returns ------- None Every internal node is assigned the probability distribution in form of an interpolation object and sends this distribution further towards the root. """ def _cleanup(): for node in self.tree.find_clades(): del node.joint_pos_Lx del node.joint_pos_Cx self.logger("ClockTree - Joint reconstruction: Propagating leaves -> root...", 2) # go through the nodes from leaves towards the root: for node in self.tree.find_clades(order='postorder'): # children first, msg to parents # Lx is the maximal likelihood of a subtree given the parent position # Cx is the branch length corresponding to the maximally likely subtree if node.bad_branch: # no information at the node node.joint_pos_Lx = None node.joint_pos_Cx = None else: # all other nodes if node.date_constraint is not None and node.date_constraint.is_delta: # there is a time constraint # subtree probability given the position of the parent node # Lx.x is the position of the parent node # Lx.y is the probablity of the subtree (consisting of one terminal node in this case) # Cx.y is the branch length corresponding the optimal subtree bl = node.branch_length_interpolator.x x = bl + node.date_constraint.peak_pos node.joint_pos_Lx = Distribution(x, node.branch_length_interpolator(bl), min_width=self.min_width, is_log=True) node.joint_pos_Cx = Distribution(x, bl, min_width=self.min_width) # map back to the branch length else: # all nodes without precise constraint but positional information msgs_to_multiply = [node.date_constraint] if node.date_constraint is not None else [] msgs_to_multiply.extend([child.joint_pos_Lx for child in node.clades if child.joint_pos_Lx is not None]) # subtree likelihood given the node's constraint and child messages if len(msgs_to_multiply) == 0: # there are no constraints node.joint_pos_Lx = None node.joint_pos_Cx = None continue elif len(msgs_to_multiply)>1: # combine the different msgs and constraints subtree_distribution = Distribution.multiply(msgs_to_multiply) else: # there is exactly one constraint. subtree_distribution = msgs_to_multiply[0] if node.up is None: # this is the root, set dates subtree_distribution._adjust_grid(rel_tol=self.rel_tol_prune) # set root position and joint likelihood of the tree node.time_before_present = subtree_distribution.peak_pos node.joint_pos_Lx = subtree_distribution node.joint_pos_Cx = None node.clock_length = node.branch_length else: # otherwise propagate to parent res, res_t = NodeInterpolator.convolve(subtree_distribution, node.branch_length_interpolator, max_or_integral='max', inverse_time=True, n_grid_points = self.node_grid_points, n_integral=self.n_integral, rel_tol=self.rel_tol_refine) res._adjust_grid(rel_tol=self.rel_tol_prune) node.joint_pos_Lx = res node.joint_pos_Cx = res_t # go through the nodes from root towards the leaves and assign joint ML positions: self.logger("ClockTree - Joint reconstruction: Propagating root -> leaves...", 2) for node in self.tree.find_clades(order='preorder'): # root first, msgs to children if node.up is None: # root node continue # the position was already set on the previous step if node.joint_pos_Cx is None: # no constraints or branch is bad - reconstruct from the branch len interpolator node.branch_length = node.branch_length_interpolator.peak_pos elif isinstance(node.joint_pos_Cx, Distribution): # NOTE the Lx distribution is the likelihood, given the position of the parent # (Lx.x = parent position, Lx.y = LH of the node_pos given Lx.x, # the length of the branch corresponding to the most likely # subtree is node.Cx(node.time_before_present)) subtree_LH = node.joint_pos_Lx(node.up.time_before_present) node.branch_length = node.joint_pos_Cx(max(node.joint_pos_Cx.xmin, node.up.time_before_present)+ttconf.TINY_NUMBER) node.time_before_present = node.up.time_before_present - node.branch_length node.clock_length = node.branch_length # just sanity check, should never happen: if node.branch_length < 0 or node.time_before_present < 0: if node.branch_length<0 and node.branch_length>-ttconf.TINY_NUMBER: self.logger("ClockTree - Joint reconstruction: correcting rounding error of %s"%node.name, 4) node.branch_length = 0 self.tree.positional_joint_LH = self.timetree_likelihood() # cleanup, if required if not self.debug: _cleanup()
def init_date_constraints(self, ancestral_inference=False, slope=None, **kwarks): """ Get the conversion coefficients between the dates and the branch lengths as they are used in ML computations. The conversion formula is assumed to be 'length = k*numdate_given + b'. For convenience, these coefficients as well as regression parameters are stored in the dates2dist object. Note: that tree must have dates set to all nodes before calling this function. (This is accomplished by calling load_dates func). Params: ancestral_inference: bool -- whether or not to reinfer ancestral sequences done by default when ancestral sequences are missing """ self.logger("ClockTree.init_date_constraints...",2) if ancestral_inference or (not hasattr(self.tree.root, 'sequence')): self.infer_ancestral_sequences('ml',sample_from_profile='root',**kwarks) # set the None for the date-related attributes in the internal nodes. # make interpolation objects for the branches self.logger('ClockTree.init_date_constraints: Initializing branch length interpolation objects...',3) for node in self.tree.find_clades(): if node.up is None: node.branch_length_interpolator = None else: # copy the merger rate and gamma if they are set if hasattr(node,'branch_length_interpolator'): gamma = node.branch_length_interpolator.gamma merger_rate = node.branch_length_interpolator.merger_rate else: gamma = 1.0 merger_rate = self.merger_rate_default node.branch_length_interpolator = BranchLenInterpolator(node, self.gtr, one_mutation=self.one_mutation) node.branch_length_interpolator.merger_rate = merger_rate node.branch_length_interpolator.gamma = gamma self.date2dist = utils.DateConversion.from_tree(self.tree, slope) # make node distribution objects for node in self.tree.find_clades(): # node is constrained if hasattr(node, 'numdate_given') and node.numdate_given is not None: # set the absolute time before present in branch length units if not np.isscalar(node.numdate_given): node.numdate_given = np.array(node.numdate_given) node.time_before_present = self.date2dist.get_time_before_present(node.numdate_given) if hasattr(node, 'bad_branch') and node.bad_branch==True: self.logger("ClockTree.init_date_constraints -- WARNING: Branch is marked as bad" ", excluding it from the optimization process" " Will be optimized freely", 4, warn=True) # if there are no constraints - log_prob will be set on-the-fly node.msg_to_parent = None else: if np.isscalar(node.numdate_given): node.msg_to_parent = NodeInterpolator.delta_function(node.time_before_present, weight=1) else: node.msg_to_parent = NodeInterpolator(node.time_before_present, np.ones_like(node.time_before_present), is_log=False) else: # node without sampling date set node.numdate_given = None node.time_before_present = None # if there are no constraints - log_prob will be set on-the-fly node.msg_to_parent = None