def annotate_tree_depths(tree: CassiopeiaTree) -> None: """Annotates tree depth at every node. Adds two attributes to the tree: how far away each node is from the root of the tree and how many triplets are rooted at that node. Modifies the tree in place. Args: tree: An ete3 Tree Returns: A dictionary mapping depth to the list of nodes at that depth. """ depth_to_nodes = defaultdict(list) for n in tree.depth_first_traverse_nodes(source=tree.root, postorder=False): if tree.is_root(n): tree.set_attribute(n, "depth", 0) else: tree.set_attribute(n, "depth", tree.get_attribute(tree.parent(n), "depth") + 1) depth_to_nodes[tree.get_attribute(n, "depth")].append(n) number_of_leaves = 0 correction = 0 for child in tree.children(n): number_of_leaves += len(tree.leaves_in_subtree(child)) correction += nCr(len(tree.leaves_in_subtree(child)), 3) tree.set_attribute(n, "number_of_triplets", nCr(number_of_leaves, 3) - correction) return depth_to_nodes
def fitch_hartigan_top_down( cassiopeia_tree: CassiopeiaTree, root: Optional[str] = None, state_key: str = "S1", label_key: str = "label", copy: bool = False, ) -> Optional[CassiopeiaTree]: """Run Fitch-Hartigan top-down refinement Runs the Fitch-Hartigan top-down algorithm which selects an optimal solution from the tree rooted at the specified root. Args: cassiopeia_tree: CassiopeiaTree that has been processed with the Fitch-Hartigan bottom-up algorithm. root: Root from which to begin this refinement. Only the subtree below this node will be considered. state_key: Attribute key that stores the Fitch-Hartigan ancestral states. label_key: Key to add that stores the maximum-parsimony assignment inferred from the Fitch-Hartigan top-down refinement. copy: Modify the tree in place or not. Returns: A new CassiopeiaTree if the copy is set to True, else None. Raises: A CassiopeiaTreeError if Fitch-Hartigan bottom-up has not been called or if the state_key does not exist for a node. """ # assign root root = cassiopeia_tree.root if (root is None) else root cassiopeia_tree = cassiopeia_tree.copy() if copy else cassiopeia_tree for node in cassiopeia_tree.depth_first_traverse_nodes(source=root, postorder=False): if node == root: root_states = cassiopeia_tree.get_attribute(root, state_key) cassiopeia_tree.set_attribute(root, label_key, np.random.choice(root_states)) continue parent = cassiopeia_tree.parent(node) parent_label = cassiopeia_tree.get_attribute(parent, label_key) optimal_node_states = cassiopeia_tree.get_attribute(node, state_key) if parent_label in optimal_node_states: cassiopeia_tree.set_attribute(node, label_key, parent_label) else: cassiopeia_tree.set_attribute( node, label_key, np.random.choice(optimal_node_states)) return cassiopeia_tree if copy else None
def _N_fitch_count( cassiopeia_tree: CassiopeiaTree, unique_states: List[str], node_to_i: Dict[str, int], label_to_j: Dict[str, int], state_key: str = "S1", ) -> np.array(int): """Fill in the dynamic programming table N for FitchCount. Computes N[v, s], corresponding to the number of solutions below a node v in the tree given v takes on the state s. Args: cassiopeia_tree: CassiopeiaTree object unique_states: The state space that a node can take on node_to_i: Helper array storing a mapping of each node to a unique integer label_to_j: Helper array storing a mapping of each unique state in the state space to a unique integer state_key: Attribute name in the CassiopeiaTree storing the possible states for each node, as inferred with the Fitch-Hartigan algorithm Returns: A 2-dimensional array storing N[v, s] - the number of equally-parsimonious solutions below node v, given v takes on state s """ def _fill(v: str, s: str): """Helper function to fill in a single entry in N.""" if cassiopeia_tree.is_leaf(v): return 1 children = cassiopeia_tree.children(v) A = np.zeros((len(children))) legal_states = [] for i, u in zip(range(len(children)), children): if s not in cassiopeia_tree.get_attribute(u, state_key): legal_states = cassiopeia_tree.get_attribute(u, state_key) else: legal_states = [s] A[i] = np.sum( [N[node_to_i[u], label_to_j[sp]] for sp in legal_states]) return np.prod([A[u] for u in range(len(A))]) N = np.full((len(cassiopeia_tree.nodes), len(unique_states)), 0.0) for n in cassiopeia_tree.depth_first_traverse_nodes(): for s in cassiopeia_tree.get_attribute(n, state_key): N[node_to_i[n], label_to_j[s]] = _fill(n, s) return N
def create_clade_colors( tree: CassiopeiaTree, clade_colors: Dict[str, Tuple[float, float, float]] ) -> Tuple[Dict[str, Tuple[float, float, float]], Dict[Tuple[str, str], Tuple[ float, float, float]], ]: """Assign colors to nodes and branches by clade. Args: tree: The CassiopeiaTree. clade_colors: Dictionary containing internal node-color mappings. These colors will be used to color all the paths from this node to the leaves the provided color. Returns: Two dictionaries. The first contains the node colors, and the second contains the branch colors. """ # Deal with clade colors. descendants = {} for node in clade_colors.keys(): descendants[node] = set(tree.depth_first_traverse_nodes(node)) if len(set.union(*list(descendants.values()))) != sum( len(d) for d in descendants.values()): warnings.warn( "Some clades specified with `clade_colors` are overlapping. " "Colors may be overridden.", PlottingWarning, ) # Color by largest clade first node_colors = {} branch_colors = {} for node in sorted(descendants, key=lambda x: len(descendants[x]), reverse=True): color = clade_colors[node] for n1, n2 in tree.depth_first_traverse_edges(node): node_colors[n1] = node_colors[n2] = color branch_colors[(n1, n2)] = color return node_colors, branch_colors
def fitch_hartigan_bottom_up( cassiopeia_tree: CassiopeiaTree, meta_item: str, add_key: str = "S1", copy: bool = False, ) -> Optional[CassiopeiaTree]: """Performs Fitch-Hartigan bottom-up ancestral reconstruction. Performs the bottom-up phase of the Fitch-Hartigan small parsimony algorithm. A new attribute called "S1" will be added to each node storing the optimal set of ancestral states inferred from this bottom-up algorithm. If copy is False, the tree will be modified in place. Args: cassiopeia_tree: CassiopeiaTree object with cell meta data. meta_item: A column in the CassiopeiaTree cell meta corresponding to a categorical variable. add_key: Key to add for bottom-up reconstruction copy: Modify the tree in place or not. Returns: A new CassiopeiaTree if the copy is set to True, else None. Raises: CassiopeiaError if the tree does not have the specified meta data or the meta data is not categorical. """ if meta_item not in cassiopeia_tree.cell_meta.columns: raise CassiopeiaError( "Meta item does not exist in the cassiopeia tree") meta = cassiopeia_tree.cell_meta[meta_item] if is_numeric_dtype(meta): raise CassiopeiaError("Meta item is not a categorical variable.") if not is_categorical_dtype(meta): meta = meta.astype("category") cassiopeia_tree = cassiopeia_tree.copy() if copy else cassiopeia_tree for node in cassiopeia_tree.depth_first_traverse_nodes(): if cassiopeia_tree.is_leaf(node): cassiopeia_tree.set_attribute(node, add_key, [meta.loc[node]]) else: children = cassiopeia_tree.children(node) if len(children) == 1: child_assignment = cassiopeia_tree.get_attribute( children[0], add_key) cassiopeia_tree.set_attribute(node, add_key, [child_assignment]) all_labels = np.concatenate([ cassiopeia_tree.get_attribute(child, add_key) for child in children ]) states, frequencies = np.unique(all_labels, return_counts=True) S1 = states[np.where(frequencies == np.max(frequencies))] cassiopeia_tree.set_attribute(node, add_key, S1) return cassiopeia_tree if copy else None
def _C_fitch_count( cassiopeia_tree: CassiopeiaTree, N: np.array, unique_states: List[str], node_to_i: Dict[str, int], label_to_j: Dict[str, int], state_key: str = "S1", ) -> np.array(int): """Fill in the dynamic programming table C for FitchCount. Computes C[v, s, s1, s2], the number of transitions from state s1 to state s2 in the subtree rooted at v, given that state v takes on the state s. Args: cassiopeia_tree: CassiopeiaTree object N: N array computed during FitchCount storing the number of solutions below a node v given v takes on state s unique_states: The state space that a node can take on node_to_i: Helper array storing a mapping of each node to a unique integer label_to_j: Helper array storing a mapping of each unique state in the state space to a unique integer state_key: Attribute name in the CassiopeiaTree storing the possible states for each node, as inferred with the Fitch-Hartigan algorithm Returns: A 4-dimensional array storing C[v, s, s1, s2] - the number of transitions from state s1 to s2 below a node v given v takes on the state s. """ def _fill(v: str, s: str, s1: str, s2: str) -> int: """Helper function to fill in a single entry in C.""" if cassiopeia_tree.is_leaf(v): return 0 children = cassiopeia_tree.children(v) A = np.zeros((len(children))) LS = [[]] * len(children) for i, u in zip(range(len(children)), children): if s in cassiopeia_tree.get_attribute(u, state_key): LS[i] = [s] else: LS[i] = cassiopeia_tree.get_attribute(u, state_key) A[i] = np.sum([ C[node_to_i[u], label_to_j[sp], label_to_j[s1], label_to_j[s2], ] for sp in LS[i] ]) if s1 == s and s2 in LS[i]: A[i] += N[node_to_i[u], label_to_j[s2]] parts = [] for i, u in zip(range(len(children)), children): prod = 1 for k, up in zip(range(len(children)), children): fact = 0 if up == u: continue for sp in LS[k]: fact += N[node_to_i[up], label_to_j[sp]] prod *= fact part = A[i] * prod parts.append(part) return np.sum(parts) C = np.zeros( (len(cassiopeia_tree.nodes), N.shape[1], N.shape[1], N.shape[1])) for n in cassiopeia_tree.depth_first_traverse_nodes(): for s in cassiopeia_tree.get_attribute(n, state_key): for (s1, s2) in itertools.product(unique_states, repeat=2): C[node_to_i[n], label_to_j[s], label_to_j[s1], label_to_j[s2]] = _fill(n, s, s1, s2) return C
def place_tree( tree: CassiopeiaTree, depth_key: Optional[str] = None, orient: Union[Literal["down", "up", "left", "right"], float] = "down", depth_scale: float = 1.0, width_scale: float = 1.0, extend_branches: bool = True, angled_branches: bool = True, polar_interpolation_threshold: float = 5.0, polar_interpolation_step: float = 1.0, add_root: bool = False, ) -> Tuple[Dict[str, Tuple[float, float]], Dict[Tuple[str, str], Tuple[ List[float], List[float]]], ]: """Given a tree, computes the coordinates of the nodes and branches. This function computes the x and y coordinates of all nodes and branches (as lines) to be used for visualization. Several options are provided to modify how the elements are placed. This function returns two dictionaries, where the first has nodes as keys and an (x, y) tuple as the values. Similarly, the second dictionary has (parent, child) tuples as keys denoting branches in the tree and a tuple (xs, ys) of two lists containing the x and y coordinates to draw the branch. This function also provides functionality to place the tree in polar coordinates as a circular plot when the `orient` is a number. In this case, the returned dictionaries have (thetas, radii) as its elements, which are the angles and radii in polar coordinates respectively. Note: This function only *places* (i.e. computes the x and y coordinates) a tree on a coordinate system and does no real plotting. Args: tree: The CassiopeiaTree to place on the coordinate grid. depth_key: The node attribute to use as the depth of the nodes. If not provided, the distances from the root is used by calling `tree.get_distances`. orient: The orientation of the tree. Valid arguments are `left`, `right`, `up`, `down` to display a rectangular plot (indicating the direction of going from root -> leaves) or any number, in which case the tree is placed in polar coordinates with the provided number used as an angle offset. depth_scale: Scale the depth of the tree by this amount. This option has no effect in polar coordinates. width_scale: Scale the width of the tree by this amount. This option has no effect in polar coordinates. extend_branches: Extend branch lengths such that the distance from the root to every node is the same. If `depth_key` is also provided, then only the leaf branches are extended to the deepest leaf. angled_branches: Display branches as angled, instead of as just a line from the parent to a child. polar_interpolation_threshold: When displaying in polar coordinates, many plotting frameworks (such as Plotly) just draws a straight line from one point to another, instead of scaling the radius appropriately. This effect is most noticeable when two points are connected that have a large angle between them. When the angle between two connected points in polar coordinates exceed this amount (in degrees), this function adds additional points that smoothly interpolate between the two endpoints. polar_interpolation_step: Interpolation step. See above. add_root: Add a root node so that only one branch connects to the start of the tree. This node will have the name `synthetic_root`. Returns: Two dictionaries, where the first contains the node coordinates and the second contains the branch coordinates. """ root = tree.root nodes = tree.nodes.copy() edges = tree.edges.copy() depths = None if depth_key: depths = { node: tree.get_attribute(node, depth_key) for node in tree.nodes } else: depths = tree.get_distances(root) placement_depths = {} positions = {} leaf_i = 0 leaves = set() for node in tree.depth_first_traverse_nodes(postorder=False): if tree.is_leaf(node): positions[node] = leaf_i leaf_i += 1 leaves.add(node) placement_depths[node] = depths[node] if extend_branches: placement_depths[node] = (max(depths.values()) if depth_key else 0) # Place nodes by depth for node in sorted(depths, key=lambda k: depths[k], reverse=True): # Leaves have already been placed if node in leaves: continue # Find all immediate children and place at center. min_position = np.inf max_position = -np.inf min_depth = np.inf for child in tree.children(node): min_position = min(min_position, positions[child]) max_position = max(max_position, positions[child]) min_depth = min(min_depth, placement_depths[child]) positions[node] = (min_position + max_position) / 2 placement_depths[node] = (min_depth - 1 if extend_branches and not depth_key else depths[node]) # Add synthetic root if add_root: root_name = "synthetic_root" positions[root_name] = 0 placement_depths[root_name] = min(placement_depths.values()) - 1 nodes.append(root_name) edges.append((root_name, root)) polar = isinstance(orient, (float, int)) polar_depth_offset = -min(placement_depths.values()) polar_angle_scale = 360 / (len(leaves) + 1) # Define some helper functions to modify coordinate system. def reorient(pos, depth): pos *= width_scale depth *= depth_scale if orient == "down": return (pos, -depth) elif orient == "right": return (depth, pos) elif orient == "left": return (-depth, pos) # default: up return (pos, depth) def polarize(pos, depth): # angle, radius return ( (pos + 1) * polar_angle_scale + orient, depth + polar_depth_offset, ) node_coords = {} for node in nodes: pos = positions[node] depth = placement_depths[node] coords = polarize(pos, depth) if polar else reorient(pos, depth) node_coords[node] = coords branch_coords = {} for parent, child in edges: parent_pos, parent_depth = positions[parent], placement_depths[parent] child_pos, child_depth = positions[child], placement_depths[child] middle_x = (child_pos if angled_branches else (parent_pos + child_pos) / 2) middle_y = (parent_depth if angled_branches else (parent_depth + child_depth) / 2) _xs = [parent_pos, middle_x, child_pos] _ys = [parent_depth, middle_y, child_depth] xs = [] ys = [] for _x, _y in zip(_xs, _ys): if polar: x, y = polarize(_x, _y) if xs: # Interpolate to prevent weird angles if the angle exceeds threshold. prev_x = xs[-1] prev_y = ys[-1] if abs(x - prev_x) > polar_interpolation_threshold: num = int(abs(x - prev_x) / polar_interpolation_step) for inter_x, inter_y in zip( np.linspace(prev_x, x, num)[1:-1], np.linspace(prev_y, y, num)[1:-1], ): xs.append(inter_x) ys.append(inter_y) else: x, y = reorient(_x, _y) xs.append(x) ys.append(y) branch_coords[(parent, child)] = (xs, ys) return node_coords, branch_coords
def compute_expansion_pvalues( tree: CassiopeiaTree, min_clade_size: int = 10, min_depth: int = 1, copy: bool = False, ) -> Union[CassiopeiaTree, None]: """Call expansion pvalues on a tree. Uses the methodology described in Yang, Jones et al, BioRxiv (2021) to assess the expansion probability of a given subclade of a phylogeny. Mathematical treatment of the coalescent probability is described in Griffiths and Tavare, Stochastic Models (1998). The probability computed corresponds to the probability that, under a simple neutral coalescent model, a given subclade contains the observed number of cells; in other words, a one-sided p-value. Often, if the probability is less than some threshold (e.g., 0.05), this might indicate that there exists some subclade under this node that to which this expansion probability can be attributed (i.e. the null hypothesis that the subclade is undergoing neutral drift can be rejected). This function will add an attribute "expansion_pvalue" to the tree, and return None unless :param:`copy` is set to True. On a typical balanced tree, this function will perform in O(n log n) time, but can be up to O(n^3) on highly unbalanced trees. A future endeavor may be to impelement the function in O(n) time. Args: tree: CassiopeiaTree min_clade_size: Minimum number of leaves in a subtree to be considered. min_depth: Minimum depth of clade to be considered. Depth is measured in number of nodes from the root, not branch lengths. copy: Return copy. Returns: If copy is set to False, returns the tree with attributes added in place. Else, returns a new CassiopeiaTree. """ tree = tree.copy() if copy else tree # instantiate attributes _depths = {} for node in tree.depth_first_traverse_nodes(postorder=False): tree.set_attribute(node, "expansion_pvalue", 1.0) if tree.is_root(node): _depths[node] = 0 else: _depths[node] = _depths[tree.parent(node)] + 1 for node in tree.depth_first_traverse_nodes(postorder=False): n = len(tree.leaves_in_subtree(node)) k = len(tree.children(node)) for c in tree.children(node): if len(tree.leaves_in_subtree(c)) < min_clade_size: continue depth = _depths[c] if depth < min_depth: continue b = len(tree.leaves_in_subtree(c)) # this value below is a simplification of the quantity: # sum[simple_coalescent_probability(n, b2, k) for \ # b2 in range(b, n - k + 2)] p = nCk(n - b, k - 1) / nCk(n - 1, k - 1) tree.set_attribute(c, "expansion_pvalue", p) return tree if copy else None
def overlay_data(self, tree: CassiopeiaTree): """Overlays Cas9-based lineage tracing data onto the CassiopeiaTree. Args: tree: Input CassiopeiaTree """ if self.random_seed is not None: np.random.seed(self.random_seed) # create state priors if they don't exist. # This will set the instance's variable for mutation priors and will # use this for all future simulations. if self.mutation_priors is None: self.mutation_priors = {} probabilites = [ self.state_generating_distribution() for _ in range(self.number_of_states) ] Z = np.sum(probabilites) for i in range(self.number_of_states): self.mutation_priors[i + 1] = probabilites[i] / Z number_of_characters = self.number_of_cassettes * self.size_of_cassette # initialize character states character_matrix = {} for node in tree.nodes: character_matrix[node] = [-1] * number_of_characters for node in tree.depth_first_traverse_nodes(tree.root, postorder=False): if tree.is_root(node): character_matrix[node] = [0] * number_of_characters continue parent = tree.parent(node) life_time = tree.get_time(node) - tree.get_time(parent) character_array = character_matrix[parent] open_sites = [ c for c in range(len(character_array)) if character_array[c] == 0 ] new_cuts = [] for site in open_sites: mutation_rate = self.mutation_rate_per_character[site] mutation_probability = 1 - (np.exp(-life_time * mutation_rate)) if np.random.uniform() < mutation_probability: new_cuts.append(site) # collapse cuts that are on the same cassette cuts_remaining = new_cuts if self.collapse_sites_on_cassette and self.size_of_cassette > 1: character_array, cuts_remaining = self.collapse_sites( character_array, new_cuts ) # introduce new states at cut sites character_array = self.introduce_states( character_array, cuts_remaining ) # silence cassettes silencing_probability = 1 - ( np.exp(-life_time * self.heritable_silencing_rate) ) character_array = self.silence_cassettes( character_array, silencing_probability, self.heritable_missing_data_state, ) character_matrix[node] = character_array # apply stochastic silencing for leaf in tree.leaves: character_matrix[leaf] = self.silence_cassettes( character_matrix[leaf], self.stochastic_silencing_rate, self.stochastic_missing_data_state, ) tree.set_all_character_states(character_matrix)
def log_likelihood_of_character( tree: CassiopeiaTree, character: int, use_internal_character_states: bool, mutation_probability_function_of_time: Callable[[float], float], missing_probability_function_of_time: Callable[[float], float], stochastic_missing_probability: float, implicit_root_branch_length: float, ) -> float: """Calculates the log likelihood of a given character on the tree. Calculates the log likelihood of a tree given the states at a given character in the leaves using Felsenstein's Pruning Algorithm, which sets up a recursive relation between the likelihoods of states at nodes for this character. The likelihood L(s, n) at a given state s at a given node n is: L(s, n) = Π_{n'}(Σ_{s'}(P(s'|s) * L(s', n'))) for all n' that are children of n, and s' in the state space, with P(s'|s) being the transition probability from s to s'. That is, the likelihood at a given state at a given node is the product of the likelihoods of the states at this character at the children scaled by the probability of the current state transitioning to those states. This includes the missing state, as specified by `tree.missing_state_indicator`. We assume here that mutations are irreversible. Once a character mutates to a certain state that character cannot mutate again, with the exception of the fact that any non-missing state can mutate to a missing state. `mutation_probability_function_of_time` is expected to be a function that determine the probability of a mutation occuring given an amount of time. To determine the probability of acquiring a given (non-missing) state once a mutation occurs, the priors of the tree are used. Likewise, `missing_probability_function_of_time` determines the the probability of a missing data event occuring given an amount of time. The user can choose to use the character states annotated at internal nodes. If these are not used, then the likelihood is marginalized over all possible internal state characters. If the actual internal states are not provided, then the root is assumed to have the unmutated state at each character. Additionally, it is assumed that there is a single branch leading from the root that represents the roots' lifetime. If this branch does not exist and `use_internal_character_states` is set to False, then this branch is added with branch length equal to the average branch length of this tree. Args: tree: The tree on which to calculate the likelihood character: The index of the character to calculate the likelihood of use_internal_character_states: Indicates if internal node character states should be assumed to be specified exactly mutation_probability_function_of_time: The function defining the probability of a lineage acquiring a mutation within a given time missing_probability_function_of_time: The function defining the probability of a lineage acquiring heritable missing data within a given time stochastic_missing_probability: The probability that a cell/character pair acquires stochastic missing data at the end of the lineage implicit_root_branch_length: The length of the implicit root branch. Used if the implicit root needs to be added Returns: The log likelihood of the tree on one character """ # This dictionary uses a nested dictionary structure. Each node is mapped # to a dictionary storing the likelihood for each possible state # (states that have non-0 likelihood) likelihoods_at_nodes = {} # Perform a DFS to propagate the likelihood from the leaves for n in tree.depth_first_traverse_nodes(postorder=True): state_at_n = tree.get_character_states(n) # If states are observed, their likelihoods are set to 1 if tree.is_leaf(n): likelihoods_at_nodes[n] = {state_at_n[character]: 0} continue possible_states = [] # If internal character states are to be used, then the likelihood # for all other states are ignored. Otherwise, marginalize over # only states that do not break irreversibility, as all states that # do have likelihood of 0 if use_internal_character_states: possible_states = [state_at_n[character]] else: child_possible_states = [] for c in [ set(likelihoods_at_nodes[child]) for child in tree.children(n) ]: if tree.missing_state_indicator not in c and "&" not in c: child_possible_states.append(c) # "&" stands in for any non-missing state (including uncut), and # is a possible state when all children are missing, as any # state could have occurred at the parent if all missing data # events occurred independently. Used to avoid marginalizing # over the entire state space. if child_possible_states == []: possible_states = [ "&", tree.missing_state_indicator, ] else: possible_states = list( set.intersection(*child_possible_states)) if 0 not in possible_states: possible_states.append(0) # This stores the likelihood of each possible state at the current node likelihoods_per_state_at_n = {} # We calculate the likelihood of the states at the current node # according to the recurrence relation. For each state, we marginalize # over the likelihoods of the states that it could transition to in the # daughter nodes for s in possible_states: likelihood_for_s = 0 for child in tree.children(n): likelihoods_for_s_marginalize_over_s_ = [] for s_ in likelihoods_at_nodes[child]: likelihood_s_ = (log_transition_probability( tree, character, s, s_, tree.get_branch_length(n, child), mutation_probability_function_of_time, missing_probability_function_of_time, ) + likelihoods_at_nodes[child][s_]) # Here we take into account the probability of # stochastic missing data if tree.is_leaf(child): if (s_ == tree.missing_state_indicator and s != tree.missing_state_indicator): likelihood_s_ = np.log( np.exp(likelihood_s_) + (1 - missing_probability_function_of_time( tree.get_branch_length(n, child))) * stochastic_missing_probability) if s_ != tree.missing_state_indicator: likelihood_s_ += np.log( 1 - stochastic_missing_probability) likelihoods_for_s_marginalize_over_s_.append(likelihood_s_) likelihood_for_s += scipy.special.logsumexp( np.array(likelihoods_for_s_marginalize_over_s_)) likelihoods_per_state_at_n[s] = likelihood_for_s likelihoods_at_nodes[n] = likelihoods_per_state_at_n # If we are not to use the internal state annotations explicitly, # then we assume an implicit root where each state is the uncut state (0) # Thus, we marginalize over the transition from 0 in the implicit root # to all non-0 states in its child if not use_internal_character_states: # If the implicit root does not exist in the tree, then we impose it, # with the length of the branch being specified as # `implicit_root_branch_length`. Otherwise, we just use the existing # root with a singleton child as the implicit root if len(tree.children(tree.root)) != 1: likelihood_contribution_from_each_root_state = [ log_transition_probability( tree, character, 0, s_, implicit_root_branch_length, mutation_probability_function_of_time, missing_probability_function_of_time, ) + likelihoods_at_nodes[tree.root][s_] for s_ in likelihoods_at_nodes[tree.root] ] likelihood_at_implicit_root = scipy.special.logsumexp( likelihood_contribution_from_each_root_state) return likelihood_at_implicit_root else: # Here we account for the edge case in which all of the leaves are # missing, in which case the root will have "&" in place of 0. The # likelihood at "&" will have the same likelihood as 0 based on the # transition rules regarding "&". As "&" is a placeholder when the # state is unknown, this can be thought of realizing "&" as 0. if 0 not in likelihoods_at_nodes[tree.root]: return likelihoods_at_nodes[tree.root]["&"] else: # Otherwise, we return the likelihood of the 0 state at the # existing implicit root return likelihoods_at_nodes[tree.root][0] # If we use the internal state annotations explicitly, then we return # the likelihood of the state annotated at this character at the root else: return list(likelihoods_at_nodes[tree.root].values())[0]