Exemplo n.º 1
0
def annotate_tree_depths(tree: CassiopeiaTree) -> None:
    """Annotates tree depth at every node.

    Adds two attributes to the tree: how far away each node is from the root of
    the tree and how many triplets are rooted at that node. Modifies the tree
    in place.

    Args:
        tree: An ete3 Tree

    Returns:
        A dictionary mapping depth to the list of nodes at that depth.
    """

    depth_to_nodes = defaultdict(list)
    for n in tree.depth_first_traverse_nodes(source=tree.root,
                                             postorder=False):
        if tree.is_root(n):
            tree.set_attribute(n, "depth", 0)
        else:
            tree.set_attribute(n, "depth",
                               tree.get_attribute(tree.parent(n), "depth") + 1)

        depth_to_nodes[tree.get_attribute(n, "depth")].append(n)

        number_of_leaves = 0
        correction = 0
        for child in tree.children(n):
            number_of_leaves += len(tree.leaves_in_subtree(child))
            correction += nCr(len(tree.leaves_in_subtree(child)), 3)

        tree.set_attribute(n, "number_of_triplets",
                           nCr(number_of_leaves, 3) - correction)

    return depth_to_nodes
Exemplo n.º 2
0
def fitch_hartigan_bottom_up(
    cassiopeia_tree: CassiopeiaTree,
    meta_item: str,
    add_key: str = "S1",
    copy: bool = False,
) -> Optional[CassiopeiaTree]:
    """Performs Fitch-Hartigan bottom-up ancestral reconstruction.

    Performs the bottom-up phase of the Fitch-Hartigan small parsimony
    algorithm. A new attribute called "S1" will be added to each node
    storing the optimal set of ancestral states inferred from this bottom-up 
    algorithm. If copy is False, the tree will be modified in place.
     

    Args:
        cassiopeia_tree: CassiopeiaTree object with cell meta data.
        meta_item: A column in the CassiopeiaTree cell meta corresponding to a
            categorical variable.
        add_key: Key to add for bottom-up reconstruction
        copy: Modify the tree in place or not.

    Returns:
        A new CassiopeiaTree if the copy is set to True, else None.

    Raises:
        CassiopeiaError if the tree does not have the specified meta data
            or the meta data is not categorical.
    """

    if meta_item not in cassiopeia_tree.cell_meta.columns:
        raise CassiopeiaError(
            "Meta item does not exist in the cassiopeia tree")

    meta = cassiopeia_tree.cell_meta[meta_item]

    if is_numeric_dtype(meta):
        raise CassiopeiaError("Meta item is not a categorical variable.")

    if not is_categorical_dtype(meta):
        meta = meta.astype("category")

    cassiopeia_tree = cassiopeia_tree.copy() if copy else cassiopeia_tree

    for node in cassiopeia_tree.depth_first_traverse_nodes():

        if cassiopeia_tree.is_leaf(node):
            cassiopeia_tree.set_attribute(node, add_key, [meta.loc[node]])

        else:
            children = cassiopeia_tree.children(node)
            if len(children) == 1:
                child_assignment = cassiopeia_tree.get_attribute(
                    children[0], add_key)
                cassiopeia_tree.set_attribute(node, add_key,
                                              [child_assignment])

            all_labels = np.concatenate([
                cassiopeia_tree.get_attribute(child, add_key)
                for child in children
            ])
            states, frequencies = np.unique(all_labels, return_counts=True)

            S1 = states[np.where(frequencies == np.max(frequencies))]
            cassiopeia_tree.set_attribute(node, add_key, S1)

    return cassiopeia_tree if copy else None
Exemplo n.º 3
0
def place_tree(
    tree: CassiopeiaTree,
    depth_key: Optional[str] = None,
    orient: Union[Literal["down", "up", "left", "right"], float] = "down",
    depth_scale: float = 1.0,
    width_scale: float = 1.0,
    extend_branches: bool = True,
    angled_branches: bool = True,
    polar_interpolation_threshold: float = 5.0,
    polar_interpolation_step: float = 1.0,
    add_root: bool = False,
) -> Tuple[Dict[str, Tuple[float, float]], Dict[Tuple[str, str], Tuple[
        List[float], List[float]]], ]:
    """Given a tree, computes the coordinates of the nodes and branches.

    This function computes the x and y coordinates of all nodes and branches (as
    lines) to be used for visualization. Several options are provided to
    modify how the elements are placed. This function returns two dictionaries,
    where the first has nodes as keys and an (x, y) tuple as the values.
    Similarly, the second dictionary has (parent, child) tuples as keys denoting
    branches in the tree and a tuple (xs, ys) of two lists containing the x and
    y coordinates to draw the branch.

    This function also provides functionality to place the tree in polar
    coordinates as a circular plot when the `orient` is a number. In this case,
    the returned dictionaries have (thetas, radii) as its elements, which are
    the angles and radii in polar coordinates respectively.

    Note:
        This function only *places* (i.e. computes the x and y coordinates) a
        tree on a coordinate system and does no real plotting.

    Args:
        tree: The CassiopeiaTree to place on the coordinate grid.
        depth_key: The node attribute to use as the depth of the nodes. If
            not provided, the distances from the root is used by calling
            `tree.get_distances`.
        orient: The orientation of the tree. Valid arguments are `left`, `right`,
            `up`, `down` to display a rectangular plot (indicating the direction
            of going from root -> leaves) or any number, in which case the
            tree is placed in polar coordinates with the provided number used
            as an angle offset.
        depth_scale: Scale the depth of the tree by this amount. This option
            has no effect in polar coordinates.
        width_scale: Scale the width of the tree by this amount. This option
            has no effect in polar coordinates.
        extend_branches: Extend branch lengths such that the distance from the
            root to every node is the same. If `depth_key` is also provided, then
            only the leaf branches are extended to the deepest leaf.
        angled_branches: Display branches as angled, instead of as just a
            line from the parent to a child.
        polar_interpolation_threshold: When displaying in polar coordinates,
            many plotting frameworks (such as Plotly) just draws a straight line
            from one point to another, instead of scaling the radius appropriately.
            This effect is most noticeable when two points are connected that
            have a large angle between them. When the angle between two connected
            points in polar coordinates exceed this amount (in degrees), this function
            adds additional points that smoothly interpolate between the two
            endpoints.
        polar_interpolation_step: Interpolation step. See above.
        add_root: Add a root node so that only one branch connects to the
            start of the tree. This node will have the name `synthetic_root`.

    Returns:
        Two dictionaries, where the first contains the node coordinates and
            the second contains the branch coordinates.
    """
    root = tree.root
    nodes = tree.nodes.copy()
    edges = tree.edges.copy()
    depths = None
    if depth_key:
        depths = {
            node: tree.get_attribute(node, depth_key)
            for node in tree.nodes
        }
    else:
        depths = tree.get_distances(root)

    placement_depths = {}
    positions = {}
    leaf_i = 0
    leaves = set()
    for node in tree.depth_first_traverse_nodes(postorder=False):
        if tree.is_leaf(node):
            positions[node] = leaf_i
            leaf_i += 1
            leaves.add(node)
            placement_depths[node] = depths[node]
            if extend_branches:
                placement_depths[node] = (max(depths.values())
                                          if depth_key else 0)

    # Place nodes by depth
    for node in sorted(depths, key=lambda k: depths[k], reverse=True):
        # Leaves have already been placed
        if node in leaves:
            continue

        # Find all immediate children and place at center.
        min_position = np.inf
        max_position = -np.inf
        min_depth = np.inf
        for child in tree.children(node):
            min_position = min(min_position, positions[child])
            max_position = max(max_position, positions[child])
            min_depth = min(min_depth, placement_depths[child])
        positions[node] = (min_position + max_position) / 2
        placement_depths[node] = (min_depth - 1 if extend_branches
                                  and not depth_key else depths[node])
    # Add synthetic root
    if add_root:
        root_name = "synthetic_root"
        positions[root_name] = 0
        placement_depths[root_name] = min(placement_depths.values()) - 1
        nodes.append(root_name)
        edges.append((root_name, root))

    polar = isinstance(orient, (float, int))
    polar_depth_offset = -min(placement_depths.values())
    polar_angle_scale = 360 / (len(leaves) + 1)

    # Define some helper functions to modify coordinate system.
    def reorient(pos, depth):
        pos *= width_scale
        depth *= depth_scale
        if orient == "down":
            return (pos, -depth)
        elif orient == "right":
            return (depth, pos)
        elif orient == "left":
            return (-depth, pos)
        # default: up
        return (pos, depth)

    def polarize(pos, depth):
        # angle, radius
        return (
            (pos + 1) * polar_angle_scale + orient,
            depth + polar_depth_offset,
        )

    node_coords = {}
    for node in nodes:
        pos = positions[node]
        depth = placement_depths[node]
        coords = polarize(pos, depth) if polar else reorient(pos, depth)
        node_coords[node] = coords

    branch_coords = {}
    for parent, child in edges:
        parent_pos, parent_depth = positions[parent], placement_depths[parent]
        child_pos, child_depth = positions[child], placement_depths[child]

        middle_x = (child_pos if angled_branches else
                    (parent_pos + child_pos) / 2)
        middle_y = (parent_depth if angled_branches else
                    (parent_depth + child_depth) / 2)
        _xs = [parent_pos, middle_x, child_pos]
        _ys = [parent_depth, middle_y, child_depth]

        xs = []
        ys = []
        for _x, _y in zip(_xs, _ys):
            if polar:
                x, y = polarize(_x, _y)

                if xs:
                    # Interpolate to prevent weird angles if the angle exceeds threshold.
                    prev_x = xs[-1]
                    prev_y = ys[-1]
                    if abs(x - prev_x) > polar_interpolation_threshold:
                        num = int(abs(x - prev_x) / polar_interpolation_step)
                        for inter_x, inter_y in zip(
                                np.linspace(prev_x, x, num)[1:-1],
                                np.linspace(prev_y, y, num)[1:-1],
                        ):
                            xs.append(inter_x)
                            ys.append(inter_y)
            else:
                x, y = reorient(_x, _y)
            xs.append(x)
            ys.append(y)

        branch_coords[(parent, child)] = (xs, ys)

    return node_coords, branch_coords
Exemplo n.º 4
0
def compute_expansion_pvalues(
    tree: CassiopeiaTree,
    min_clade_size: int = 10,
    min_depth: int = 1,
    copy: bool = False,
) -> Union[CassiopeiaTree, None]:
    """Call expansion pvalues on a tree.

    Uses the methodology described in Yang, Jones et al, BioRxiv (2021) to
    assess the expansion probability of a given subclade of a phylogeny.
    Mathematical treatment of the coalescent probability is described in
    Griffiths and Tavare, Stochastic Models (1998).

    The probability computed corresponds to the probability that, under a simple
    neutral coalescent model, a given subclade contains the observed number of
    cells; in other words, a one-sided p-value. Often, if the probability is
    less than some threshold (e.g., 0.05), this might indicate that there exists
    some subclade under this node that to which this expansion probability can
    be attributed (i.e. the null hypothesis that the subclade is undergoing 
    neutral drift can be rejected).

    This function will add an attribute "expansion_pvalue" to the tree, and
    return None unless :param:`copy` is set to True.

    On a typical balanced tree, this function will perform in O(n log n) time, 
    but can be up to O(n^3) on highly unbalanced trees. A future endeavor may 
    be to impelement the function in O(n) time.

    Args:
        tree: CassiopeiaTree
        min_clade_size: Minimum number of leaves in a subtree to be considered.
        min_depth: Minimum depth of clade to be considered. Depth is measured
            in number of nodes from the root, not branch lengths.
        copy: Return copy.

    Returns:
        If copy is set to False, returns the tree with attributes added
            in place. Else, returns a new CassiopeiaTree.
    """

    tree = tree.copy() if copy else tree

    # instantiate attributes
    _depths = {}
    for node in tree.depth_first_traverse_nodes(postorder=False):
        tree.set_attribute(node, "expansion_pvalue", 1.0)

        if tree.is_root(node):
            _depths[node] = 0
        else:
            _depths[node] = _depths[tree.parent(node)] + 1

    for node in tree.depth_first_traverse_nodes(postorder=False):

        n = len(tree.leaves_in_subtree(node))

        k = len(tree.children(node))
        for c in tree.children(node):

            if len(tree.leaves_in_subtree(c)) < min_clade_size:
                continue

            depth = _depths[c]
            if depth < min_depth:
                continue

            b = len(tree.leaves_in_subtree(c))

            # this value below is a simplification of the quantity:
            # sum[simple_coalescent_probability(n, b2, k) for \
            #   b2 in range(b, n - k + 2)]
            p = nCk(n - b, k - 1) / nCk(n - 1, k - 1)

            tree.set_attribute(c, "expansion_pvalue", p)

    return tree if copy else None
Exemplo n.º 5
0
def log_likelihood_of_character(
    tree: CassiopeiaTree,
    character: int,
    use_internal_character_states: bool,
    mutation_probability_function_of_time: Callable[[float], float],
    missing_probability_function_of_time: Callable[[float], float],
    stochastic_missing_probability: float,
    implicit_root_branch_length: float,
) -> float:
    """Calculates the log likelihood of a given character on the tree.

    Calculates the log likelihood of a tree given the states at a given
    character in the leaves using Felsenstein's Pruning Algorithm, which sets
    up a recursive relation between the likelihoods of states at nodes for this
    character. The likelihood L(s, n) at a given state s at a given node n is:

    L(s, n) = Π_{n'}(Σ_{s'}(P(s'|s) * L(s', n')))

    for all n' that are children of n, and s' in the state space, with
    P(s'|s) being the transition probability from s to s'. That is,
    the likelihood at a given state at a given node is the product of
    the likelihoods of the states at this character at the children scaled by
    the probability of the current state transitioning to those states. This
    includes the missing state, as specified by `tree.missing_state_indicator`.

    We assume here that mutations are irreversible. Once a character mutates to
    a certain state that character cannot mutate again, with the exception of
    the fact that any non-missing state can mutate to a missing state.
    `mutation_probability_function_of_time` is expected to be a function that
    determine the probability of a mutation occuring given an amount of time.
    To determine the probability of acquiring a given (non-missing) state once
    a mutation occurs, the priors of the tree are used. Likewise,
    `missing_probability_function_of_time` determines the the probability of a
    missing data event occuring given an amount of time.

    The user can choose to use the character states annotated at internal
    nodes. If these are not used, then the likelihood is marginalized over
    all possible internal state characters. If the actual internal states
    are not provided, then the root is assumed to have the unmutated state
    at each character. Additionally, it is assumed that there is a single
    branch leading from the root that represents the roots' lifetime. If
    this branch does not exist and `use_internal_character_states` is set
    to False, then this branch is added with branch length equal to the
    average branch length of this tree.

    Args:
        tree: The tree on which to calculate the likelihood
        character: The index of the character to calculate the likelihood of
        use_internal_character_states: Indicates if internal node
            character states should be assumed to be specified exactly
        mutation_probability_function_of_time: The function defining the
            probability of a lineage acquiring a mutation within a given time
        missing_probability_function_of_time: The function defining the
            probability of a lineage acquiring heritable missing data within a
            given time
        stochastic_missing_probability: The probability that a cell/character
            pair acquires stochastic missing data at the end of the lineage
        implicit_root_branch_length: The length of the implicit root branch.
            Used if the implicit root needs to be added

    Returns:
        The log likelihood of the tree on one character
    """

    # This dictionary uses a nested dictionary structure. Each node is mapped
    # to a dictionary storing the likelihood for each possible state
    # (states that have non-0 likelihood)
    likelihoods_at_nodes = {}

    # Perform a DFS to propagate the likelihood from the leaves
    for n in tree.depth_first_traverse_nodes(postorder=True):
        state_at_n = tree.get_character_states(n)
        # If states are observed, their likelihoods are set to 1
        if tree.is_leaf(n):
            likelihoods_at_nodes[n] = {state_at_n[character]: 0}
            continue

        possible_states = []
        # If internal character states are to be used, then the likelihood
        # for all other states are ignored. Otherwise, marginalize over
        # only states that do not break irreversibility, as all states that
        # do have likelihood of 0
        if use_internal_character_states:
            possible_states = [state_at_n[character]]
        else:
            child_possible_states = []
            for c in [
                    set(likelihoods_at_nodes[child])
                    for child in tree.children(n)
            ]:
                if tree.missing_state_indicator not in c and "&" not in c:
                    child_possible_states.append(c)
            # "&" stands in for any non-missing state (including uncut), and
            # is a possible state when all children are missing, as any
            # state could have occurred at the parent if all missing data
            # events occurred independently. Used to avoid marginalizing
            # over the entire state space.
            if child_possible_states == []:
                possible_states = [
                    "&",
                    tree.missing_state_indicator,
                ]
            else:
                possible_states = list(
                    set.intersection(*child_possible_states))
                if 0 not in possible_states:
                    possible_states.append(0)

        # This stores the likelihood of each possible state at the current node
        likelihoods_per_state_at_n = {}

        # We calculate the likelihood of the states at the current node
        # according to the recurrence relation. For each state, we marginalize
        # over the likelihoods of the states that it could transition to in the
        # daughter nodes
        for s in possible_states:
            likelihood_for_s = 0
            for child in tree.children(n):
                likelihoods_for_s_marginalize_over_s_ = []
                for s_ in likelihoods_at_nodes[child]:
                    likelihood_s_ = (log_transition_probability(
                        tree,
                        character,
                        s,
                        s_,
                        tree.get_branch_length(n, child),
                        mutation_probability_function_of_time,
                        missing_probability_function_of_time,
                    ) + likelihoods_at_nodes[child][s_])
                    # Here we take into account the probability of
                    # stochastic missing data
                    if tree.is_leaf(child):
                        if (s_ == tree.missing_state_indicator
                                and s != tree.missing_state_indicator):
                            likelihood_s_ = np.log(
                                np.exp(likelihood_s_) +
                                (1 - missing_probability_function_of_time(
                                    tree.get_branch_length(n, child))) *
                                stochastic_missing_probability)
                        if s_ != tree.missing_state_indicator:
                            likelihood_s_ += np.log(
                                1 - stochastic_missing_probability)
                    likelihoods_for_s_marginalize_over_s_.append(likelihood_s_)
                likelihood_for_s += scipy.special.logsumexp(
                    np.array(likelihoods_for_s_marginalize_over_s_))
            likelihoods_per_state_at_n[s] = likelihood_for_s

        likelihoods_at_nodes[n] = likelihoods_per_state_at_n

    # If we are not to use the internal state annotations explicitly,
    # then we assume an implicit root where each state is the uncut state (0)
    # Thus, we marginalize over the transition from 0 in the implicit root
    # to all non-0 states in its child
    if not use_internal_character_states:
        # If the implicit root does not exist in the tree, then we impose it,
        # with the length of the branch being specified as
        # `implicit_root_branch_length`. Otherwise, we just use the existing
        # root with a singleton child as the implicit root
        if len(tree.children(tree.root)) != 1:

            likelihood_contribution_from_each_root_state = [
                log_transition_probability(
                    tree,
                    character,
                    0,
                    s_,
                    implicit_root_branch_length,
                    mutation_probability_function_of_time,
                    missing_probability_function_of_time,
                ) + likelihoods_at_nodes[tree.root][s_]
                for s_ in likelihoods_at_nodes[tree.root]
            ]
            likelihood_at_implicit_root = scipy.special.logsumexp(
                likelihood_contribution_from_each_root_state)

            return likelihood_at_implicit_root

        else:
            # Here we account for the edge case in which all of the leaves are
            # missing, in which case the root will have "&" in place of 0. The
            # likelihood at "&" will have the same likelihood as 0 based on the
            # transition rules regarding "&". As "&" is a placeholder when the
            # state is unknown, this can be thought of realizing "&" as 0.
            if 0 not in likelihoods_at_nodes[tree.root]:
                return likelihoods_at_nodes[tree.root]["&"]
            else:
                # Otherwise, we return the likelihood of the 0 state at the
                # existing implicit root
                return likelihoods_at_nodes[tree.root][0]

    # If we use the internal state annotations explicitly, then we return
    # the likelihood of the state annotated at this character at the root
    else:
        return list(likelihoods_at_nodes[tree.root].values())[0]
Exemplo n.º 6
0
    def percolate(
        self,
        character_matrix: pd.DataFrame,
        samples: List[str],
        priors: Optional[Dict[int, Dict[int, float]]] = None,
        weights: Optional[Dict[int, Dict[int, float]]] = None,
        missing_state_indicator: int = -1,
    ) -> Tuple[List[str], List[str]]:
        """The function used by the percolation algorithm to partition the
        set of samples in two.
        First, a pairwise similarity graph is generated with samples as nodes
        such that edges between a pair of nodes is some provided function on
        the number of character/state mutations shared. Then, the algorithm
        removes the minimum edge (in the case of ties all are removed) until
        the graph is split into multiple connected components. If there are more
        than two connected components, the procedure joins them until two remain.
        This is done by inferring the mutations of the LCA of each sample set
        obeying Camin-Sokal Parsimony, and then clustering the groups of samples
        based on their LCAs. The provided solver is used to cluster the groups
        into two clusters.
        Args:
            character_matrix: Character matrix
            samples: A list of samples to partition
            priors: A dictionary storing the probability of each character
                mutating to a particular state.
            weights: Weighting of each (character, state) pair. Typically a
                transformation of the priors.
            missing_state_indicator: Character representing missing data.
        Returns:
            A tuple of lists, representing the left and right partition groups
        """
        sample_indices = solver_utilities.convert_sample_names_to_indices(
            character_matrix.index, samples)
        unique_character_array = character_matrix.to_numpy()

        G = nx.Graph()
        G.add_nodes_from(sample_indices)

        # Add edge weights into the similarity graph
        edge_weight_buckets = defaultdict(list)
        for i, j in itertools.combinations(sample_indices, 2):
            similarity = self.similarity_function(
                unique_character_array[i, :],
                unique_character_array[j, :],
                missing_state_indicator,
                weights,
            )
            if similarity > self.threshold:
                edge_weight_buckets[similarity].append((i, j))
                G.add_edge(i, j)

        if len(G.edges) == 0:
            return samples, []

        connected_components = list(nx.connected_components(G))
        sorted_edge_weights = sorted(edge_weight_buckets, reverse=True)

        # Percolate the similarity graph by continuously removing the minimum
        # edge until at least two components exists
        while len(connected_components) <= 1:
            min_weight = sorted_edge_weights.pop()
            for edge in edge_weight_buckets[min_weight]:
                G.remove_edge(edge[0], edge[1])
            connected_components = list(nx.connected_components(G))

        # If the number of connected components > 2, merge components by
        # joining the most similar LCAs of each component until
        # only 2 remain
        partition_sides = []

        if len(connected_components) > 2:
            for c in range(len(connected_components)):
                connected_components[c] = list(connected_components[c])
            lcas = {}
            component_to_nodes = {}
            # Find the LCA of the nodes in each connected component
            for ind in range(len(connected_components)):
                component_identifier = "component" + str(ind)
                component_to_nodes[
                    component_identifier] = connected_components[ind]
                character_vectors = [
                    list(i) for i in list(unique_character_array[
                        connected_components[ind], :])
                ]
                lcas[component_identifier] = data_utilities.get_lca_characters(
                    character_vectors, missing_state_indicator)
            # Build a tree on the LCA characters to cluster the components
            lca_tree = CassiopeiaTree(
                pd.DataFrame.from_dict(lcas, orient="index"),
                missing_state_indicator=missing_state_indicator,
                priors=priors,
            )

            self.joining_solver.solve(lca_tree,
                                      collapse_mutationless_edges=False)
            grouped_components = []

            # Take the split at the root as the clusters of components
            # in the split, ignoring unifurcations
            current_node = lca_tree.root
            while len(grouped_components) == 0:
                successors = lca_tree.children(current_node)
                if len(successors) == 1:
                    current_node = successors[0]
                else:
                    for i in successors:
                        grouped_components.append(
                            lca_tree.leaves_in_subtree(i))

            # For each component in each cluster, take the nodes in that
            # component to form the final split
            for cluster in grouped_components:
                sample_index_group = []
                for component in cluster:
                    sample_index_group.extend(component_to_nodes[component])
                partition_sides.append(sample_index_group)
        else:
            for c in range(len(connected_components)):
                partition_sides.append(list(connected_components[c]))

        # Convert from component indices back to the sample names in the
        # original character matrix
        sample_names = list(character_matrix.index)
        partition_named = []
        for sample_index_group in partition_sides:
            sample_name_group = []
            for sample_index in sample_index_group:
                sample_name_group.append(sample_names[sample_index])
            partition_named.append(sample_name_group)

        return partition_named
Exemplo n.º 7
0
def sample_triplet_at_depth(
    tree: CassiopeiaTree,
    depth: int,
    depth_to_nodes: Optional[Dict[int, List[str]]] = None,
) -> Tuple[List[int], str]:
    """Samples a triplet at a given depth.

    Samples a triplet of leaves such that the depth of the LCA of the triplet
    is at the specified depth.

    Args:
        tree: CassiopeiaTree
        depth: Depth at which to sample the triplet
        depth_to_nodes: An optional dictionary that maps a depth to the nodes
            that appear at that depth. This speeds up the function considerably.

    Returns:
        A list of three leaves corresponding to the triplet name of the outgroup
            of the triplet.
    """

    if depth_to_nodes is None:
        candidate_nodes = tree.filter_nodes(
            lambda x: tree.get_attribute(x, "depth") == depth)
    else:
        candidate_nodes = depth_to_nodes[depth]

    total_triplets = sum(
        [tree.get_attribute(v, "number_of_triplets") for v in candidate_nodes])

    # sample a  node from this depth with probability proportional to the number
    # of triplets underneath it
    probs = [
        tree.get_attribute(v, "number_of_triplets") / total_triplets
        for v in candidate_nodes
    ]
    node = np.random.choice(candidate_nodes, size=1, replace=False, p=probs)[0]

    # Generate the probabilities to sample each combination of 3 daughter clades
    # to sample from, proportional to the number of triplets in each daughter
    # clade. Choices include all ways to choose 3 different daughter clades
    # or 2 from one daughter clade and one from another
    probs = []
    combos = []
    denom = 0
    for (i, j, k) in itertools.combinations_with_replacement(
            list(tree.children(node)), 3):

        if i == j and j == k:
            continue

        combos.append((i, j, k))

        size_of_i = len(tree.leaves_in_subtree(i))
        size_of_j = len(tree.leaves_in_subtree(j))
        size_of_k = len(tree.leaves_in_subtree(k))

        val = 0
        if i == j:
            val = nCr(size_of_i, 2) * size_of_k
        elif j == k:
            val = nCr(size_of_j, 2) * size_of_i
        elif i == k:
            val = nCr(size_of_k, 2) * size_of_j
        else:
            val = size_of_i * size_of_j * size_of_k
        probs.append(val)
        denom += val

    probs = [val / denom for val in probs]

    # choose daughter clades
    ind = np.random.choice(range(len(combos)), size=1, replace=False,
                           p=probs)[0]
    (i, j, k) = combos[ind]

    if i == j:
        in_group = np.random.choice(tree.leaves_in_subtree(i),
                                    2,
                                    replace=False)
        out_group = np.random.choice(tree.leaves_in_subtree(k))
    elif j == k:
        in_group = np.random.choice(tree.leaves_in_subtree(j),
                                    2,
                                    replace=False)
        out_group = np.random.choice(tree.leaves_in_subtree(i))
    elif i == k:
        in_group = np.random.choice(tree.leaves_in_subtree(k), 2, replace=True)
        out_group = np.random.choice(tree.leaves_in_subtree(j))
    else:

        return (
            (
                str(np.random.choice(tree.leaves_in_subtree(i))),
                str(np.random.choice(tree.leaves_in_subtree(j))),
                str(np.random.choice(tree.leaves_in_subtree(k))),
            ),
            "None",
        )

    return (str(in_group[0]), str(in_group[1]), str(out_group)), out_group