def compute_phylogenetic_weight_matrix( tree: CassiopeiaTree, inverse: bool = False, inverse_fn: Callable[[Union[int, float]], float] = lambda x: 1 / x, ) -> pd.DataFrame: """Computes the phylogenetic weight matrix. Computes the distances between all leaves in a tree. The user has the option to return the inverse matrix, (i.e., transform distances to proximities) and specify an appropriate inverse function. This function computes the phylogenetic weight matrix in O(n^2 logn) time. An NxN weight matrix is returned. Args: tree: CassiopeiaTree inverse: Convert distances to proximities inverse_fn: Inverse function (default = 1 / x) Returns: An NxN phylogenetic weight matrix """ N = tree.n_cell W = pd.DataFrame(np.zeros((N, N)), index=tree.leaves, columns=tree.leaves) for leaf1 in tree.leaves: distances = tree.get_distances(leaf1, leaves_only=True) for leaf2, _d in distances.items(): if inverse: _d = inverse_fn(_d) if _d > 0 else np.inf W.loc[leaf1, leaf2] = W.loc[leaf2, leaf1] = _d np.fill_diagonal(W.values, 0) return W
def place_tree( tree: CassiopeiaTree, depth_key: Optional[str] = None, orient: Union[Literal["down", "up", "left", "right"], float] = "down", depth_scale: float = 1.0, width_scale: float = 1.0, extend_branches: bool = True, angled_branches: bool = True, polar_interpolation_threshold: float = 5.0, polar_interpolation_step: float = 1.0, add_root: bool = False, ) -> Tuple[Dict[str, Tuple[float, float]], Dict[Tuple[str, str], Tuple[ List[float], List[float]]], ]: """Given a tree, computes the coordinates of the nodes and branches. This function computes the x and y coordinates of all nodes and branches (as lines) to be used for visualization. Several options are provided to modify how the elements are placed. This function returns two dictionaries, where the first has nodes as keys and an (x, y) tuple as the values. Similarly, the second dictionary has (parent, child) tuples as keys denoting branches in the tree and a tuple (xs, ys) of two lists containing the x and y coordinates to draw the branch. This function also provides functionality to place the tree in polar coordinates as a circular plot when the `orient` is a number. In this case, the returned dictionaries have (thetas, radii) as its elements, which are the angles and radii in polar coordinates respectively. Note: This function only *places* (i.e. computes the x and y coordinates) a tree on a coordinate system and does no real plotting. Args: tree: The CassiopeiaTree to place on the coordinate grid. depth_key: The node attribute to use as the depth of the nodes. If not provided, the distances from the root is used by calling `tree.get_distances`. orient: The orientation of the tree. Valid arguments are `left`, `right`, `up`, `down` to display a rectangular plot (indicating the direction of going from root -> leaves) or any number, in which case the tree is placed in polar coordinates with the provided number used as an angle offset. depth_scale: Scale the depth of the tree by this amount. This option has no effect in polar coordinates. width_scale: Scale the width of the tree by this amount. This option has no effect in polar coordinates. extend_branches: Extend branch lengths such that the distance from the root to every node is the same. If `depth_key` is also provided, then only the leaf branches are extended to the deepest leaf. angled_branches: Display branches as angled, instead of as just a line from the parent to a child. polar_interpolation_threshold: When displaying in polar coordinates, many plotting frameworks (such as Plotly) just draws a straight line from one point to another, instead of scaling the radius appropriately. This effect is most noticeable when two points are connected that have a large angle between them. When the angle between two connected points in polar coordinates exceed this amount (in degrees), this function adds additional points that smoothly interpolate between the two endpoints. polar_interpolation_step: Interpolation step. See above. add_root: Add a root node so that only one branch connects to the start of the tree. This node will have the name `synthetic_root`. Returns: Two dictionaries, where the first contains the node coordinates and the second contains the branch coordinates. """ root = tree.root nodes = tree.nodes.copy() edges = tree.edges.copy() depths = None if depth_key: depths = { node: tree.get_attribute(node, depth_key) for node in tree.nodes } else: depths = tree.get_distances(root) placement_depths = {} positions = {} leaf_i = 0 leaves = set() for node in tree.depth_first_traverse_nodes(postorder=False): if tree.is_leaf(node): positions[node] = leaf_i leaf_i += 1 leaves.add(node) placement_depths[node] = depths[node] if extend_branches: placement_depths[node] = (max(depths.values()) if depth_key else 0) # Place nodes by depth for node in sorted(depths, key=lambda k: depths[k], reverse=True): # Leaves have already been placed if node in leaves: continue # Find all immediate children and place at center. min_position = np.inf max_position = -np.inf min_depth = np.inf for child in tree.children(node): min_position = min(min_position, positions[child]) max_position = max(max_position, positions[child]) min_depth = min(min_depth, placement_depths[child]) positions[node] = (min_position + max_position) / 2 placement_depths[node] = (min_depth - 1 if extend_branches and not depth_key else depths[node]) # Add synthetic root if add_root: root_name = "synthetic_root" positions[root_name] = 0 placement_depths[root_name] = min(placement_depths.values()) - 1 nodes.append(root_name) edges.append((root_name, root)) polar = isinstance(orient, (float, int)) polar_depth_offset = -min(placement_depths.values()) polar_angle_scale = 360 / (len(leaves) + 1) # Define some helper functions to modify coordinate system. def reorient(pos, depth): pos *= width_scale depth *= depth_scale if orient == "down": return (pos, -depth) elif orient == "right": return (depth, pos) elif orient == "left": return (-depth, pos) # default: up return (pos, depth) def polarize(pos, depth): # angle, radius return ( (pos + 1) * polar_angle_scale + orient, depth + polar_depth_offset, ) node_coords = {} for node in nodes: pos = positions[node] depth = placement_depths[node] coords = polarize(pos, depth) if polar else reorient(pos, depth) node_coords[node] = coords branch_coords = {} for parent, child in edges: parent_pos, parent_depth = positions[parent], placement_depths[parent] child_pos, child_depth = positions[child], placement_depths[child] middle_x = (child_pos if angled_branches else (parent_pos + child_pos) / 2) middle_y = (parent_depth if angled_branches else (parent_depth + child_depth) / 2) _xs = [parent_pos, middle_x, child_pos] _ys = [parent_depth, middle_y, child_depth] xs = [] ys = [] for _x, _y in zip(_xs, _ys): if polar: x, y = polarize(_x, _y) if xs: # Interpolate to prevent weird angles if the angle exceeds threshold. prev_x = xs[-1] prev_y = ys[-1] if abs(x - prev_x) > polar_interpolation_threshold: num = int(abs(x - prev_x) / polar_interpolation_step) for inter_x, inter_y in zip( np.linspace(prev_x, x, num)[1:-1], np.linspace(prev_y, y, num)[1:-1], ): xs.append(inter_x) ys.append(inter_y) else: x, y = reorient(_x, _y) xs.append(x) ys.append(y) branch_coords[(parent, child)] = (xs, ys) return node_coords, branch_coords