예제 #1
0
def compute_phylogenetic_weight_matrix(
    tree: CassiopeiaTree,
    inverse: bool = False,
    inverse_fn: Callable[[Union[int, float]], float] = lambda x: 1 / x,
) -> pd.DataFrame:
    """Computes the phylogenetic weight matrix.

    Computes the distances between all leaves in a tree. The user has the option
    to return the inverse matrix, (i.e., transform distances to proximities) and
    specify an appropriate inverse function.

    This function computes the phylogenetic weight matrix in O(n^2 logn) time.

    An NxN weight matrix is returned.

    Args:
        tree: CassiopeiaTree
        inverse: Convert distances to proximities
        inverse_fn: Inverse function (default = 1 / x)

    Returns:
        An NxN phylogenetic weight matrix
    """
    N = tree.n_cell
    W = pd.DataFrame(np.zeros((N, N)), index=tree.leaves, columns=tree.leaves)

    for leaf1 in tree.leaves:

        distances = tree.get_distances(leaf1, leaves_only=True)
        for leaf2, _d in distances.items():

            if inverse:
                _d = inverse_fn(_d) if _d > 0 else np.inf

            W.loc[leaf1, leaf2] = W.loc[leaf2, leaf1] = _d

    np.fill_diagonal(W.values, 0)

    return W
예제 #2
0
def place_tree(
    tree: CassiopeiaTree,
    depth_key: Optional[str] = None,
    orient: Union[Literal["down", "up", "left", "right"], float] = "down",
    depth_scale: float = 1.0,
    width_scale: float = 1.0,
    extend_branches: bool = True,
    angled_branches: bool = True,
    polar_interpolation_threshold: float = 5.0,
    polar_interpolation_step: float = 1.0,
    add_root: bool = False,
) -> Tuple[Dict[str, Tuple[float, float]], Dict[Tuple[str, str], Tuple[
        List[float], List[float]]], ]:
    """Given a tree, computes the coordinates of the nodes and branches.

    This function computes the x and y coordinates of all nodes and branches (as
    lines) to be used for visualization. Several options are provided to
    modify how the elements are placed. This function returns two dictionaries,
    where the first has nodes as keys and an (x, y) tuple as the values.
    Similarly, the second dictionary has (parent, child) tuples as keys denoting
    branches in the tree and a tuple (xs, ys) of two lists containing the x and
    y coordinates to draw the branch.

    This function also provides functionality to place the tree in polar
    coordinates as a circular plot when the `orient` is a number. In this case,
    the returned dictionaries have (thetas, radii) as its elements, which are
    the angles and radii in polar coordinates respectively.

    Note:
        This function only *places* (i.e. computes the x and y coordinates) a
        tree on a coordinate system and does no real plotting.

    Args:
        tree: The CassiopeiaTree to place on the coordinate grid.
        depth_key: The node attribute to use as the depth of the nodes. If
            not provided, the distances from the root is used by calling
            `tree.get_distances`.
        orient: The orientation of the tree. Valid arguments are `left`, `right`,
            `up`, `down` to display a rectangular plot (indicating the direction
            of going from root -> leaves) or any number, in which case the
            tree is placed in polar coordinates with the provided number used
            as an angle offset.
        depth_scale: Scale the depth of the tree by this amount. This option
            has no effect in polar coordinates.
        width_scale: Scale the width of the tree by this amount. This option
            has no effect in polar coordinates.
        extend_branches: Extend branch lengths such that the distance from the
            root to every node is the same. If `depth_key` is also provided, then
            only the leaf branches are extended to the deepest leaf.
        angled_branches: Display branches as angled, instead of as just a
            line from the parent to a child.
        polar_interpolation_threshold: When displaying in polar coordinates,
            many plotting frameworks (such as Plotly) just draws a straight line
            from one point to another, instead of scaling the radius appropriately.
            This effect is most noticeable when two points are connected that
            have a large angle between them. When the angle between two connected
            points in polar coordinates exceed this amount (in degrees), this function
            adds additional points that smoothly interpolate between the two
            endpoints.
        polar_interpolation_step: Interpolation step. See above.
        add_root: Add a root node so that only one branch connects to the
            start of the tree. This node will have the name `synthetic_root`.

    Returns:
        Two dictionaries, where the first contains the node coordinates and
            the second contains the branch coordinates.
    """
    root = tree.root
    nodes = tree.nodes.copy()
    edges = tree.edges.copy()
    depths = None
    if depth_key:
        depths = {
            node: tree.get_attribute(node, depth_key)
            for node in tree.nodes
        }
    else:
        depths = tree.get_distances(root)

    placement_depths = {}
    positions = {}
    leaf_i = 0
    leaves = set()
    for node in tree.depth_first_traverse_nodes(postorder=False):
        if tree.is_leaf(node):
            positions[node] = leaf_i
            leaf_i += 1
            leaves.add(node)
            placement_depths[node] = depths[node]
            if extend_branches:
                placement_depths[node] = (max(depths.values())
                                          if depth_key else 0)

    # Place nodes by depth
    for node in sorted(depths, key=lambda k: depths[k], reverse=True):
        # Leaves have already been placed
        if node in leaves:
            continue

        # Find all immediate children and place at center.
        min_position = np.inf
        max_position = -np.inf
        min_depth = np.inf
        for child in tree.children(node):
            min_position = min(min_position, positions[child])
            max_position = max(max_position, positions[child])
            min_depth = min(min_depth, placement_depths[child])
        positions[node] = (min_position + max_position) / 2
        placement_depths[node] = (min_depth - 1 if extend_branches
                                  and not depth_key else depths[node])
    # Add synthetic root
    if add_root:
        root_name = "synthetic_root"
        positions[root_name] = 0
        placement_depths[root_name] = min(placement_depths.values()) - 1
        nodes.append(root_name)
        edges.append((root_name, root))

    polar = isinstance(orient, (float, int))
    polar_depth_offset = -min(placement_depths.values())
    polar_angle_scale = 360 / (len(leaves) + 1)

    # Define some helper functions to modify coordinate system.
    def reorient(pos, depth):
        pos *= width_scale
        depth *= depth_scale
        if orient == "down":
            return (pos, -depth)
        elif orient == "right":
            return (depth, pos)
        elif orient == "left":
            return (-depth, pos)
        # default: up
        return (pos, depth)

    def polarize(pos, depth):
        # angle, radius
        return (
            (pos + 1) * polar_angle_scale + orient,
            depth + polar_depth_offset,
        )

    node_coords = {}
    for node in nodes:
        pos = positions[node]
        depth = placement_depths[node]
        coords = polarize(pos, depth) if polar else reorient(pos, depth)
        node_coords[node] = coords

    branch_coords = {}
    for parent, child in edges:
        parent_pos, parent_depth = positions[parent], placement_depths[parent]
        child_pos, child_depth = positions[child], placement_depths[child]

        middle_x = (child_pos if angled_branches else
                    (parent_pos + child_pos) / 2)
        middle_y = (parent_depth if angled_branches else
                    (parent_depth + child_depth) / 2)
        _xs = [parent_pos, middle_x, child_pos]
        _ys = [parent_depth, middle_y, child_depth]

        xs = []
        ys = []
        for _x, _y in zip(_xs, _ys):
            if polar:
                x, y = polarize(_x, _y)

                if xs:
                    # Interpolate to prevent weird angles if the angle exceeds threshold.
                    prev_x = xs[-1]
                    prev_y = ys[-1]
                    if abs(x - prev_x) > polar_interpolation_threshold:
                        num = int(abs(x - prev_x) / polar_interpolation_step)
                        for inter_x, inter_y in zip(
                                np.linspace(prev_x, x, num)[1:-1],
                                np.linspace(prev_y, y, num)[1:-1],
                        ):
                            xs.append(inter_x)
                            ys.append(inter_y)
            else:
                x, y = reorient(_x, _y)
            xs.append(x)
            ys.append(y)

        branch_coords[(parent, child)] = (xs, ys)

    return node_coords, branch_coords