def calc_exact_log_full_joint(
    tree: CassiopeiaTree,
    mutation_rate: float,
    birth_rate: float,
    sampling_probability: float,
) -> float:
    """
    Exact log full joint probability computation.

    This method is used for testing the implementation of the model.

    The log full joint probability density of the observed tree topology,
    state vectors, and branch lengths. In other words:
    log P(branch lengths, character states, tree topology)
    Intergrating this function allows computing the marginals and hence
    the posteriors of the times of any internal node in the tree.

    Note that this method is only fast enough for small trees. It's
    run time scales exponentially with the number of internal nodes of the
    tree.

    Args:
        tree: The CassiopeiaTree containing the tree topology and all
            character states.
        node: An internal node of the tree, for which to compute the
            posterior log joint.
        mutation_rate: The mutation rate of the model.
        birth_rate: The birth rate of the model.
        sampling_probability: The sampling probability of the model.

    Returns:
        log P(branch lengths, character states, tree topology)
    """
    tree = deepcopy(tree)
    ll = 0.0
    lam = birth_rate
    r = mutation_rate
    p = sampling_probability
    q_inv = (1.0 - p) / p
    lg = np.log
    e = np.exp
    b = binom
    T = tree.get_max_depth_of_tree()
    for (p, c) in tree.edges:
        t = tree.get_branch_length(p, c)
        # Birth process with subsampling likelihood
        h = T - tree.get_time(p) + tree.get_time(tree.root)
        h_tilde = T - tree.get_time(c) + tree.get_time(tree.root)
        if c in tree.leaves:
            # "Easy" case
            assert h_tilde == 0
            ll += (2.0 * lg(q_inv + 1.0) + lam * h -
                   2.0 * lg(q_inv + e(lam * h)) + lg(sampling_probability))
        else:
            ll += (lg(lam) + lam * h - 2.0 * lg(q_inv + e(lam * h)) +
                   2.0 * lg(q_inv + e(lam * h_tilde)) - lam * h_tilde)
        # Mutation process likelihood
        cuts = len(
            tree.get_mutations_along_edge(p,
                                          c,
                                          treat_missing_as_mutations=False))
        uncuts = tree.get_character_states(c).count(0)
        # Care must be taken here, we might get a nan
        if np.isnan(lg(1 - e(-t * r)) * cuts):
            return -np.inf
        ll += ((-t * r) * uncuts + lg(1 - e(-t * r)) * cuts +
               lg(b(cuts + uncuts, cuts)))
    return ll
Esempio n. 2
0
    def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None:
        r"""
        MLE under a model of IID memoryless CRISPR/Cas9 mutations.

        The only caveat is that this method raises an IIDExponentialMLEError
        if the underlying convex optimization solver fails, or a
        ValueError if the character matrix is degenerate (fully mutated,
        or fully unmutated).

        Raises:
            IIDExponentialMLEError
            ValueError
        """
        # Extract parameters
        minimum_branch_length = self._minimum_branch_length
        solver = self._solver
        verbose = self._verbose

        # # # # # Check that the character has at least one mutation # # # # #
        if (tree.character_matrix == 0).all().all():
            raise ValueError(
                "The character matrix has no mutations. Please check your data."
            )

        # # # # # Check that the character is not saturated # # # # #
        if (tree.character_matrix != 0).all().all():
            raise ValueError(
                "The character matrix is fully mutated. The MLE does not "
                "exist. Please check your data.")

        # # # # # Create variables of the optimization problem # # # # #
        r_X_t_variables = dict([(node_id, cp.Variable(name=f"r_X_t_{node_id}"))
                                for node_id in tree.nodes])

        # # # # # Create constraints of the optimization problem # # # # #
        a_leaf = tree.leaves[0]
        root = tree.root
        root_has_time_0_constraint = [r_X_t_variables[root] == 0]
        minimum_branch_length_constraints = [
            r_X_t_variables[child] >= r_X_t_variables[parent] +
            minimum_branch_length * r_X_t_variables[a_leaf]
            for (parent, child) in tree.edges
        ]
        ultrametric_constraints = [
            r_X_t_variables[leaf] == r_X_t_variables[a_leaf]
            for leaf in tree.leaves if leaf != a_leaf
        ]
        all_constraints = (root_has_time_0_constraint +
                           minimum_branch_length_constraints +
                           ultrametric_constraints)

        # # # # # Compute the log-likelihood # # # # #
        log_likelihood = 0
        for (parent, child) in tree.edges:
            edge_length = r_X_t_variables[child] - r_X_t_variables[parent]
            num_unmutated = len(
                tree.get_unmutated_characters_along_edge(parent, child))
            num_mutated = len(
                tree.get_mutations_along_edge(
                    parent, child, treat_missing_as_mutations=False))
            log_likelihood += num_unmutated * (-edge_length)
            log_likelihood += num_mutated * cp.log(
                1 - cp.exp(-edge_length - 1e-5)  # We add eps for stability.
            )

        # # # # # Solve the problem # # # # #
        obj = cp.Maximize(log_likelihood)
        prob = cp.Problem(obj, all_constraints)
        try:
            prob.solve(solver=solver, verbose=verbose)
        except cp.SolverError:  # pragma: no cover
            raise IIDExponentialMLEError("Third-party solver failed")

        # # # # # Extract the mutation rate # # # # #
        self._mutation_rate = float(r_X_t_variables[a_leaf].value)
        if self._mutation_rate < 1e-8 or self._mutation_rate > 15.0:
            raise IIDExponentialMLEError(
                "The solver failed when it shouldn't have.")

        # # # # # Extract the log-likelihood # # # # #
        log_likelihood = float(log_likelihood.value)
        if np.isnan(log_likelihood):
            log_likelihood = -np.inf
        self._log_likelihood = log_likelihood

        # # # # # Populate the tree with the estimated branch lengths # # # # #
        times = {
            node: float(r_X_t_variables[node].value) / self._mutation_rate
            for node in tree.nodes
        }
        # Make sure that the root has time 0 (avoid epsilons)
        times[tree.root] = 0.0
        # We smooth out epsilons that might make a parent's time greater
        # than its child (which can happen if minimum_branch_length=0)
        for (parent, child) in tree.depth_first_traverse_edges():
            times[child] = max(times[parent], times[child])
        tree.set_times(times)
Esempio n. 3
0
def calculate_parsimony(
    tree: CassiopeiaTree,
    infer_ancestral_characters: bool = False,
    treat_missing_as_mutation: bool = False,
) -> int:
    """
    Calculates the number of mutations that have occurred on a tree.

    Calculates the parsimony, defined as the number of character/state
    mutations that occur on edges of the tree, from the character state
    annotations at the nodes. A mutation is said to have occurred on an
    edge if a state is present at a character at the child node and this
    state is not in the parent node.

    If `infer_ancestral_characters` is set to True, then the internal
    nodes' character states are inferred by Camin-Sokal Parsimony from the
    current character states at the leaves. Use
    `tree.set_character_states_at_leaves` to use a different layer to infer
    ancestral states. Otherwise, the current annotations at the internal
    states are used. If `treat_missing_as_mutations` is set to True, then
    transitions from a non-missing state to a missing state are counted in
    the parsimony calculation. Otherwise, they are not included.

    Args:
        tree: The tree to calculate parsimony over
        infer_ancestral_characters: Whether to infer the ancestral
            characters states of the tree
        treat_missing_as_mutations: Whether to treat missing states as
            mutations

    Returns:
        The number of mutations that have occurred on the tree

    Raises:
        TreeMetricError if the tree has not been initialized or if
            a node does not have character states initialized
    """

    if infer_ancestral_characters:
        tree.reconstruct_ancestral_characters()

    parsimony = 0

    if tree.get_character_states(tree.root) == []:
        raise TreeMetricError(
            f"Character states empty at internal node. Annotate"
            " character states or infer ancestral characters by"
            " setting infer_ancestral_characters=True.")

    for u, v in tree.depth_first_traverse_edges():
        if tree.get_character_states(v) == []:
            if tree.is_leaf(v):
                raise TreeMetricError(
                    "Character states have not been initialized at leaves."
                    " Use set_character_states_at_leaves or populate_tree"
                    " with the character matrix that specifies the leaf"
                    " character states.")
            else:
                raise TreeMetricError(
                    f"Character states empty at internal node. Annotate"
                    " character states or infer ancestral characters by"
                    " setting infer_ancestral_characters=True.")

        parsimony += len(
            tree.get_mutations_along_edge(u, v, treat_missing_as_mutation))

    return parsimony