Esempio n. 1
0
def place_tree(
    tree: CassiopeiaTree,
    depth_key: Optional[str] = None,
    orient: Union[Literal["down", "up", "left", "right"], float] = "down",
    depth_scale: float = 1.0,
    width_scale: float = 1.0,
    extend_branches: bool = True,
    angled_branches: bool = True,
    polar_interpolation_threshold: float = 5.0,
    polar_interpolation_step: float = 1.0,
    add_root: bool = False,
) -> Tuple[Dict[str, Tuple[float, float]], Dict[Tuple[str, str], Tuple[
        List[float], List[float]]], ]:
    """Given a tree, computes the coordinates of the nodes and branches.

    This function computes the x and y coordinates of all nodes and branches (as
    lines) to be used for visualization. Several options are provided to
    modify how the elements are placed. This function returns two dictionaries,
    where the first has nodes as keys and an (x, y) tuple as the values.
    Similarly, the second dictionary has (parent, child) tuples as keys denoting
    branches in the tree and a tuple (xs, ys) of two lists containing the x and
    y coordinates to draw the branch.

    This function also provides functionality to place the tree in polar
    coordinates as a circular plot when the `orient` is a number. In this case,
    the returned dictionaries have (thetas, radii) as its elements, which are
    the angles and radii in polar coordinates respectively.

    Note:
        This function only *places* (i.e. computes the x and y coordinates) a
        tree on a coordinate system and does no real plotting.

    Args:
        tree: The CassiopeiaTree to place on the coordinate grid.
        depth_key: The node attribute to use as the depth of the nodes. If
            not provided, the distances from the root is used by calling
            `tree.get_distances`.
        orient: The orientation of the tree. Valid arguments are `left`, `right`,
            `up`, `down` to display a rectangular plot (indicating the direction
            of going from root -> leaves) or any number, in which case the
            tree is placed in polar coordinates with the provided number used
            as an angle offset.
        depth_scale: Scale the depth of the tree by this amount. This option
            has no effect in polar coordinates.
        width_scale: Scale the width of the tree by this amount. This option
            has no effect in polar coordinates.
        extend_branches: Extend branch lengths such that the distance from the
            root to every node is the same. If `depth_key` is also provided, then
            only the leaf branches are extended to the deepest leaf.
        angled_branches: Display branches as angled, instead of as just a
            line from the parent to a child.
        polar_interpolation_threshold: When displaying in polar coordinates,
            many plotting frameworks (such as Plotly) just draws a straight line
            from one point to another, instead of scaling the radius appropriately.
            This effect is most noticeable when two points are connected that
            have a large angle between them. When the angle between two connected
            points in polar coordinates exceed this amount (in degrees), this function
            adds additional points that smoothly interpolate between the two
            endpoints.
        polar_interpolation_step: Interpolation step. See above.
        add_root: Add a root node so that only one branch connects to the
            start of the tree. This node will have the name `synthetic_root`.

    Returns:
        Two dictionaries, where the first contains the node coordinates and
            the second contains the branch coordinates.
    """
    root = tree.root
    nodes = tree.nodes.copy()
    edges = tree.edges.copy()
    depths = None
    if depth_key:
        depths = {
            node: tree.get_attribute(node, depth_key)
            for node in tree.nodes
        }
    else:
        depths = tree.get_distances(root)

    placement_depths = {}
    positions = {}
    leaf_i = 0
    leaves = set()
    for node in tree.depth_first_traverse_nodes(postorder=False):
        if tree.is_leaf(node):
            positions[node] = leaf_i
            leaf_i += 1
            leaves.add(node)
            placement_depths[node] = depths[node]
            if extend_branches:
                placement_depths[node] = (max(depths.values())
                                          if depth_key else 0)

    # Place nodes by depth
    for node in sorted(depths, key=lambda k: depths[k], reverse=True):
        # Leaves have already been placed
        if node in leaves:
            continue

        # Find all immediate children and place at center.
        min_position = np.inf
        max_position = -np.inf
        min_depth = np.inf
        for child in tree.children(node):
            min_position = min(min_position, positions[child])
            max_position = max(max_position, positions[child])
            min_depth = min(min_depth, placement_depths[child])
        positions[node] = (min_position + max_position) / 2
        placement_depths[node] = (min_depth - 1 if extend_branches
                                  and not depth_key else depths[node])
    # Add synthetic root
    if add_root:
        root_name = "synthetic_root"
        positions[root_name] = 0
        placement_depths[root_name] = min(placement_depths.values()) - 1
        nodes.append(root_name)
        edges.append((root_name, root))

    polar = isinstance(orient, (float, int))
    polar_depth_offset = -min(placement_depths.values())
    polar_angle_scale = 360 / (len(leaves) + 1)

    # Define some helper functions to modify coordinate system.
    def reorient(pos, depth):
        pos *= width_scale
        depth *= depth_scale
        if orient == "down":
            return (pos, -depth)
        elif orient == "right":
            return (depth, pos)
        elif orient == "left":
            return (-depth, pos)
        # default: up
        return (pos, depth)

    def polarize(pos, depth):
        # angle, radius
        return (
            (pos + 1) * polar_angle_scale + orient,
            depth + polar_depth_offset,
        )

    node_coords = {}
    for node in nodes:
        pos = positions[node]
        depth = placement_depths[node]
        coords = polarize(pos, depth) if polar else reorient(pos, depth)
        node_coords[node] = coords

    branch_coords = {}
    for parent, child in edges:
        parent_pos, parent_depth = positions[parent], placement_depths[parent]
        child_pos, child_depth = positions[child], placement_depths[child]

        middle_x = (child_pos if angled_branches else
                    (parent_pos + child_pos) / 2)
        middle_y = (parent_depth if angled_branches else
                    (parent_depth + child_depth) / 2)
        _xs = [parent_pos, middle_x, child_pos]
        _ys = [parent_depth, middle_y, child_depth]

        xs = []
        ys = []
        for _x, _y in zip(_xs, _ys):
            if polar:
                x, y = polarize(_x, _y)

                if xs:
                    # Interpolate to prevent weird angles if the angle exceeds threshold.
                    prev_x = xs[-1]
                    prev_y = ys[-1]
                    if abs(x - prev_x) > polar_interpolation_threshold:
                        num = int(abs(x - prev_x) / polar_interpolation_step)
                        for inter_x, inter_y in zip(
                                np.linspace(prev_x, x, num)[1:-1],
                                np.linspace(prev_y, y, num)[1:-1],
                        ):
                            xs.append(inter_x)
                            ys.append(inter_y)
            else:
                x, y = reorient(_x, _y)
            xs.append(x)
            ys.append(y)

        branch_coords[(parent, child)] = (xs, ys)

    return node_coords, branch_coords
Esempio n. 2
0
def fitch_hartigan_bottom_up(
    cassiopeia_tree: CassiopeiaTree,
    meta_item: str,
    add_key: str = "S1",
    copy: bool = False,
) -> Optional[CassiopeiaTree]:
    """Performs Fitch-Hartigan bottom-up ancestral reconstruction.

    Performs the bottom-up phase of the Fitch-Hartigan small parsimony
    algorithm. A new attribute called "S1" will be added to each node
    storing the optimal set of ancestral states inferred from this bottom-up 
    algorithm. If copy is False, the tree will be modified in place.
     

    Args:
        cassiopeia_tree: CassiopeiaTree object with cell meta data.
        meta_item: A column in the CassiopeiaTree cell meta corresponding to a
            categorical variable.
        add_key: Key to add for bottom-up reconstruction
        copy: Modify the tree in place or not.

    Returns:
        A new CassiopeiaTree if the copy is set to True, else None.

    Raises:
        CassiopeiaError if the tree does not have the specified meta data
            or the meta data is not categorical.
    """

    if meta_item not in cassiopeia_tree.cell_meta.columns:
        raise CassiopeiaError(
            "Meta item does not exist in the cassiopeia tree")

    meta = cassiopeia_tree.cell_meta[meta_item]

    if is_numeric_dtype(meta):
        raise CassiopeiaError("Meta item is not a categorical variable.")

    if not is_categorical_dtype(meta):
        meta = meta.astype("category")

    cassiopeia_tree = cassiopeia_tree.copy() if copy else cassiopeia_tree

    for node in cassiopeia_tree.depth_first_traverse_nodes():

        if cassiopeia_tree.is_leaf(node):
            cassiopeia_tree.set_attribute(node, add_key, [meta.loc[node]])

        else:
            children = cassiopeia_tree.children(node)
            if len(children) == 1:
                child_assignment = cassiopeia_tree.get_attribute(
                    children[0], add_key)
                cassiopeia_tree.set_attribute(node, add_key,
                                              [child_assignment])

            all_labels = np.concatenate([
                cassiopeia_tree.get_attribute(child, add_key)
                for child in children
            ])
            states, frequencies = np.unique(all_labels, return_counts=True)

            S1 = states[np.where(frequencies == np.max(frequencies))]
            cassiopeia_tree.set_attribute(node, add_key, S1)

    return cassiopeia_tree if copy else None
Esempio n. 3
0
def log_likelihood_of_character(
    tree: CassiopeiaTree,
    character: int,
    use_internal_character_states: bool,
    mutation_probability_function_of_time: Callable[[float], float],
    missing_probability_function_of_time: Callable[[float], float],
    stochastic_missing_probability: float,
    implicit_root_branch_length: float,
) -> float:
    """Calculates the log likelihood of a given character on the tree.

    Calculates the log likelihood of a tree given the states at a given
    character in the leaves using Felsenstein's Pruning Algorithm, which sets
    up a recursive relation between the likelihoods of states at nodes for this
    character. The likelihood L(s, n) at a given state s at a given node n is:

    L(s, n) = Π_{n'}(Σ_{s'}(P(s'|s) * L(s', n')))

    for all n' that are children of n, and s' in the state space, with
    P(s'|s) being the transition probability from s to s'. That is,
    the likelihood at a given state at a given node is the product of
    the likelihoods of the states at this character at the children scaled by
    the probability of the current state transitioning to those states. This
    includes the missing state, as specified by `tree.missing_state_indicator`.

    We assume here that mutations are irreversible. Once a character mutates to
    a certain state that character cannot mutate again, with the exception of
    the fact that any non-missing state can mutate to a missing state.
    `mutation_probability_function_of_time` is expected to be a function that
    determine the probability of a mutation occuring given an amount of time.
    To determine the probability of acquiring a given (non-missing) state once
    a mutation occurs, the priors of the tree are used. Likewise,
    `missing_probability_function_of_time` determines the the probability of a
    missing data event occuring given an amount of time.

    The user can choose to use the character states annotated at internal
    nodes. If these are not used, then the likelihood is marginalized over
    all possible internal state characters. If the actual internal states
    are not provided, then the root is assumed to have the unmutated state
    at each character. Additionally, it is assumed that there is a single
    branch leading from the root that represents the roots' lifetime. If
    this branch does not exist and `use_internal_character_states` is set
    to False, then this branch is added with branch length equal to the
    average branch length of this tree.

    Args:
        tree: The tree on which to calculate the likelihood
        character: The index of the character to calculate the likelihood of
        use_internal_character_states: Indicates if internal node
            character states should be assumed to be specified exactly
        mutation_probability_function_of_time: The function defining the
            probability of a lineage acquiring a mutation within a given time
        missing_probability_function_of_time: The function defining the
            probability of a lineage acquiring heritable missing data within a
            given time
        stochastic_missing_probability: The probability that a cell/character
            pair acquires stochastic missing data at the end of the lineage
        implicit_root_branch_length: The length of the implicit root branch.
            Used if the implicit root needs to be added

    Returns:
        The log likelihood of the tree on one character
    """

    # This dictionary uses a nested dictionary structure. Each node is mapped
    # to a dictionary storing the likelihood for each possible state
    # (states that have non-0 likelihood)
    likelihoods_at_nodes = {}

    # Perform a DFS to propagate the likelihood from the leaves
    for n in tree.depth_first_traverse_nodes(postorder=True):
        state_at_n = tree.get_character_states(n)
        # If states are observed, their likelihoods are set to 1
        if tree.is_leaf(n):
            likelihoods_at_nodes[n] = {state_at_n[character]: 0}
            continue

        possible_states = []
        # If internal character states are to be used, then the likelihood
        # for all other states are ignored. Otherwise, marginalize over
        # only states that do not break irreversibility, as all states that
        # do have likelihood of 0
        if use_internal_character_states:
            possible_states = [state_at_n[character]]
        else:
            child_possible_states = []
            for c in [
                    set(likelihoods_at_nodes[child])
                    for child in tree.children(n)
            ]:
                if tree.missing_state_indicator not in c and "&" not in c:
                    child_possible_states.append(c)
            # "&" stands in for any non-missing state (including uncut), and
            # is a possible state when all children are missing, as any
            # state could have occurred at the parent if all missing data
            # events occurred independently. Used to avoid marginalizing
            # over the entire state space.
            if child_possible_states == []:
                possible_states = [
                    "&",
                    tree.missing_state_indicator,
                ]
            else:
                possible_states = list(
                    set.intersection(*child_possible_states))
                if 0 not in possible_states:
                    possible_states.append(0)

        # This stores the likelihood of each possible state at the current node
        likelihoods_per_state_at_n = {}

        # We calculate the likelihood of the states at the current node
        # according to the recurrence relation. For each state, we marginalize
        # over the likelihoods of the states that it could transition to in the
        # daughter nodes
        for s in possible_states:
            likelihood_for_s = 0
            for child in tree.children(n):
                likelihoods_for_s_marginalize_over_s_ = []
                for s_ in likelihoods_at_nodes[child]:
                    likelihood_s_ = (log_transition_probability(
                        tree,
                        character,
                        s,
                        s_,
                        tree.get_branch_length(n, child),
                        mutation_probability_function_of_time,
                        missing_probability_function_of_time,
                    ) + likelihoods_at_nodes[child][s_])
                    # Here we take into account the probability of
                    # stochastic missing data
                    if tree.is_leaf(child):
                        if (s_ == tree.missing_state_indicator
                                and s != tree.missing_state_indicator):
                            likelihood_s_ = np.log(
                                np.exp(likelihood_s_) +
                                (1 - missing_probability_function_of_time(
                                    tree.get_branch_length(n, child))) *
                                stochastic_missing_probability)
                        if s_ != tree.missing_state_indicator:
                            likelihood_s_ += np.log(
                                1 - stochastic_missing_probability)
                    likelihoods_for_s_marginalize_over_s_.append(likelihood_s_)
                likelihood_for_s += scipy.special.logsumexp(
                    np.array(likelihoods_for_s_marginalize_over_s_))
            likelihoods_per_state_at_n[s] = likelihood_for_s

        likelihoods_at_nodes[n] = likelihoods_per_state_at_n

    # If we are not to use the internal state annotations explicitly,
    # then we assume an implicit root where each state is the uncut state (0)
    # Thus, we marginalize over the transition from 0 in the implicit root
    # to all non-0 states in its child
    if not use_internal_character_states:
        # If the implicit root does not exist in the tree, then we impose it,
        # with the length of the branch being specified as
        # `implicit_root_branch_length`. Otherwise, we just use the existing
        # root with a singleton child as the implicit root
        if len(tree.children(tree.root)) != 1:

            likelihood_contribution_from_each_root_state = [
                log_transition_probability(
                    tree,
                    character,
                    0,
                    s_,
                    implicit_root_branch_length,
                    mutation_probability_function_of_time,
                    missing_probability_function_of_time,
                ) + likelihoods_at_nodes[tree.root][s_]
                for s_ in likelihoods_at_nodes[tree.root]
            ]
            likelihood_at_implicit_root = scipy.special.logsumexp(
                likelihood_contribution_from_each_root_state)

            return likelihood_at_implicit_root

        else:
            # Here we account for the edge case in which all of the leaves are
            # missing, in which case the root will have "&" in place of 0. The
            # likelihood at "&" will have the same likelihood as 0 based on the
            # transition rules regarding "&". As "&" is a placeholder when the
            # state is unknown, this can be thought of realizing "&" as 0.
            if 0 not in likelihoods_at_nodes[tree.root]:
                return likelihoods_at_nodes[tree.root]["&"]
            else:
                # Otherwise, we return the likelihood of the 0 state at the
                # existing implicit root
                return likelihoods_at_nodes[tree.root][0]

    # If we use the internal state annotations explicitly, then we return
    # the likelihood of the state annotated at this character at the root
    else:
        return list(likelihoods_at_nodes[tree.root].values())[0]
Esempio n. 4
0
def calculate_parsimony(
    tree: CassiopeiaTree,
    infer_ancestral_characters: bool = False,
    treat_missing_as_mutation: bool = False,
) -> int:
    """
    Calculates the number of mutations that have occurred on a tree.

    Calculates the parsimony, defined as the number of character/state
    mutations that occur on edges of the tree, from the character state
    annotations at the nodes. A mutation is said to have occurred on an
    edge if a state is present at a character at the child node and this
    state is not in the parent node.

    If `infer_ancestral_characters` is set to True, then the internal
    nodes' character states are inferred by Camin-Sokal Parsimony from the
    current character states at the leaves. Use
    `tree.set_character_states_at_leaves` to use a different layer to infer
    ancestral states. Otherwise, the current annotations at the internal
    states are used. If `treat_missing_as_mutations` is set to True, then
    transitions from a non-missing state to a missing state are counted in
    the parsimony calculation. Otherwise, they are not included.

    Args:
        tree: The tree to calculate parsimony over
        infer_ancestral_characters: Whether to infer the ancestral
            characters states of the tree
        treat_missing_as_mutations: Whether to treat missing states as
            mutations

    Returns:
        The number of mutations that have occurred on the tree

    Raises:
        TreeMetricError if the tree has not been initialized or if
            a node does not have character states initialized
    """

    if infer_ancestral_characters:
        tree.reconstruct_ancestral_characters()

    parsimony = 0

    if tree.get_character_states(tree.root) == []:
        raise TreeMetricError(
            f"Character states empty at internal node. Annotate"
            " character states or infer ancestral characters by"
            " setting infer_ancestral_characters=True.")

    for u, v in tree.depth_first_traverse_edges():
        if tree.get_character_states(v) == []:
            if tree.is_leaf(v):
                raise TreeMetricError(
                    "Character states have not been initialized at leaves."
                    " Use set_character_states_at_leaves or populate_tree"
                    " with the character matrix that specifies the leaf"
                    " character states.")
            else:
                raise TreeMetricError(
                    f"Character states empty at internal node. Annotate"
                    " character states or infer ancestral characters by"
                    " setting infer_ancestral_characters=True.")

        parsimony += len(
            tree.get_mutations_along_edge(u, v, treat_missing_as_mutation))

    return parsimony
Esempio n. 5
0
def plot_plotly(
    tree: CassiopeiaTree,
    depth_key: Optional[str] = None,
    meta_data: Optional[List[str]] = None,
    allele_table: Optional[pd.DataFrame] = None,
    indel_colors: Optional[pd.DataFrame] = None,
    indel_priors: Optional[pd.DataFrame] = None,
    orient: Union[Literal["up", "down", "left", "right"], float] = 90.0,
    extend_branches: bool = True,
    angled_branches: bool = True,
    add_root: bool = False,
    width: float = 500.0,
    height: float = 500.0,
    colorstrip_width: Optional[float] = None,
    colorstrip_spacing: Optional[float] = None,
    clade_colors: Optional[Dict[str, Tuple[float, float, float]]] = None,
    internal_node_kwargs: Optional[Dict] = None,
    leaf_kwargs: Optional[Dict] = None,
    branch_kwargs: Optional[Dict] = None,
    colorstrip_kwargs: Optional[Dict] = None,
    continuous_cmap: Union[str, mpl.colors.Colormap] = "viridis",
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    categorical_cmap: Union[str, mpl.colors.Colormap] = "tab10",
    value_mapping: Optional[Dict[str, int]] = None,
    figure: Optional[go.Figure] = None,
    random_state: Optional[np.random.RandomState] = None,
) -> go.Figure:
    """Generate a static plot of a tree using Plotly.

    Args:
        tree: The CassiopeiaTree to plot.
        depth_key: The node attribute to use as the depth of the nodes. If
            not provided, the distances from the root is used by calling
            `tree.get_distances`.
        meta_data: Meta data to plot alongside the tree, which must be columns
            in the CassiopeiaTree.cell_meta variable.
        allele_table: Allele table to plot alongside the tree.
        indel_colors: Color mapping to use for plotting the alleles for each
            cell. Only necessary if `allele_table` is specified.
        indel_priors: Prior probabilities for each indel. Only useful if an
            allele table is to be plotted and `indel_colors` is None.
        orient: The orientation of the tree. Valid arguments are `left`, `right`,
            `up`, `down` to display a rectangular plot (indicating the direction
            of going from root -> leaves) or any number, in which case the
            tree is placed in polar coordinates with the provided number used
            as an angle offset. Defaults to 90.
        extend_branches: Extend branch lengths such that the distance from the
            root to every node is the same. If `depth_key` is also provided, then
            only the leaf branches are extended to the deepest leaf.
        angled_branches: Display branches as angled, instead of as just a
            line from the parent to a child.
        add_root: Add a root node so that only one branch connects to the
            start of the tree. This node will have the name `synthetic_root`.
        width: Width of the figure.
        height: Height of the figure.
        colorstrip_width: Width of the colorstrip. Width is defined as the
            length in the direction of the leaves. Defaults to 5% of the tree
            depth.
        colorstrip_spacing: Space between consecutive colorstrips. Defaults to
            half of `colorstrip_width`.
        clade_colors: Dictionary containing internal node-color mappings. These
            colors will be used to color all the paths from this node to the
            leaves the provided color.
        internal_node_kwargs: Keyword arguments to pass to `plt.scatter` when
            plotting internal nodes.
        leaf_kwargs: Keyword arguments to pass to `plt.scatter` when
            plotting leaf nodes.
        branch_kwargs: Keyword arguments to pass to `plt.plot` when plotting
            branches.
        colorstrip_kwargs: Keyword arguments to pass to `plt.fill` when plotting
            colorstrips.
        continuous_cmap: Colormap to use for continuous variables. Defaults to
            `viridis`.
        vmin: Value representing the lower limit of the color scale. Only applied
            to continuous variables.
        vmax: Value representing the upper limit of the color scale. Only applied
            to continuous variables.
        categorical_cmap: Colormap to use for categorical variables. Defaults to
            `tab10`.
        value_mapping: An optional dictionary containing string values to their
            integer mappings. These mappings are used to assign colors by
            calling the `cmap` with the designated integer mapping. By default,
            the values are assigned pseudo-randomly (whatever order the set()
            operation returns). Only applied for categorical variables.
        figure: Plotly figure to plot the tree.
        random_state: A random state for reproducibility

    Returns:
        The Plotly figure.
    """
    # Warn user if there are many leaves
    if len(tree.leaves) > 2000:
        warnings.warn(
            "Tree has greater than 2000 leaves. This may take a while.",
            PlottingWarning,
        )

    is_polar = isinstance(orient, (float, int))
    (
        node_coords,
        branch_coords,
        node_colors,
        branch_colors,
        colorstrips,
    ) = place_tree_and_annotations(
        tree,
        depth_key,
        meta_data,
        allele_table,
        indel_colors,
        indel_priors,
        orient,
        extend_branches,
        angled_branches,
        add_root,
        colorstrip_width,
        colorstrip_spacing,
        clade_colors,
        continuous_cmap,
        vmin,
        vmax,
        categorical_cmap,
        value_mapping,
        random_state,
    )
    figure = figure if figure is not None else go.Figure()

    # Plot all nodes
    _leaf_kwargs = dict(
        x=[],
        y=[],
        text=[],
        marker_size=3,
        marker_color="black",
        mode="markers",
        showlegend=False,
        hoverinfo="text",
    )
    # NOTE: setting marker_size=0 has no effect for some reason?
    _node_kwargs = dict(
        x=[],
        y=[],
        text=[],
        marker_size=0.1,
        marker_color="black",
        mode="markers",
        showlegend=False,
        hoverinfo="text",
    )
    _leaf_kwargs.update(leaf_kwargs or {})
    _node_kwargs.update(internal_node_kwargs or {})
    for node, (x, y) in node_coords.items():
        if node in node_colors:
            continue
        text = f"<b>NODE</b><br>{node}"
        if is_polar:
            x, y = utilities.polar_to_cartesian(x, y)
        if tree.is_leaf(node):
            _leaf_kwargs["x"].append(x)
            _leaf_kwargs["y"].append(y)
            _leaf_kwargs["text"].append(text)
        else:
            _node_kwargs["x"].append(x)
            _node_kwargs["y"].append(y)
            _node_kwargs["text"].append(text)
    figure.add_trace(go.Scatter(**_leaf_kwargs))
    figure.add_trace(go.Scatter(**_node_kwargs))

    _leaf_colors = []
    _node_colors = []
    _leaf_kwargs.update({"x": [], "y": [], "text": []})
    _node_kwargs.update({"x": [], "y": [], "text": []})
    for node, color in node_colors.items():
        x, y = node_coords[node]
        text = f"<b>NODE</b><br>{node}"
        if is_polar:
            x, y = utilities.polar_to_cartesian(x, y)
        if tree.is_leaf(node):
            _leaf_kwargs["x"].append(x)
            _leaf_kwargs["y"].append(y)
            _leaf_kwargs["text"].append(text)
            _leaf_colors.append(color)
        else:
            _node_kwargs["x"].append(x)
            _node_kwargs["y"].append(y)
            _node_kwargs["text"].append(text)
            _node_colors.append(color)

    _leaf_kwargs["marker_color"] = _leaf_colors
    _node_kwargs["marker_color"] = _node_colors
    figure.add_trace(go.Scatter(**_leaf_kwargs))
    figure.add_trace(go.Scatter(**_node_kwargs))

    # Plot all branches
    _branch_kwargs = dict(
        x=[],
        y=[],
        text=[],
        line_color="black",
        line_width=1,
        mode="lines",
        showlegend=False,
        hoverinfo="text",
    )
    _branch_kwargs.update(branch_kwargs or {})
    for branch, (xs, ys) in branch_coords.items():
        if branch in branch_colors:
            continue
        _branch_kwargs["x"], _branch_kwargs["y"] = xs, ys
        text = f"<b>BRANCH</b><br>{branch[0]}<br>{branch[1]}"
        if is_polar:
            (
                _branch_kwargs["x"],
                _branch_kwargs["y"],
            ) = utilities.polars_to_cartesians(xs, ys)
        _branch_kwargs["text"] = [text] * len(xs)
        figure.add_trace(go.Scatter(**_branch_kwargs))

    for branch, color in branch_colors.items():
        xs, ys = branch_coords[branch]
        _branch_kwargs["x"], _branch_kwargs["y"] = xs, ys
        _branch_kwargs["line_color"] = color
        text = f"<b>BRANCH</b><br>{branch[0]}<br>{branch[1]}"
        if is_polar:
            (
                _branch_kwargs["x"],
                _branch_kwargs["y"],
            ) = utilities.polars_to_cartesians(xs, ys)
        _branch_kwargs["text"] = [text] * len(xs)
        figure.add_trace(go.Scatter(**_branch_kwargs))

    # Colorstrips
    _colorstrip_kwargs = dict(
        x=[],
        y=[],
        text=[],
        line_width=0,
        fill="toself",
        mode="lines",
        showlegend=False,
        hoverinfo="text",
        hoveron="fills",
    )
    _colorstrip_kwargs.update(colorstrip_kwargs or {})
    for colorstrip in colorstrips:
        for xs, ys, c, text in colorstrip.values():
            _colorstrip_kwargs["x"], _colorstrip_kwargs["y"] = xs, ys
            _colorstrip_kwargs["fillcolor"] = mpl.colors.to_hex(c)
            if is_polar:
                (
                    _colorstrip_kwargs["x"],
                    _colorstrip_kwargs["y"],
                ) = utilities.polars_to_cartesians(xs, ys)
            _colorstrip_kwargs["text"] = text.replace("\n", "<br>")
            figure.add_trace(go.Scatter(**_colorstrip_kwargs))

    figure.update_layout(
        width=width,
        height=height,
        xaxis=dict(showgrid=False, visible=False),
        yaxis=dict(showgrid=False, visible=False),
        margin=dict(l=0, r=0, t=0, b=0),
    )
    return figure
Esempio n. 6
0
def plot_matplotlib(
    tree: CassiopeiaTree,
    depth_key: Optional[str] = None,
    meta_data: Optional[List[str]] = None,
    allele_table: Optional[pd.DataFrame] = None,
    indel_colors: Optional[pd.DataFrame] = None,
    indel_priors: Optional[pd.DataFrame] = None,
    orient: Union[Literal["up", "down", "left", "right"], float] = 90.0,
    extend_branches: bool = True,
    angled_branches: bool = True,
    add_root: bool = False,
    figsize: Tuple[float, float] = (7.0, 7.0),
    colorstrip_width: Optional[float] = None,
    colorstrip_spacing: Optional[float] = None,
    clade_colors: Optional[Dict[str, Tuple[float, float, float]]] = None,
    internal_node_kwargs: Optional[Dict] = None,
    leaf_kwargs: Optional[Dict] = None,
    branch_kwargs: Optional[Dict] = None,
    colorstrip_kwargs: Optional[Dict] = None,
    continuous_cmap: Union[str, mpl.colors.Colormap] = "viridis",
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    categorical_cmap: Union[str, mpl.colors.Colormap] = "tab10",
    value_mapping: Optional[Dict[str, int]] = None,
    ax: Optional[plt.Axes] = None,
    random_state: Optional[np.random.RandomState] = None,
) -> Tuple[plt.Figure, plt.Axes]:
    """Generate a static plot of a tree using Matplotlib.

    Args:
        tree: The CassiopeiaTree to plot.
        depth_key: The node attribute to use as the depth of the nodes. If
            not provided, the distances from the root is used by calling
            `tree.get_distances`.
        meta_data: Meta data to plot alongside the tree, which must be columns
            in the CassiopeiaTree.cell_meta variable.
        allele_table: Allele table to plot alongside the tree.
        indel_colors: Color mapping to use for plotting the alleles for each
            cell. Only necessary if `allele_table` is specified.
        indel_priors: Prior probabilities for each indel. Only useful if an
            allele table is to be plotted and `indel_colors` is None.
        orient: The orientation of the tree. Valid arguments are `left`, `right`,
            `up`, `down` to display a rectangular plot (indicating the direction
            of going from root -> leaves) or any number, in which case the
            tree is placed in polar coordinates with the provided number used
            as an angle offset. Defaults to 90.
        extend_branches: Extend branch lengths such that the distance from the
            root to every node is the same. If `depth_key` is also provided, then
            only the leaf branches are extended to the deepest leaf.
        angled_branches: Display branches as angled, instead of as just a
            line from the parent to a child.
        add_root: Add a root node so that only one branch connects to the
            start of the tree. This node will have the name `synthetic_root`.
        figsize: Size of the plot. Defaults to (7., 7.,)
        colorstrip_width: Width of the colorstrip. Width is defined as the
            length in the direction of the leaves. Defaults to 5% of the tree
            depth.
        colorstrip_spacing: Space between consecutive colorstrips. Defaults to
            half of `colorstrip_width`.
        clade_colors: Dictionary containing internal node-color mappings. These
            colors will be used to color all the paths from this node to the
            leaves the provided color.
        internal_node_kwargs: Keyword arguments to pass to `plt.scatter` when
            plotting internal nodes.
        leaf_kwargs: Keyword arguments to pass to `plt.scatter` when
            plotting leaf nodes.
        branch_kwargs: Keyword arguments to pass to `plt.plot` when plotting
            branches.
        colorstrip_kwargs: Keyword arguments to pass to `plt.fill` when plotting
            colorstrips.
        continuous_cmap: Colormap to use for continuous variables. Defaults to
            `viridis`.
        vmin: Value representing the lower limit of the color scale. Only applied
            to continuous variables.
        vmax: Value representing the upper limit of the color scale. Only applied
            to continuous variables.
        categorical_cmap: Colormap to use for categorical variables. Defaults to
            `tab10`.
        value_mapping: An optional dictionary containing string values to their
            integer mappings. These mappings are used to assign colors by
            calling the `cmap` with the designated integer mapping. By default,
            the values are assigned pseudo-randomly (whatever order the set()
            operation returns). Only applied for categorical variables.
        ax: Matplotlib axis to place the tree. If not provided, a new figure is
            initialized.
        random_state: A random state for reproducibility

    Returns:
        If `ax` is provided, `ax` is returned. Otherwise, a tuple of (fig, ax)
            of the newly initialized figure and axis.
    """
    is_polar = isinstance(orient, (float, int))
    (
        node_coords,
        branch_coords,
        node_colors,
        branch_colors,
        colorstrips,
    ) = place_tree_and_annotations(
        tree,
        depth_key,
        meta_data,
        allele_table,
        indel_colors,
        indel_priors,
        orient,
        extend_branches,
        angled_branches,
        add_root,
        colorstrip_width,
        colorstrip_spacing,
        clade_colors,
        continuous_cmap,
        vmin,
        vmax,
        categorical_cmap,
        value_mapping,
        random_state,
    )

    fig = None
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize, tight_layout=True)
    ax.set_axis_off()

    # Plot all nodes
    _leaf_kwargs = dict(x=[], y=[], s=5, c="black")
    _node_kwargs = dict(x=[], y=[], s=0, c="black")
    _leaf_kwargs.update(leaf_kwargs or {})
    _node_kwargs.update(internal_node_kwargs or {})
    for node, (x, y) in node_coords.items():
        if node in node_colors:
            continue
        if is_polar:
            x, y = utilities.polar_to_cartesian(x, y)
        if tree.is_leaf(node):
            _leaf_kwargs["x"].append(x)
            _leaf_kwargs["y"].append(y)
        else:
            _node_kwargs["x"].append(x)
            _node_kwargs["y"].append(y)
    ax.scatter(**_leaf_kwargs)
    ax.scatter(**_node_kwargs)

    _leaf_colors = []
    _node_colors = []
    _leaf_kwargs.update({"x": [], "y": []})
    _node_kwargs.update({"x": [], "y": []})
    for node, color in node_colors.items():
        x, y = node_coords[node]
        if is_polar:
            x, y = utilities.polar_to_cartesian(x, y)
        if tree.is_leaf(node):
            _leaf_kwargs["x"].append(x)
            _leaf_kwargs["y"].append(y)
            _leaf_colors.append(color)
        else:
            _node_kwargs["x"].append(x)
            _node_kwargs["y"].append(y)
            _node_colors.append(color)

    _leaf_kwargs["c"] = _leaf_colors
    _node_kwargs["c"] = _node_colors
    ax.scatter(**_leaf_kwargs)
    ax.scatter(**_node_kwargs)

    # Plot all branches
    _branch_kwargs = dict(linewidth=1, c="black")
    _branch_kwargs.update(branch_kwargs or {})
    for branch, (xs, ys) in branch_coords.items():
        if branch in branch_colors:
            continue
        if is_polar:
            xs, ys = utilities.polars_to_cartesians(xs, ys)
        ax.plot(xs, ys, **_branch_kwargs)

    for branch, color in branch_colors.items():
        _branch_kwargs["c"] = color
        xs, ys = branch_coords[branch]
        if is_polar:
            xs, ys = utilities.polars_to_cartesians(xs, ys)
        ax.plot(xs, ys, **_branch_kwargs)

    # Colorstrips
    _colorstrip_kwargs = dict(linewidth=0)
    _colorstrip_kwargs.update(colorstrip_kwargs or {})
    for colorstrip in colorstrips:
        # Last element is text, but this can not be shown in static plotting.
        for xs, ys, c, _ in colorstrip.values():
            _colorstrip_kwargs["c"] = c
            if is_polar:
                xs, ys = utilities.polars_to_cartesians(xs, ys)
            ax.fill(xs, ys, **_colorstrip_kwargs)

    return (fig, ax) if fig is not None else ax
Esempio n. 7
0
def place_tree_and_annotations(
    tree: CassiopeiaTree,
    depth_key: Optional[str] = None,
    meta_data: Optional[List[str]] = None,
    allele_table: Optional[pd.DataFrame] = None,
    indel_colors: Optional[pd.DataFrame] = None,
    indel_priors: Optional[pd.DataFrame] = None,
    orient: Union[Literal["up", "down", "left", "right"], float] = 90.0,
    extend_branches: bool = True,
    angled_branches: bool = True,
    add_root: bool = False,
    colorstrip_width: Optional[float] = None,
    colorstrip_spacing: Optional[float] = None,
    clade_colors: Optional[Dict[str, Tuple[float, float, float]]] = None,
    continuous_cmap: Union[str, mpl.colors.Colormap] = "viridis",
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    categorical_cmap: Union[str, mpl.colors.Colormap] = "tab10",
    value_mapping: Optional[Dict[str, int]] = None,
    random_state: Optional[np.random.RandomState] = None,
) -> Tuple[Dict, Dict, Dict, Dict, List]:
    """Helper function to place the tree and all requested annotations.

    Args:
        tree: The CassiopeiaTree to plot.
        depth_key: The node attribute to use as the depth of the nodes. If
            not provided, the distances from the root is used by calling
            `tree.get_distances`.
        meta_data: Meta data to plot alongside the tree, which must be columns
            in the CassiopeiaTree.cell_meta variable.
        allele_table: Alleletable to plot alongside the tree.
        indel_colors: Color mapping to use for plotting the alleles for each
            cell. Only necessary if `allele_table` is specified.
        indel_priors: Prior probabilities for each indel. Only useful if an
            allele table is to be plotted and `indel_colors` is None.
        orient: The orientation of the tree. Valid arguments are `left`, `right`,
            `up`, `down` to display a rectangular plot (indicating the direction
            of going from root -> leaves) or any number, in which case the
            tree is placed in polar coordinates with the provided number used
            as an angle offset. Defaults to 90.
        extend_branches: Extend branch lengths such that the distance from the
            root to every node is the same. If `depth_key` is also provided, then
            only the leaf branches are extended to the deepest leaf.
        angled_branches: Display branches as angled, instead of as just a
            line from the parent to a child.
        add_root: Add a root node so that only one branch connects to the
            start of the tree. This node will have the name `synthetic_root`.
        colorstrip_width: Width of the colorstrip. Width is defined as the
            length in the direction of the leaves. Defaults to 5% of the tree
            depth.
        colorstrip_spacing: Space between consecutive colorstrips. Defaults to
            half of `colorstrip_width`.
        clade_colors: Dictionary containing internal node-color mappings. These
            colors will be used to color all the paths from this node to the
            leaves the provided color.
        continuous_cmap: Colormap to use for continuous variables. Defaults to
            `viridis`.
        vmin: Value representing the lower limit of the color scale. Only applied
            to continuous variables.
        vmax: Value representing the upper limit of the color scale. Only applied
            to continuous variables.
        categorical_cmap: Colormap to use for categorical variables. Defaults to
            `tab10`.
        value_mapping: An optional dictionary containing string values to their
            integer mappings. These mappings are used to assign colors by
            calling the `cmap` with the designated integer mapping. By default,
            the values are assigned pseudo-randomly (whatever order the set()
            operation returns). Only applied for categorical variables.
        random_state: A random state for reproducibility

    Returns:
        Four dictionaries (node coordinates, branch coordinates, node
            colors, branch colors) and a list of colorstrips.
    """
    meta_data = meta_data or []

    # Place tree on the appropriate coordinate system.
    node_coords, branch_coords = utilities.place_tree(
        tree,
        depth_key=depth_key,
        orient=orient,
        extend_branches=extend_branches,
        angled_branches=angled_branches,
        add_root=add_root,
    )

    # Compute first set of anchor coords, which are just the coordinates of
    # all the leaves.
    anchor_coords = {
        node: coords
        for node, coords in node_coords.items() if tree.is_leaf(node)
    }
    is_polar = isinstance(orient, (float, int))
    loc = "polar" if is_polar else orient
    tight_width, tight_height = compute_colorstrip_size(
        node_coords, anchor_coords, loc)
    width = colorstrip_width or tight_width
    spacing = colorstrip_spacing or tight_width / 2

    # Place indel heatmap
    colorstrips = []
    if allele_table is not None:
        heatmap, anchor_coords = create_indel_heatmap(
            allele_table,
            anchor_coords,
            width,
            tight_height,
            spacing,
            loc,
            indel_colors,
            indel_priors,
            random_state,
        )
        colorstrips.extend(heatmap)

    # Any other annotations
    for meta_item in meta_data:
        if meta_item not in tree.cell_meta.columns:
            raise PlottingError(
                "Meta data item not in CassiopeiaTree cell meta.")

        values = tree.cell_meta[meta_item]
        if pd.api.types.is_numeric_dtype(values):
            colorstrip, anchor_coords = create_continuous_colorstrip(
                values.to_dict(),
                anchor_coords,
                width,
                tight_height,
                spacing,
                loc,
                continuous_cmap,
                vmin,
                vmax,
            )

        if pd.api.types.is_string_dtype(values):
            colorstrip, anchor_coords = create_categorical_colorstrip(
                values.to_dict(),
                anchor_coords,
                width,
                tight_height,
                spacing,
                loc,
                categorical_cmap,
                value_mapping,
            )
        colorstrips.append(colorstrip)

    # Clade colors
    node_colors = {}
    branch_colors = {}
    if clade_colors:
        node_colors, branch_colors = create_clade_colors(tree, clade_colors)
    return node_coords, branch_coords, node_colors, branch_colors, colorstrips