def annotate_tree_depths(tree: CassiopeiaTree) -> None:
    """Annotates tree depth at every node.

    Adds two attributes to the tree: how far away each node is from the root of
    the tree and how many triplets are rooted at that node. Modifies the tree
    in place.

    Args:
        tree: An ete3 Tree

    Returns:
        A dictionary mapping depth to the list of nodes at that depth.
    """

    depth_to_nodes = defaultdict(list)
    for n in tree.depth_first_traverse_nodes(source=tree.root,
                                             postorder=False):
        if tree.is_root(n):
            tree.set_attribute(n, "depth", 0)
        else:
            tree.set_attribute(n, "depth",
                               tree.get_attribute(tree.parent(n), "depth") + 1)

        depth_to_nodes[tree.get_attribute(n, "depth")].append(n)

        number_of_leaves = 0
        correction = 0
        for child in tree.children(n):
            number_of_leaves += len(tree.leaves_in_subtree(child))
            correction += nCr(len(tree.leaves_in_subtree(child)), 3)

        tree.set_attribute(n, "number_of_triplets",
                           nCr(number_of_leaves, 3) - correction)

    return depth_to_nodes
def fitch_hartigan_top_down(
    cassiopeia_tree: CassiopeiaTree,
    root: Optional[str] = None,
    state_key: str = "S1",
    label_key: str = "label",
    copy: bool = False,
) -> Optional[CassiopeiaTree]:
    """Run Fitch-Hartigan top-down refinement

    Runs the Fitch-Hartigan top-down algorithm which selects an optimal solution
    from the tree rooted at the specified root.

    Args:
        cassiopeia_tree: CassiopeiaTree that has been processed with the
            Fitch-Hartigan bottom-up algorithm.
        root: Root from which to begin this refinement. Only the subtree below
            this node will be considered.
        state_key: Attribute key that stores the Fitch-Hartigan ancestral
            states.
        label_key: Key to add that stores the maximum-parsimony assignment
            inferred from the Fitch-Hartigan top-down refinement.
        copy: Modify the tree in place or not.

    Returns:
        A new CassiopeiaTree if the copy is set to True, else None.

    Raises:
        A CassiopeiaTreeError if Fitch-Hartigan bottom-up has not been called
        or if the state_key does not exist for a node.
    """

    # assign root
    root = cassiopeia_tree.root if (root is None) else root

    cassiopeia_tree = cassiopeia_tree.copy() if copy else cassiopeia_tree

    for node in cassiopeia_tree.depth_first_traverse_nodes(source=root,
                                                           postorder=False):

        if node == root:
            root_states = cassiopeia_tree.get_attribute(root, state_key)
            cassiopeia_tree.set_attribute(root, label_key,
                                          np.random.choice(root_states))
            continue

        parent = cassiopeia_tree.parent(node)
        parent_label = cassiopeia_tree.get_attribute(parent, label_key)
        optimal_node_states = cassiopeia_tree.get_attribute(node, state_key)

        if parent_label in optimal_node_states:
            cassiopeia_tree.set_attribute(node, label_key, parent_label)

        else:
            cassiopeia_tree.set_attribute(
                node, label_key, np.random.choice(optimal_node_states))

    return cassiopeia_tree if copy else None
def score_small_parsimony(
    cassiopeia_tree: CassiopeiaTree,
    meta_item: str,
    root: Optional[str] = None,
    infer_ancestral_states: bool = True,
    label_key: Optional[str] = "label",
) -> int:
    """Computes the small-parsimony of the tree.

    Using the meta data stored in the specified cell meta column, compute the
    parsimony score of the tree.

    Args:
        cassiopeia_tree: CassiopeiaTree object with cell meta data.
        meta_item: A column in the CassiopeiaTree cell meta corresponding to a
            categorical variable.
        root: Node to treat as the root. Only the subtree below
            this node will be considered.
        infer_ancestral_states: Whether or not ancestral states must be inferred
            (this will be False if `fitch_hartigan` has already been called on
            the tree.)
        label_key: If ancestral states have already been inferred, this key
            indicates the name of the attribute they're stored in.

    Returns:
        The parsimony score.

    Raises:
        CassiopeiaError if label_key has not been populated.
    """

    cassiopeia_tree = cassiopeia_tree.copy()

    if infer_ancestral_states:
        fitch_hartigan(cassiopeia_tree, meta_item, root, label_key=label_key)

    parsimony = 0
    for (parent,
         child) in cassiopeia_tree.depth_first_traverse_edges(source=root):

        try:
            if cassiopeia_tree.get_attribute(
                    parent, label_key) != cassiopeia_tree.get_attribute(
                        child, label_key):
                parsimony += 1
        except CassiopeiaTreeError:
            raise CassiopeiaError(f"{label_key} does not exist for a node, "
                                  "try running Fitch-Hartigan or passing "
                                  "infer_ancestral_states=True.")
    return parsimony
def _N_fitch_count(
    cassiopeia_tree: CassiopeiaTree,
    unique_states: List[str],
    node_to_i: Dict[str, int],
    label_to_j: Dict[str, int],
    state_key: str = "S1",
) -> np.array(int):
    """Fill in the dynamic programming table N for FitchCount.
    
    Computes N[v, s], corresponding to the number of solutions below
    a node v in the tree given v takes on the state s.

    Args:
        cassiopeia_tree: CassiopeiaTree object
        unique_states: The state space that a node can take on
        node_to_i: Helper array storing a mapping of each node to a unique
            integer
        label_to_j: Helper array storing a mapping of each unique state in the
            state space to a unique integer
        state_key: Attribute name in the CassiopeiaTree storing the possible
            states for each node, as inferred with the Fitch-Hartigan algorithm

    Returns:
        A 2-dimensional array storing N[v, s] - the number of
            equally-parsimonious solutions below node v, given v takes on
            state s
    """
    def _fill(v: str, s: str):
        """Helper function to fill in a single entry in N."""

        if cassiopeia_tree.is_leaf(v):
            return 1

        children = cassiopeia_tree.children(v)
        A = np.zeros((len(children)))

        legal_states = []
        for i, u in zip(range(len(children)), children):

            if s not in cassiopeia_tree.get_attribute(u, state_key):
                legal_states = cassiopeia_tree.get_attribute(u, state_key)
            else:
                legal_states = [s]

            A[i] = np.sum(
                [N[node_to_i[u], label_to_j[sp]] for sp in legal_states])
        return np.prod([A[u] for u in range(len(A))])

    N = np.full((len(cassiopeia_tree.nodes), len(unique_states)), 0.0)
    for n in cassiopeia_tree.depth_first_traverse_nodes():
        for s in cassiopeia_tree.get_attribute(n, state_key):
            N[node_to_i[n], label_to_j[s]] = _fill(n, s)

    return N
def fitch_hartigan_bottom_up(
    cassiopeia_tree: CassiopeiaTree,
    meta_item: str,
    add_key: str = "S1",
    copy: bool = False,
) -> Optional[CassiopeiaTree]:
    """Performs Fitch-Hartigan bottom-up ancestral reconstruction.

    Performs the bottom-up phase of the Fitch-Hartigan small parsimony
    algorithm. A new attribute called "S1" will be added to each node
    storing the optimal set of ancestral states inferred from this bottom-up 
    algorithm. If copy is False, the tree will be modified in place.
     

    Args:
        cassiopeia_tree: CassiopeiaTree object with cell meta data.
        meta_item: A column in the CassiopeiaTree cell meta corresponding to a
            categorical variable.
        add_key: Key to add for bottom-up reconstruction
        copy: Modify the tree in place or not.

    Returns:
        A new CassiopeiaTree if the copy is set to True, else None.

    Raises:
        CassiopeiaError if the tree does not have the specified meta data
            or the meta data is not categorical.
    """

    if meta_item not in cassiopeia_tree.cell_meta.columns:
        raise CassiopeiaError(
            "Meta item does not exist in the cassiopeia tree")

    meta = cassiopeia_tree.cell_meta[meta_item]

    if is_numeric_dtype(meta):
        raise CassiopeiaError("Meta item is not a categorical variable.")

    if not is_categorical_dtype(meta):
        meta = meta.astype("category")

    cassiopeia_tree = cassiopeia_tree.copy() if copy else cassiopeia_tree

    for node in cassiopeia_tree.depth_first_traverse_nodes():

        if cassiopeia_tree.is_leaf(node):
            cassiopeia_tree.set_attribute(node, add_key, [meta.loc[node]])

        else:
            children = cassiopeia_tree.children(node)
            if len(children) == 1:
                child_assignment = cassiopeia_tree.get_attribute(
                    children[0], add_key)
                cassiopeia_tree.set_attribute(node, add_key,
                                              [child_assignment])

            all_labels = np.concatenate([
                cassiopeia_tree.get_attribute(child, add_key)
                for child in children
            ])
            states, frequencies = np.unique(all_labels, return_counts=True)

            S1 = states[np.where(frequencies == np.max(frequencies))]
            cassiopeia_tree.set_attribute(node, add_key, S1)

    return cassiopeia_tree if copy else None
def _C_fitch_count(
    cassiopeia_tree: CassiopeiaTree,
    N: np.array,
    unique_states: List[str],
    node_to_i: Dict[str, int],
    label_to_j: Dict[str, int],
    state_key: str = "S1",
) -> np.array(int):
    """Fill in the dynamic programming table C for FitchCount.
    
    Computes C[v, s, s1, s2], the number of transitions from state s1 to
    state s2 in the subtree rooted at v, given that state v takes on the
    state s. 

    Args:
        cassiopeia_tree: CassiopeiaTree object
        N: N array computed during FitchCount storing the number of solutions
            below a node v given v takes on state s
        unique_states: The state space that a node can take on
        node_to_i: Helper array storing a mapping of each node to a unique
            integer
        label_to_j: Helper array storing a mapping of each unique state in the
            state space to a unique integer
        state_key: Attribute name in the CassiopeiaTree storing the possible
            states for each node, as inferred with the Fitch-Hartigan algorithm

    Returns:
        A 4-dimensional array storing C[v, s, s1, s2] - the number of
            transitions from state s1 to s2 below a node v given v takes on
            the state s.
    """
    def _fill(v: str, s: str, s1: str, s2: str) -> int:
        """Helper function to fill in a single entry in C."""

        if cassiopeia_tree.is_leaf(v):
            return 0

        children = cassiopeia_tree.children(v)
        A = np.zeros((len(children)))
        LS = [[]] * len(children)

        for i, u in zip(range(len(children)), children):
            if s in cassiopeia_tree.get_attribute(u, state_key):
                LS[i] = [s]
            else:
                LS[i] = cassiopeia_tree.get_attribute(u, state_key)

            A[i] = np.sum([
                C[node_to_i[u], label_to_j[sp], label_to_j[s1],
                  label_to_j[s2], ] for sp in LS[i]
            ])

            if s1 == s and s2 in LS[i]:
                A[i] += N[node_to_i[u], label_to_j[s2]]

        parts = []
        for i, u in zip(range(len(children)), children):
            prod = 1

            for k, up in zip(range(len(children)), children):
                fact = 0
                if up == u:
                    continue
                for sp in LS[k]:
                    fact += N[node_to_i[up], label_to_j[sp]]
                prod *= fact

            part = A[i] * prod
            parts.append(part)

        return np.sum(parts)

    C = np.zeros(
        (len(cassiopeia_tree.nodes), N.shape[1], N.shape[1], N.shape[1]))

    for n in cassiopeia_tree.depth_first_traverse_nodes():
        for s in cassiopeia_tree.get_attribute(n, state_key):
            for (s1, s2) in itertools.product(unique_states, repeat=2):
                C[node_to_i[n], label_to_j[s], label_to_j[s1],
                  label_to_j[s2]] = _fill(n, s, s1, s2)

    return C
Beispiel #7
0
def place_tree(
    tree: CassiopeiaTree,
    depth_key: Optional[str] = None,
    orient: Union[Literal["down", "up", "left", "right"], float] = "down",
    depth_scale: float = 1.0,
    width_scale: float = 1.0,
    extend_branches: bool = True,
    angled_branches: bool = True,
    polar_interpolation_threshold: float = 5.0,
    polar_interpolation_step: float = 1.0,
    add_root: bool = False,
) -> Tuple[Dict[str, Tuple[float, float]], Dict[Tuple[str, str], Tuple[
        List[float], List[float]]], ]:
    """Given a tree, computes the coordinates of the nodes and branches.

    This function computes the x and y coordinates of all nodes and branches (as
    lines) to be used for visualization. Several options are provided to
    modify how the elements are placed. This function returns two dictionaries,
    where the first has nodes as keys and an (x, y) tuple as the values.
    Similarly, the second dictionary has (parent, child) tuples as keys denoting
    branches in the tree and a tuple (xs, ys) of two lists containing the x and
    y coordinates to draw the branch.

    This function also provides functionality to place the tree in polar
    coordinates as a circular plot when the `orient` is a number. In this case,
    the returned dictionaries have (thetas, radii) as its elements, which are
    the angles and radii in polar coordinates respectively.

    Note:
        This function only *places* (i.e. computes the x and y coordinates) a
        tree on a coordinate system and does no real plotting.

    Args:
        tree: The CassiopeiaTree to place on the coordinate grid.
        depth_key: The node attribute to use as the depth of the nodes. If
            not provided, the distances from the root is used by calling
            `tree.get_distances`.
        orient: The orientation of the tree. Valid arguments are `left`, `right`,
            `up`, `down` to display a rectangular plot (indicating the direction
            of going from root -> leaves) or any number, in which case the
            tree is placed in polar coordinates with the provided number used
            as an angle offset.
        depth_scale: Scale the depth of the tree by this amount. This option
            has no effect in polar coordinates.
        width_scale: Scale the width of the tree by this amount. This option
            has no effect in polar coordinates.
        extend_branches: Extend branch lengths such that the distance from the
            root to every node is the same. If `depth_key` is also provided, then
            only the leaf branches are extended to the deepest leaf.
        angled_branches: Display branches as angled, instead of as just a
            line from the parent to a child.
        polar_interpolation_threshold: When displaying in polar coordinates,
            many plotting frameworks (such as Plotly) just draws a straight line
            from one point to another, instead of scaling the radius appropriately.
            This effect is most noticeable when two points are connected that
            have a large angle between them. When the angle between two connected
            points in polar coordinates exceed this amount (in degrees), this function
            adds additional points that smoothly interpolate between the two
            endpoints.
        polar_interpolation_step: Interpolation step. See above.
        add_root: Add a root node so that only one branch connects to the
            start of the tree. This node will have the name `synthetic_root`.

    Returns:
        Two dictionaries, where the first contains the node coordinates and
            the second contains the branch coordinates.
    """
    root = tree.root
    nodes = tree.nodes.copy()
    edges = tree.edges.copy()
    depths = None
    if depth_key:
        depths = {
            node: tree.get_attribute(node, depth_key)
            for node in tree.nodes
        }
    else:
        depths = tree.get_distances(root)

    placement_depths = {}
    positions = {}
    leaf_i = 0
    leaves = set()
    for node in tree.depth_first_traverse_nodes(postorder=False):
        if tree.is_leaf(node):
            positions[node] = leaf_i
            leaf_i += 1
            leaves.add(node)
            placement_depths[node] = depths[node]
            if extend_branches:
                placement_depths[node] = (max(depths.values())
                                          if depth_key else 0)

    # Place nodes by depth
    for node in sorted(depths, key=lambda k: depths[k], reverse=True):
        # Leaves have already been placed
        if node in leaves:
            continue

        # Find all immediate children and place at center.
        min_position = np.inf
        max_position = -np.inf
        min_depth = np.inf
        for child in tree.children(node):
            min_position = min(min_position, positions[child])
            max_position = max(max_position, positions[child])
            min_depth = min(min_depth, placement_depths[child])
        positions[node] = (min_position + max_position) / 2
        placement_depths[node] = (min_depth - 1 if extend_branches
                                  and not depth_key else depths[node])
    # Add synthetic root
    if add_root:
        root_name = "synthetic_root"
        positions[root_name] = 0
        placement_depths[root_name] = min(placement_depths.values()) - 1
        nodes.append(root_name)
        edges.append((root_name, root))

    polar = isinstance(orient, (float, int))
    polar_depth_offset = -min(placement_depths.values())
    polar_angle_scale = 360 / (len(leaves) + 1)

    # Define some helper functions to modify coordinate system.
    def reorient(pos, depth):
        pos *= width_scale
        depth *= depth_scale
        if orient == "down":
            return (pos, -depth)
        elif orient == "right":
            return (depth, pos)
        elif orient == "left":
            return (-depth, pos)
        # default: up
        return (pos, depth)

    def polarize(pos, depth):
        # angle, radius
        return (
            (pos + 1) * polar_angle_scale + orient,
            depth + polar_depth_offset,
        )

    node_coords = {}
    for node in nodes:
        pos = positions[node]
        depth = placement_depths[node]
        coords = polarize(pos, depth) if polar else reorient(pos, depth)
        node_coords[node] = coords

    branch_coords = {}
    for parent, child in edges:
        parent_pos, parent_depth = positions[parent], placement_depths[parent]
        child_pos, child_depth = positions[child], placement_depths[child]

        middle_x = (child_pos if angled_branches else
                    (parent_pos + child_pos) / 2)
        middle_y = (parent_depth if angled_branches else
                    (parent_depth + child_depth) / 2)
        _xs = [parent_pos, middle_x, child_pos]
        _ys = [parent_depth, middle_y, child_depth]

        xs = []
        ys = []
        for _x, _y in zip(_xs, _ys):
            if polar:
                x, y = polarize(_x, _y)

                if xs:
                    # Interpolate to prevent weird angles if the angle exceeds threshold.
                    prev_x = xs[-1]
                    prev_y = ys[-1]
                    if abs(x - prev_x) > polar_interpolation_threshold:
                        num = int(abs(x - prev_x) / polar_interpolation_step)
                        for inter_x, inter_y in zip(
                                np.linspace(prev_x, x, num)[1:-1],
                                np.linspace(prev_y, y, num)[1:-1],
                        ):
                            xs.append(inter_x)
                            ys.append(inter_y)
            else:
                x, y = reorient(_x, _y)
            xs.append(x)
            ys.append(y)

        branch_coords[(parent, child)] = (xs, ys)

    return node_coords, branch_coords
def sample_triplet_at_depth(
    tree: CassiopeiaTree,
    depth: int,
    depth_to_nodes: Optional[Dict[int, List[str]]] = None,
) -> Tuple[List[int], str]:
    """Samples a triplet at a given depth.

    Samples a triplet of leaves such that the depth of the LCA of the triplet
    is at the specified depth.

    Args:
        tree: CassiopeiaTree
        depth: Depth at which to sample the triplet
        depth_to_nodes: An optional dictionary that maps a depth to the nodes
            that appear at that depth. This speeds up the function considerably.

    Returns:
        A list of three leaves corresponding to the triplet name of the outgroup
            of the triplet.
    """

    if depth_to_nodes is None:
        candidate_nodes = tree.filter_nodes(
            lambda x: tree.get_attribute(x, "depth") == depth)
    else:
        candidate_nodes = depth_to_nodes[depth]

    total_triplets = sum(
        [tree.get_attribute(v, "number_of_triplets") for v in candidate_nodes])

    # sample a  node from this depth with probability proportional to the number
    # of triplets underneath it
    probs = [
        tree.get_attribute(v, "number_of_triplets") / total_triplets
        for v in candidate_nodes
    ]
    node = np.random.choice(candidate_nodes, size=1, replace=False, p=probs)[0]

    # Generate the probabilities to sample each combination of 3 daughter clades
    # to sample from, proportional to the number of triplets in each daughter
    # clade. Choices include all ways to choose 3 different daughter clades
    # or 2 from one daughter clade and one from another
    probs = []
    combos = []
    denom = 0
    for (i, j, k) in itertools.combinations_with_replacement(
            list(tree.children(node)), 3):

        if i == j and j == k:
            continue

        combos.append((i, j, k))

        size_of_i = len(tree.leaves_in_subtree(i))
        size_of_j = len(tree.leaves_in_subtree(j))
        size_of_k = len(tree.leaves_in_subtree(k))

        val = 0
        if i == j:
            val = nCr(size_of_i, 2) * size_of_k
        elif j == k:
            val = nCr(size_of_j, 2) * size_of_i
        elif i == k:
            val = nCr(size_of_k, 2) * size_of_j
        else:
            val = size_of_i * size_of_j * size_of_k
        probs.append(val)
        denom += val

    probs = [val / denom for val in probs]

    # choose daughter clades
    ind = np.random.choice(range(len(combos)), size=1, replace=False,
                           p=probs)[0]
    (i, j, k) = combos[ind]

    if i == j:
        in_group = np.random.choice(tree.leaves_in_subtree(i),
                                    2,
                                    replace=False)
        out_group = np.random.choice(tree.leaves_in_subtree(k))
    elif j == k:
        in_group = np.random.choice(tree.leaves_in_subtree(j),
                                    2,
                                    replace=False)
        out_group = np.random.choice(tree.leaves_in_subtree(i))
    elif i == k:
        in_group = np.random.choice(tree.leaves_in_subtree(k), 2, replace=True)
        out_group = np.random.choice(tree.leaves_in_subtree(j))
    else:

        return (
            (
                str(np.random.choice(tree.leaves_in_subtree(i))),
                str(np.random.choice(tree.leaves_in_subtree(j))),
                str(np.random.choice(tree.leaves_in_subtree(k))),
            ),
            "None",
        )

    return (str(in_group[0]), str(in_group[1]), str(out_group)), out_group