def _ml_t_marginal(self, assign_dates=False): """ Compute the marginal probability distribution of the internal nodes positions by propagating from the tree leaves towards the root. The result of this operation are the probability distributions of each internal node, conditional on the constraints on all leaves of the tree, which have sampling dates. The probability distributions are set as marginal_pos_LH attributes to the nodes. Parameters ---------- assign_dates : bool, default False If True, the inferred dates will be assigned to the nodes as :code:`time_before_present' attributes, and their branch lengths will be corrected accordingly. .. Note:: Normally, the dates are assigned by running joint reconstruction. Returns ------- None Every internal node is assigned the probability distribution in form of an interpolation object and sends this distribution further towards the root. """ def _cleanup(): for node in self.tree.find_clades(): try: del node.marginal_pos_Lx del node.subtree_distribution del node.msg_from_parent #del node.marginal_pos_LH except: pass self.logger("ClockTree - Marginal reconstruction: Propagating leaves -> root...", 2) # go through the nodes from leaves towards the root: for node in self.tree.find_clades(order='postorder'): # children first, msg to parents if node.bad_branch: # no information node.marginal_pos_Lx = None else: # all other nodes if node.date_constraint is not None and node.date_constraint.is_delta: # there is a time constraint # initialize the Lx for nodes with precise date constraint: # subtree probability given the position of the parent node # position of the parent node is given by the branch length # distribution attached to the child node position node.subtree_distribution = node.date_constraint bl = node.branch_length_interpolator.x x = bl + node.date_constraint.peak_pos node.marginal_pos_Lx = Distribution(x, node.branch_length_interpolator(bl), min_width=self.min_width, is_log=True) else: # all nodes without precise constraint but positional information # subtree likelihood given the node's constraint and child msg: msgs_to_multiply = [node.date_constraint] if node.date_constraint is not None else [] msgs_to_multiply.extend([child.marginal_pos_Lx for child in node.clades if child.marginal_pos_Lx is not None]) # combine the different msgs and constraints if len(msgs_to_multiply)==0: # no information node.marginal_pos_Lx = None continue elif len(msgs_to_multiply)==1: node.subtree_distribution = msgs_to_multiply[0] else: # combine the different msgs and constraints node.subtree_distribution = Distribution.multiply(msgs_to_multiply) if node.up is None: # this is the root, set dates node.subtree_distribution._adjust_grid(rel_tol=self.rel_tol_prune) node.marginal_pos_Lx = node.subtree_distribution node.marginal_pos_LH = node.subtree_distribution self.tree.positional_marginal_LH = -node.subtree_distribution.peak_val else: # otherwise propagate to parent res, res_t = NodeInterpolator.convolve(node.subtree_distribution, node.branch_length_interpolator, max_or_integral='integral', n_grid_points = self.node_grid_points, n_integral=self.n_integral, rel_tol=self.rel_tol_refine) res._adjust_grid(rel_tol=self.rel_tol_prune) node.marginal_pos_Lx = res self.logger("ClockTree - Marginal reconstruction: Propagating root -> leaves...", 2) from scipy.interpolate import interp1d for node in self.tree.find_clades(order='preorder'): ## The root node if node.up is None: node.msg_from_parent = None # nothing beyond the root # all other cases (All internal nodes + unconstrained terminals) else: parent = node.up # messages from the complementary subtree (iterate over all sister nodes) complementary_msgs = [sister.marginal_pos_Lx for sister in parent.clades if (sister != node) and (sister.marginal_pos_Lx is not None)] # if parent itself got smth from the root node, include it if parent.msg_from_parent is not None: complementary_msgs.append(parent.msg_from_parent) elif parent.marginal_pos_Lx is not None: complementary_msgs.append(parent.marginal_pos_LH) if len(complementary_msgs): msg_parent_to_node = NodeInterpolator.multiply(complementary_msgs) msg_parent_to_node._adjust_grid(rel_tol=self.rel_tol_prune) else: from utils import numeric_date x = [parent.numdate, numeric_date()] msg_parent_to_node = NodeInterpolator(x, [1.0, 1.0],min_width=self.min_width) # integral message, which delivers to the node the positional information # from the complementary subtree res, res_t = NodeInterpolator.convolve(msg_parent_to_node, node.branch_length_interpolator, max_or_integral='integral', inverse_time=False, n_grid_points = self.node_grid_points, n_integral=self.n_integral, rel_tol=self.rel_tol_refine) node.msg_from_parent = res if node.marginal_pos_Lx is None: node.marginal_pos_LH = node.msg_from_parent else: node.marginal_pos_LH = NodeInterpolator.multiply((node.msg_from_parent, node.subtree_distribution)) self.logger('ClockTree._ml_t_root_to_leaves: computed convolution' ' with %d points at node %s'%(len(res.x),node.name),4) if self.debug: tmp = np.diff(res.y-res.peak_val) nsign_changed = np.sum((tmp[1:]*tmp[:-1]<0)&(res.y[1:-1]-res.peak_val<500)) if nsign_changed>1: import matplotlib.pyplot as plt plt.ion() plt.plot(res.x, res.y-res.peak_val, '-o') plt.plot(res.peak_pos - node.branch_length_interpolator.x, node.branch_length_interpolator(node.branch_length_interpolator.x)-node.branch_length_interpolator.peak_val, '-o') plt.plot(msg_parent_to_node.x,msg_parent_to_node.y-msg_parent_to_node.peak_val, '-o') plt.ylim(0,100) plt.xlim(-0.05, 0.05) import ipdb; ipdb.set_trace() # assign positions of nodes and branch length only when desired # since marginal reconstruction can result in negative branch length if assign_dates: node.time_before_present = node.marginal_pos_LH.peak_pos if node.up: node.clock_length = node.up.time_before_present - node.time_before_present node.branch_length = node.clock_length # construct the inverse cumulant distribution to evaluate confidence intervals if node.marginal_pos_LH.is_delta: node.marginal_inverse_cdf=interp1d([0,1], node.marginal_pos_LH.peak_pos*np.ones(2), kind="linear") else: dt = np.diff(node.marginal_pos_LH.x) y = node.marginal_pos_LH.prob_relative(node.marginal_pos_LH.x) int_y = np.concatenate(([0], np.cumsum(dt*(y[1:]+y[:-1])/2.0))) int_y/=int_y[-1] node.marginal_inverse_cdf = interp1d(int_y, node.marginal_pos_LH.x, kind="linear") node.marginal_cdf = interp1d(node.marginal_pos_LH.x, int_y, kind="linear") if not self.debug: _cleanup() return
def _ml_t_joint(self): """ Compute the joint maximum likelihood assignment of the internal nodes positions by propagating from the tree leaves towards the root. Given the assignment of parent nodes, reconstruct the maximum-likelihood positions of the child nodes by propagating from the root to the leaves. The result of this operation is the time_before_present value, which is the position of the node, expressed in the units of the branch length, and scaled from the present-day. The value is assigned to the corresponding attribute of each node of the tree. Returns ------- None Every internal node is assigned the probability distribution in form of an interpolation object and sends this distribution further towards the root. """ def _cleanup(): for node in self.tree.find_clades(): del node.joint_pos_Lx del node.joint_pos_Cx self.logger("ClockTree - Joint reconstruction: Propagating leaves -> root...", 2) # go through the nodes from leaves towards the root: for node in self.tree.find_clades(order='postorder'): # children first, msg to parents # Lx is the maximal likelihood of a subtree given the parent position # Cx is the branch length corresponding to the maximally likely subtree if node.bad_branch: # no information at the node node.joint_pos_Lx = None node.joint_pos_Cx = None else: # all other nodes if node.date_constraint is not None and node.date_constraint.is_delta: # there is a time constraint # subtree probability given the position of the parent node # Lx.x is the position of the parent node # Lx.y is the probablity of the subtree (consisting of one terminal node in this case) # Cx.y is the branch length corresponding the optimal subtree bl = node.branch_length_interpolator.x x = bl + node.date_constraint.peak_pos node.joint_pos_Lx = Distribution(x, node.branch_length_interpolator(bl), min_width=self.min_width, is_log=True) node.joint_pos_Cx = Distribution(x, bl, min_width=self.min_width) # map back to the branch length else: # all nodes without precise constraint but positional information msgs_to_multiply = [node.date_constraint] if node.date_constraint is not None else [] msgs_to_multiply.extend([child.joint_pos_Lx for child in node.clades if child.joint_pos_Lx is not None]) # subtree likelihood given the node's constraint and child messages if len(msgs_to_multiply) == 0: # there are no constraints node.joint_pos_Lx = None node.joint_pos_Cx = None continue elif len(msgs_to_multiply)>1: # combine the different msgs and constraints subtree_distribution = Distribution.multiply(msgs_to_multiply) else: # there is exactly one constraint. subtree_distribution = msgs_to_multiply[0] if node.up is None: # this is the root, set dates subtree_distribution._adjust_grid(rel_tol=self.rel_tol_prune) # set root position and joint likelihood of the tree node.time_before_present = subtree_distribution.peak_pos node.joint_pos_Lx = subtree_distribution node.joint_pos_Cx = None node.clock_length = node.branch_length else: # otherwise propagate to parent res, res_t = NodeInterpolator.convolve(subtree_distribution, node.branch_length_interpolator, max_or_integral='max', inverse_time=True, n_grid_points = self.node_grid_points, n_integral=self.n_integral, rel_tol=self.rel_tol_refine) res._adjust_grid(rel_tol=self.rel_tol_prune) node.joint_pos_Lx = res node.joint_pos_Cx = res_t # go through the nodes from root towards the leaves and assign joint ML positions: self.logger("ClockTree - Joint reconstruction: Propagating root -> leaves...", 2) for node in self.tree.find_clades(order='preorder'): # root first, msgs to children if node.up is None: # root node continue # the position was already set on the previous step if node.joint_pos_Cx is None: # no constraints or branch is bad - reconstruct from the branch len interpolator node.branch_length = node.branch_length_interpolator.peak_pos elif isinstance(node.joint_pos_Cx, Distribution): # NOTE the Lx distribution is the likelihood, given the position of the parent # (Lx.x = parent position, Lx.y = LH of the node_pos given Lx.x, # the length of the branch corresponding to the most likely # subtree is node.Cx(node.time_before_present)) subtree_LH = node.joint_pos_Lx(node.up.time_before_present) node.branch_length = node.joint_pos_Cx(max(node.joint_pos_Cx.xmin, node.up.time_before_present)+ttconf.TINY_NUMBER) node.time_before_present = node.up.time_before_present - node.branch_length node.clock_length = node.branch_length # just sanity check, should never happen: if node.branch_length < 0 or node.time_before_present < 0: if node.branch_length<0 and node.branch_length>-ttconf.TINY_NUMBER: self.logger("ClockTree - Joint reconstruction: correcting rounding error of %s"%node.name, 4) node.branch_length = 0 self.tree.positional_joint_LH = self.timetree_likelihood() # cleanup, if required if not self.debug: _cleanup()