Exemplo n.º 1
0
def annotate_tree_depths(tree: CassiopeiaTree) -> None:
    """Annotates tree depth at every node.

    Adds two attributes to the tree: how far away each node is from the root of
    the tree and how many triplets are rooted at that node. Modifies the tree
    in place.

    Args:
        tree: An ete3 Tree

    Returns:
        A dictionary mapping depth to the list of nodes at that depth.
    """

    depth_to_nodes = defaultdict(list)
    for n in tree.depth_first_traverse_nodes(source=tree.root,
                                             postorder=False):
        if tree.is_root(n):
            tree.set_attribute(n, "depth", 0)
        else:
            tree.set_attribute(n, "depth",
                               tree.get_attribute(tree.parent(n), "depth") + 1)

        depth_to_nodes[tree.get_attribute(n, "depth")].append(n)

        number_of_leaves = 0
        correction = 0
        for child in tree.children(n):
            number_of_leaves += len(tree.leaves_in_subtree(child))
            correction += nCr(len(tree.leaves_in_subtree(child)), 3)

        tree.set_attribute(n, "number_of_triplets",
                           nCr(number_of_leaves, 3) - correction)

    return depth_to_nodes
Exemplo n.º 2
0
def fitch_hartigan_top_down(
    cassiopeia_tree: CassiopeiaTree,
    root: Optional[str] = None,
    state_key: str = "S1",
    label_key: str = "label",
    copy: bool = False,
) -> Optional[CassiopeiaTree]:
    """Run Fitch-Hartigan top-down refinement

    Runs the Fitch-Hartigan top-down algorithm which selects an optimal solution
    from the tree rooted at the specified root.

    Args:
        cassiopeia_tree: CassiopeiaTree that has been processed with the
            Fitch-Hartigan bottom-up algorithm.
        root: Root from which to begin this refinement. Only the subtree below
            this node will be considered.
        state_key: Attribute key that stores the Fitch-Hartigan ancestral
            states.
        label_key: Key to add that stores the maximum-parsimony assignment
            inferred from the Fitch-Hartigan top-down refinement.
        copy: Modify the tree in place or not.

    Returns:
        A new CassiopeiaTree if the copy is set to True, else None.

    Raises:
        A CassiopeiaTreeError if Fitch-Hartigan bottom-up has not been called
        or if the state_key does not exist for a node.
    """

    # assign root
    root = cassiopeia_tree.root if (root is None) else root

    cassiopeia_tree = cassiopeia_tree.copy() if copy else cassiopeia_tree

    for node in cassiopeia_tree.depth_first_traverse_nodes(source=root,
                                                           postorder=False):

        if node == root:
            root_states = cassiopeia_tree.get_attribute(root, state_key)
            cassiopeia_tree.set_attribute(root, label_key,
                                          np.random.choice(root_states))
            continue

        parent = cassiopeia_tree.parent(node)
        parent_label = cassiopeia_tree.get_attribute(parent, label_key)
        optimal_node_states = cassiopeia_tree.get_attribute(node, state_key)

        if parent_label in optimal_node_states:
            cassiopeia_tree.set_attribute(node, label_key, parent_label)

        else:
            cassiopeia_tree.set_attribute(
                node, label_key, np.random.choice(optimal_node_states))

    return cassiopeia_tree if copy else None
Exemplo n.º 3
0
def _N_fitch_count(
    cassiopeia_tree: CassiopeiaTree,
    unique_states: List[str],
    node_to_i: Dict[str, int],
    label_to_j: Dict[str, int],
    state_key: str = "S1",
) -> np.array(int):
    """Fill in the dynamic programming table N for FitchCount.
    
    Computes N[v, s], corresponding to the number of solutions below
    a node v in the tree given v takes on the state s.

    Args:
        cassiopeia_tree: CassiopeiaTree object
        unique_states: The state space that a node can take on
        node_to_i: Helper array storing a mapping of each node to a unique
            integer
        label_to_j: Helper array storing a mapping of each unique state in the
            state space to a unique integer
        state_key: Attribute name in the CassiopeiaTree storing the possible
            states for each node, as inferred with the Fitch-Hartigan algorithm

    Returns:
        A 2-dimensional array storing N[v, s] - the number of
            equally-parsimonious solutions below node v, given v takes on
            state s
    """
    def _fill(v: str, s: str):
        """Helper function to fill in a single entry in N."""

        if cassiopeia_tree.is_leaf(v):
            return 1

        children = cassiopeia_tree.children(v)
        A = np.zeros((len(children)))

        legal_states = []
        for i, u in zip(range(len(children)), children):

            if s not in cassiopeia_tree.get_attribute(u, state_key):
                legal_states = cassiopeia_tree.get_attribute(u, state_key)
            else:
                legal_states = [s]

            A[i] = np.sum(
                [N[node_to_i[u], label_to_j[sp]] for sp in legal_states])
        return np.prod([A[u] for u in range(len(A))])

    N = np.full((len(cassiopeia_tree.nodes), len(unique_states)), 0.0)
    for n in cassiopeia_tree.depth_first_traverse_nodes():
        for s in cassiopeia_tree.get_attribute(n, state_key):
            N[node_to_i[n], label_to_j[s]] = _fill(n, s)

    return N
Exemplo n.º 4
0
def create_clade_colors(
    tree: CassiopeiaTree, clade_colors: Dict[str, Tuple[float, float, float]]
) -> Tuple[Dict[str, Tuple[float, float, float]], Dict[Tuple[str, str], Tuple[
        float, float, float]], ]:
    """Assign colors to nodes and branches by clade.

    Args:
        tree: The CassiopeiaTree.
        clade_colors: Dictionary containing internal node-color mappings. These
            colors will be used to color all the paths from this node to the
            leaves the provided color.

    Returns:
        Two dictionaries. The first contains the node colors, and the second
            contains the branch colors.
    """
    # Deal with clade colors.
    descendants = {}
    for node in clade_colors.keys():
        descendants[node] = set(tree.depth_first_traverse_nodes(node))
    if len(set.union(*list(descendants.values()))) != sum(
            len(d) for d in descendants.values()):
        warnings.warn(
            "Some clades specified with `clade_colors` are overlapping. "
            "Colors may be overridden.",
            PlottingWarning,
        )

    # Color by largest clade first
    node_colors = {}
    branch_colors = {}
    for node in sorted(descendants,
                       key=lambda x: len(descendants[x]),
                       reverse=True):
        color = clade_colors[node]
        for n1, n2 in tree.depth_first_traverse_edges(node):
            node_colors[n1] = node_colors[n2] = color
            branch_colors[(n1, n2)] = color
    return node_colors, branch_colors
Exemplo n.º 5
0
def fitch_hartigan_bottom_up(
    cassiopeia_tree: CassiopeiaTree,
    meta_item: str,
    add_key: str = "S1",
    copy: bool = False,
) -> Optional[CassiopeiaTree]:
    """Performs Fitch-Hartigan bottom-up ancestral reconstruction.

    Performs the bottom-up phase of the Fitch-Hartigan small parsimony
    algorithm. A new attribute called "S1" will be added to each node
    storing the optimal set of ancestral states inferred from this bottom-up 
    algorithm. If copy is False, the tree will be modified in place.
     

    Args:
        cassiopeia_tree: CassiopeiaTree object with cell meta data.
        meta_item: A column in the CassiopeiaTree cell meta corresponding to a
            categorical variable.
        add_key: Key to add for bottom-up reconstruction
        copy: Modify the tree in place or not.

    Returns:
        A new CassiopeiaTree if the copy is set to True, else None.

    Raises:
        CassiopeiaError if the tree does not have the specified meta data
            or the meta data is not categorical.
    """

    if meta_item not in cassiopeia_tree.cell_meta.columns:
        raise CassiopeiaError(
            "Meta item does not exist in the cassiopeia tree")

    meta = cassiopeia_tree.cell_meta[meta_item]

    if is_numeric_dtype(meta):
        raise CassiopeiaError("Meta item is not a categorical variable.")

    if not is_categorical_dtype(meta):
        meta = meta.astype("category")

    cassiopeia_tree = cassiopeia_tree.copy() if copy else cassiopeia_tree

    for node in cassiopeia_tree.depth_first_traverse_nodes():

        if cassiopeia_tree.is_leaf(node):
            cassiopeia_tree.set_attribute(node, add_key, [meta.loc[node]])

        else:
            children = cassiopeia_tree.children(node)
            if len(children) == 1:
                child_assignment = cassiopeia_tree.get_attribute(
                    children[0], add_key)
                cassiopeia_tree.set_attribute(node, add_key,
                                              [child_assignment])

            all_labels = np.concatenate([
                cassiopeia_tree.get_attribute(child, add_key)
                for child in children
            ])
            states, frequencies = np.unique(all_labels, return_counts=True)

            S1 = states[np.where(frequencies == np.max(frequencies))]
            cassiopeia_tree.set_attribute(node, add_key, S1)

    return cassiopeia_tree if copy else None
Exemplo n.º 6
0
def _C_fitch_count(
    cassiopeia_tree: CassiopeiaTree,
    N: np.array,
    unique_states: List[str],
    node_to_i: Dict[str, int],
    label_to_j: Dict[str, int],
    state_key: str = "S1",
) -> np.array(int):
    """Fill in the dynamic programming table C for FitchCount.
    
    Computes C[v, s, s1, s2], the number of transitions from state s1 to
    state s2 in the subtree rooted at v, given that state v takes on the
    state s. 

    Args:
        cassiopeia_tree: CassiopeiaTree object
        N: N array computed during FitchCount storing the number of solutions
            below a node v given v takes on state s
        unique_states: The state space that a node can take on
        node_to_i: Helper array storing a mapping of each node to a unique
            integer
        label_to_j: Helper array storing a mapping of each unique state in the
            state space to a unique integer
        state_key: Attribute name in the CassiopeiaTree storing the possible
            states for each node, as inferred with the Fitch-Hartigan algorithm

    Returns:
        A 4-dimensional array storing C[v, s, s1, s2] - the number of
            transitions from state s1 to s2 below a node v given v takes on
            the state s.
    """
    def _fill(v: str, s: str, s1: str, s2: str) -> int:
        """Helper function to fill in a single entry in C."""

        if cassiopeia_tree.is_leaf(v):
            return 0

        children = cassiopeia_tree.children(v)
        A = np.zeros((len(children)))
        LS = [[]] * len(children)

        for i, u in zip(range(len(children)), children):
            if s in cassiopeia_tree.get_attribute(u, state_key):
                LS[i] = [s]
            else:
                LS[i] = cassiopeia_tree.get_attribute(u, state_key)

            A[i] = np.sum([
                C[node_to_i[u], label_to_j[sp], label_to_j[s1],
                  label_to_j[s2], ] for sp in LS[i]
            ])

            if s1 == s and s2 in LS[i]:
                A[i] += N[node_to_i[u], label_to_j[s2]]

        parts = []
        for i, u in zip(range(len(children)), children):
            prod = 1

            for k, up in zip(range(len(children)), children):
                fact = 0
                if up == u:
                    continue
                for sp in LS[k]:
                    fact += N[node_to_i[up], label_to_j[sp]]
                prod *= fact

            part = A[i] * prod
            parts.append(part)

        return np.sum(parts)

    C = np.zeros(
        (len(cassiopeia_tree.nodes), N.shape[1], N.shape[1], N.shape[1]))

    for n in cassiopeia_tree.depth_first_traverse_nodes():
        for s in cassiopeia_tree.get_attribute(n, state_key):
            for (s1, s2) in itertools.product(unique_states, repeat=2):
                C[node_to_i[n], label_to_j[s], label_to_j[s1],
                  label_to_j[s2]] = _fill(n, s, s1, s2)

    return C
Exemplo n.º 7
0
def place_tree(
    tree: CassiopeiaTree,
    depth_key: Optional[str] = None,
    orient: Union[Literal["down", "up", "left", "right"], float] = "down",
    depth_scale: float = 1.0,
    width_scale: float = 1.0,
    extend_branches: bool = True,
    angled_branches: bool = True,
    polar_interpolation_threshold: float = 5.0,
    polar_interpolation_step: float = 1.0,
    add_root: bool = False,
) -> Tuple[Dict[str, Tuple[float, float]], Dict[Tuple[str, str], Tuple[
        List[float], List[float]]], ]:
    """Given a tree, computes the coordinates of the nodes and branches.

    This function computes the x and y coordinates of all nodes and branches (as
    lines) to be used for visualization. Several options are provided to
    modify how the elements are placed. This function returns two dictionaries,
    where the first has nodes as keys and an (x, y) tuple as the values.
    Similarly, the second dictionary has (parent, child) tuples as keys denoting
    branches in the tree and a tuple (xs, ys) of two lists containing the x and
    y coordinates to draw the branch.

    This function also provides functionality to place the tree in polar
    coordinates as a circular plot when the `orient` is a number. In this case,
    the returned dictionaries have (thetas, radii) as its elements, which are
    the angles and radii in polar coordinates respectively.

    Note:
        This function only *places* (i.e. computes the x and y coordinates) a
        tree on a coordinate system and does no real plotting.

    Args:
        tree: The CassiopeiaTree to place on the coordinate grid.
        depth_key: The node attribute to use as the depth of the nodes. If
            not provided, the distances from the root is used by calling
            `tree.get_distances`.
        orient: The orientation of the tree. Valid arguments are `left`, `right`,
            `up`, `down` to display a rectangular plot (indicating the direction
            of going from root -> leaves) or any number, in which case the
            tree is placed in polar coordinates with the provided number used
            as an angle offset.
        depth_scale: Scale the depth of the tree by this amount. This option
            has no effect in polar coordinates.
        width_scale: Scale the width of the tree by this amount. This option
            has no effect in polar coordinates.
        extend_branches: Extend branch lengths such that the distance from the
            root to every node is the same. If `depth_key` is also provided, then
            only the leaf branches are extended to the deepest leaf.
        angled_branches: Display branches as angled, instead of as just a
            line from the parent to a child.
        polar_interpolation_threshold: When displaying in polar coordinates,
            many plotting frameworks (such as Plotly) just draws a straight line
            from one point to another, instead of scaling the radius appropriately.
            This effect is most noticeable when two points are connected that
            have a large angle between them. When the angle between two connected
            points in polar coordinates exceed this amount (in degrees), this function
            adds additional points that smoothly interpolate between the two
            endpoints.
        polar_interpolation_step: Interpolation step. See above.
        add_root: Add a root node so that only one branch connects to the
            start of the tree. This node will have the name `synthetic_root`.

    Returns:
        Two dictionaries, where the first contains the node coordinates and
            the second contains the branch coordinates.
    """
    root = tree.root
    nodes = tree.nodes.copy()
    edges = tree.edges.copy()
    depths = None
    if depth_key:
        depths = {
            node: tree.get_attribute(node, depth_key)
            for node in tree.nodes
        }
    else:
        depths = tree.get_distances(root)

    placement_depths = {}
    positions = {}
    leaf_i = 0
    leaves = set()
    for node in tree.depth_first_traverse_nodes(postorder=False):
        if tree.is_leaf(node):
            positions[node] = leaf_i
            leaf_i += 1
            leaves.add(node)
            placement_depths[node] = depths[node]
            if extend_branches:
                placement_depths[node] = (max(depths.values())
                                          if depth_key else 0)

    # Place nodes by depth
    for node in sorted(depths, key=lambda k: depths[k], reverse=True):
        # Leaves have already been placed
        if node in leaves:
            continue

        # Find all immediate children and place at center.
        min_position = np.inf
        max_position = -np.inf
        min_depth = np.inf
        for child in tree.children(node):
            min_position = min(min_position, positions[child])
            max_position = max(max_position, positions[child])
            min_depth = min(min_depth, placement_depths[child])
        positions[node] = (min_position + max_position) / 2
        placement_depths[node] = (min_depth - 1 if extend_branches
                                  and not depth_key else depths[node])
    # Add synthetic root
    if add_root:
        root_name = "synthetic_root"
        positions[root_name] = 0
        placement_depths[root_name] = min(placement_depths.values()) - 1
        nodes.append(root_name)
        edges.append((root_name, root))

    polar = isinstance(orient, (float, int))
    polar_depth_offset = -min(placement_depths.values())
    polar_angle_scale = 360 / (len(leaves) + 1)

    # Define some helper functions to modify coordinate system.
    def reorient(pos, depth):
        pos *= width_scale
        depth *= depth_scale
        if orient == "down":
            return (pos, -depth)
        elif orient == "right":
            return (depth, pos)
        elif orient == "left":
            return (-depth, pos)
        # default: up
        return (pos, depth)

    def polarize(pos, depth):
        # angle, radius
        return (
            (pos + 1) * polar_angle_scale + orient,
            depth + polar_depth_offset,
        )

    node_coords = {}
    for node in nodes:
        pos = positions[node]
        depth = placement_depths[node]
        coords = polarize(pos, depth) if polar else reorient(pos, depth)
        node_coords[node] = coords

    branch_coords = {}
    for parent, child in edges:
        parent_pos, parent_depth = positions[parent], placement_depths[parent]
        child_pos, child_depth = positions[child], placement_depths[child]

        middle_x = (child_pos if angled_branches else
                    (parent_pos + child_pos) / 2)
        middle_y = (parent_depth if angled_branches else
                    (parent_depth + child_depth) / 2)
        _xs = [parent_pos, middle_x, child_pos]
        _ys = [parent_depth, middle_y, child_depth]

        xs = []
        ys = []
        for _x, _y in zip(_xs, _ys):
            if polar:
                x, y = polarize(_x, _y)

                if xs:
                    # Interpolate to prevent weird angles if the angle exceeds threshold.
                    prev_x = xs[-1]
                    prev_y = ys[-1]
                    if abs(x - prev_x) > polar_interpolation_threshold:
                        num = int(abs(x - prev_x) / polar_interpolation_step)
                        for inter_x, inter_y in zip(
                                np.linspace(prev_x, x, num)[1:-1],
                                np.linspace(prev_y, y, num)[1:-1],
                        ):
                            xs.append(inter_x)
                            ys.append(inter_y)
            else:
                x, y = reorient(_x, _y)
            xs.append(x)
            ys.append(y)

        branch_coords[(parent, child)] = (xs, ys)

    return node_coords, branch_coords
Exemplo n.º 8
0
def compute_expansion_pvalues(
    tree: CassiopeiaTree,
    min_clade_size: int = 10,
    min_depth: int = 1,
    copy: bool = False,
) -> Union[CassiopeiaTree, None]:
    """Call expansion pvalues on a tree.

    Uses the methodology described in Yang, Jones et al, BioRxiv (2021) to
    assess the expansion probability of a given subclade of a phylogeny.
    Mathematical treatment of the coalescent probability is described in
    Griffiths and Tavare, Stochastic Models (1998).

    The probability computed corresponds to the probability that, under a simple
    neutral coalescent model, a given subclade contains the observed number of
    cells; in other words, a one-sided p-value. Often, if the probability is
    less than some threshold (e.g., 0.05), this might indicate that there exists
    some subclade under this node that to which this expansion probability can
    be attributed (i.e. the null hypothesis that the subclade is undergoing 
    neutral drift can be rejected).

    This function will add an attribute "expansion_pvalue" to the tree, and
    return None unless :param:`copy` is set to True.

    On a typical balanced tree, this function will perform in O(n log n) time, 
    but can be up to O(n^3) on highly unbalanced trees. A future endeavor may 
    be to impelement the function in O(n) time.

    Args:
        tree: CassiopeiaTree
        min_clade_size: Minimum number of leaves in a subtree to be considered.
        min_depth: Minimum depth of clade to be considered. Depth is measured
            in number of nodes from the root, not branch lengths.
        copy: Return copy.

    Returns:
        If copy is set to False, returns the tree with attributes added
            in place. Else, returns a new CassiopeiaTree.
    """

    tree = tree.copy() if copy else tree

    # instantiate attributes
    _depths = {}
    for node in tree.depth_first_traverse_nodes(postorder=False):
        tree.set_attribute(node, "expansion_pvalue", 1.0)

        if tree.is_root(node):
            _depths[node] = 0
        else:
            _depths[node] = _depths[tree.parent(node)] + 1

    for node in tree.depth_first_traverse_nodes(postorder=False):

        n = len(tree.leaves_in_subtree(node))

        k = len(tree.children(node))
        for c in tree.children(node):

            if len(tree.leaves_in_subtree(c)) < min_clade_size:
                continue

            depth = _depths[c]
            if depth < min_depth:
                continue

            b = len(tree.leaves_in_subtree(c))

            # this value below is a simplification of the quantity:
            # sum[simple_coalescent_probability(n, b2, k) for \
            #   b2 in range(b, n - k + 2)]
            p = nCk(n - b, k - 1) / nCk(n - 1, k - 1)

            tree.set_attribute(c, "expansion_pvalue", p)

    return tree if copy else None
Exemplo n.º 9
0
    def overlay_data(self, tree: CassiopeiaTree):
        """Overlays Cas9-based lineage tracing data onto the CassiopeiaTree.

        Args:
            tree: Input CassiopeiaTree
        """

        if self.random_seed is not None:
            np.random.seed(self.random_seed)

        # create state priors if they don't exist.
        # This will set the instance's variable for mutation priors and will
        # use this for all future simulations.
        if self.mutation_priors is None:
            self.mutation_priors = {}
            probabilites = [
                self.state_generating_distribution()
                for _ in range(self.number_of_states)
            ]
            Z = np.sum(probabilites)
            for i in range(self.number_of_states):
                self.mutation_priors[i + 1] = probabilites[i] / Z

        number_of_characters = self.number_of_cassettes * self.size_of_cassette

        # initialize character states
        character_matrix = {}
        for node in tree.nodes:
            character_matrix[node] = [-1] * number_of_characters

        for node in tree.depth_first_traverse_nodes(tree.root, postorder=False):

            if tree.is_root(node):
                character_matrix[node] = [0] * number_of_characters
                continue

            parent = tree.parent(node)
            life_time = tree.get_time(node) - tree.get_time(parent)

            character_array = character_matrix[parent]
            open_sites = [
                c
                for c in range(len(character_array))
                if character_array[c] == 0
            ]

            new_cuts = []
            for site in open_sites:
                mutation_rate = self.mutation_rate_per_character[site]
                mutation_probability = 1 - (np.exp(-life_time * mutation_rate))

                if np.random.uniform() < mutation_probability:
                    new_cuts.append(site)

            # collapse cuts that are on the same cassette
            cuts_remaining = new_cuts
            if self.collapse_sites_on_cassette and self.size_of_cassette > 1:
                character_array, cuts_remaining = self.collapse_sites(
                    character_array, new_cuts
                )

            # introduce new states at cut sites
            character_array = self.introduce_states(
                character_array, cuts_remaining
            )

            # silence cassettes
            silencing_probability = 1 - (
                np.exp(-life_time * self.heritable_silencing_rate)
            )
            character_array = self.silence_cassettes(
                character_array,
                silencing_probability,
                self.heritable_missing_data_state,
            )

            character_matrix[node] = character_array

        # apply stochastic silencing
        for leaf in tree.leaves:
            character_matrix[leaf] = self.silence_cassettes(
                character_matrix[leaf],
                self.stochastic_silencing_rate,
                self.stochastic_missing_data_state,
            )

        tree.set_all_character_states(character_matrix)
Exemplo n.º 10
0
def log_likelihood_of_character(
    tree: CassiopeiaTree,
    character: int,
    use_internal_character_states: bool,
    mutation_probability_function_of_time: Callable[[float], float],
    missing_probability_function_of_time: Callable[[float], float],
    stochastic_missing_probability: float,
    implicit_root_branch_length: float,
) -> float:
    """Calculates the log likelihood of a given character on the tree.

    Calculates the log likelihood of a tree given the states at a given
    character in the leaves using Felsenstein's Pruning Algorithm, which sets
    up a recursive relation between the likelihoods of states at nodes for this
    character. The likelihood L(s, n) at a given state s at a given node n is:

    L(s, n) = Π_{n'}(Σ_{s'}(P(s'|s) * L(s', n')))

    for all n' that are children of n, and s' in the state space, with
    P(s'|s) being the transition probability from s to s'. That is,
    the likelihood at a given state at a given node is the product of
    the likelihoods of the states at this character at the children scaled by
    the probability of the current state transitioning to those states. This
    includes the missing state, as specified by `tree.missing_state_indicator`.

    We assume here that mutations are irreversible. Once a character mutates to
    a certain state that character cannot mutate again, with the exception of
    the fact that any non-missing state can mutate to a missing state.
    `mutation_probability_function_of_time` is expected to be a function that
    determine the probability of a mutation occuring given an amount of time.
    To determine the probability of acquiring a given (non-missing) state once
    a mutation occurs, the priors of the tree are used. Likewise,
    `missing_probability_function_of_time` determines the the probability of a
    missing data event occuring given an amount of time.

    The user can choose to use the character states annotated at internal
    nodes. If these are not used, then the likelihood is marginalized over
    all possible internal state characters. If the actual internal states
    are not provided, then the root is assumed to have the unmutated state
    at each character. Additionally, it is assumed that there is a single
    branch leading from the root that represents the roots' lifetime. If
    this branch does not exist and `use_internal_character_states` is set
    to False, then this branch is added with branch length equal to the
    average branch length of this tree.

    Args:
        tree: The tree on which to calculate the likelihood
        character: The index of the character to calculate the likelihood of
        use_internal_character_states: Indicates if internal node
            character states should be assumed to be specified exactly
        mutation_probability_function_of_time: The function defining the
            probability of a lineage acquiring a mutation within a given time
        missing_probability_function_of_time: The function defining the
            probability of a lineage acquiring heritable missing data within a
            given time
        stochastic_missing_probability: The probability that a cell/character
            pair acquires stochastic missing data at the end of the lineage
        implicit_root_branch_length: The length of the implicit root branch.
            Used if the implicit root needs to be added

    Returns:
        The log likelihood of the tree on one character
    """

    # This dictionary uses a nested dictionary structure. Each node is mapped
    # to a dictionary storing the likelihood for each possible state
    # (states that have non-0 likelihood)
    likelihoods_at_nodes = {}

    # Perform a DFS to propagate the likelihood from the leaves
    for n in tree.depth_first_traverse_nodes(postorder=True):
        state_at_n = tree.get_character_states(n)
        # If states are observed, their likelihoods are set to 1
        if tree.is_leaf(n):
            likelihoods_at_nodes[n] = {state_at_n[character]: 0}
            continue

        possible_states = []
        # If internal character states are to be used, then the likelihood
        # for all other states are ignored. Otherwise, marginalize over
        # only states that do not break irreversibility, as all states that
        # do have likelihood of 0
        if use_internal_character_states:
            possible_states = [state_at_n[character]]
        else:
            child_possible_states = []
            for c in [
                    set(likelihoods_at_nodes[child])
                    for child in tree.children(n)
            ]:
                if tree.missing_state_indicator not in c and "&" not in c:
                    child_possible_states.append(c)
            # "&" stands in for any non-missing state (including uncut), and
            # is a possible state when all children are missing, as any
            # state could have occurred at the parent if all missing data
            # events occurred independently. Used to avoid marginalizing
            # over the entire state space.
            if child_possible_states == []:
                possible_states = [
                    "&",
                    tree.missing_state_indicator,
                ]
            else:
                possible_states = list(
                    set.intersection(*child_possible_states))
                if 0 not in possible_states:
                    possible_states.append(0)

        # This stores the likelihood of each possible state at the current node
        likelihoods_per_state_at_n = {}

        # We calculate the likelihood of the states at the current node
        # according to the recurrence relation. For each state, we marginalize
        # over the likelihoods of the states that it could transition to in the
        # daughter nodes
        for s in possible_states:
            likelihood_for_s = 0
            for child in tree.children(n):
                likelihoods_for_s_marginalize_over_s_ = []
                for s_ in likelihoods_at_nodes[child]:
                    likelihood_s_ = (log_transition_probability(
                        tree,
                        character,
                        s,
                        s_,
                        tree.get_branch_length(n, child),
                        mutation_probability_function_of_time,
                        missing_probability_function_of_time,
                    ) + likelihoods_at_nodes[child][s_])
                    # Here we take into account the probability of
                    # stochastic missing data
                    if tree.is_leaf(child):
                        if (s_ == tree.missing_state_indicator
                                and s != tree.missing_state_indicator):
                            likelihood_s_ = np.log(
                                np.exp(likelihood_s_) +
                                (1 - missing_probability_function_of_time(
                                    tree.get_branch_length(n, child))) *
                                stochastic_missing_probability)
                        if s_ != tree.missing_state_indicator:
                            likelihood_s_ += np.log(
                                1 - stochastic_missing_probability)
                    likelihoods_for_s_marginalize_over_s_.append(likelihood_s_)
                likelihood_for_s += scipy.special.logsumexp(
                    np.array(likelihoods_for_s_marginalize_over_s_))
            likelihoods_per_state_at_n[s] = likelihood_for_s

        likelihoods_at_nodes[n] = likelihoods_per_state_at_n

    # If we are not to use the internal state annotations explicitly,
    # then we assume an implicit root where each state is the uncut state (0)
    # Thus, we marginalize over the transition from 0 in the implicit root
    # to all non-0 states in its child
    if not use_internal_character_states:
        # If the implicit root does not exist in the tree, then we impose it,
        # with the length of the branch being specified as
        # `implicit_root_branch_length`. Otherwise, we just use the existing
        # root with a singleton child as the implicit root
        if len(tree.children(tree.root)) != 1:

            likelihood_contribution_from_each_root_state = [
                log_transition_probability(
                    tree,
                    character,
                    0,
                    s_,
                    implicit_root_branch_length,
                    mutation_probability_function_of_time,
                    missing_probability_function_of_time,
                ) + likelihoods_at_nodes[tree.root][s_]
                for s_ in likelihoods_at_nodes[tree.root]
            ]
            likelihood_at_implicit_root = scipy.special.logsumexp(
                likelihood_contribution_from_each_root_state)

            return likelihood_at_implicit_root

        else:
            # Here we account for the edge case in which all of the leaves are
            # missing, in which case the root will have "&" in place of 0. The
            # likelihood at "&" will have the same likelihood as 0 based on the
            # transition rules regarding "&". As "&" is a placeholder when the
            # state is unknown, this can be thought of realizing "&" as 0.
            if 0 not in likelihoods_at_nodes[tree.root]:
                return likelihoods_at_nodes[tree.root]["&"]
            else:
                # Otherwise, we return the likelihood of the 0 state at the
                # existing implicit root
                return likelihoods_at_nodes[tree.root][0]

    # If we use the internal state annotations explicitly, then we return
    # the likelihood of the state annotated at this character at the root
    else:
        return list(likelihoods_at_nodes[tree.root].values())[0]