Ejemplo n.º 1
0
 def test_hand_solvable_problem_2(self, name, solver):
     """
     Tree topology is just a branch 0->1.
     There are two mutated characters and one unmutated character, i.e.:
         root [state = '000']
         |
         v
         child [state = '011']
     The solution can be verified by hand. The optimization problem is:
         min_{r * t0} log(exp(-r * t0)) + 2 * log(1 - exp(-r * t0))
     The solution is r * t0 = ln(3) ~ 1.098
     (Note that because the depth of the tree is fixed to 1, r * t0 = r * 1
     is the mutation rate.)
     """
     tree = nx.DiGraph()
     tree.add_nodes_from(["0", "1"])
     tree.add_edge("0", "1")
     tree = CassiopeiaTree(tree=tree)
     tree.set_all_character_states({"0": [0, 0, 0], "1": [0, 1, 1]})
     model = IIDExponentialMLE(minimum_branch_length=1e-4, solver=solver)
     model.estimate_branch_lengths(tree)
     self.assertAlmostEqual(tree.get_branch_length("0", "1"), 1.0, places=3)
     self.assertAlmostEqual(tree.get_time("1"), 1.0, places=3)
     self.assertAlmostEqual(tree.get_time("0"), 0.0, places=3)
     self.assertAlmostEqual(model.mutation_rate, np.log(3), places=3)
     self.assertAlmostEqual(model.log_likelihood, -1.910, places=3)
Ejemplo n.º 2
0
 def test_small_tree_with_one_mutation(self, name, solver):
     """
     Perfect binary tree with one mutation at a node 6: Should give very
     short edges 1->3,1->4,0->2.
     The problem can be solved by hand: it trivially reduces to a
     1-dimensional problem:
         min_{r * t0} 2 * log(exp(-r * t0)) + log(1 - exp(-r * t0))
     The solution is r * t0 = ln(1.5) ~ 0.405
     (Note that because the depth of the tree is fixed to 1, r * t0 = r * 1
     is the mutation rate.)
     """
     tree = nx.DiGraph()
     tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]),
     tree.add_edges_from(
         [
             ("0", "1"),
             ("0", "2"),
             ("1", "3"),
             ("1", "4"),
             ("2", "5"),
             ("2", "6"),
         ]
     )
     tree = CassiopeiaTree(tree=tree)
     tree.set_all_character_states(
         {
             "0": [0],
             "1": [0],
             "2": [0],
             "3": [0],
             "4": [0],
             "5": [0],
             "6": [1],
         }
     )
     # Need to make minimum_branch_length be epsilon or else SCS fails...
     model = IIDExponentialMLE(minimum_branch_length=1e-4, solver=solver)
     model.estimate_branch_lengths(tree)
     self.assertAlmostEqual(tree.get_branch_length("0", "1"), 1.0, places=3)
     self.assertAlmostEqual(tree.get_branch_length("0", "2"), 0.0, places=3)
     self.assertAlmostEqual(tree.get_branch_length("1", "3"), 0.0, places=3)
     self.assertAlmostEqual(tree.get_branch_length("1", "4"), 0.0, places=3)
     self.assertAlmostEqual(tree.get_branch_length("2", "5"), 1.0, places=3)
     self.assertAlmostEqual(tree.get_branch_length("2", "6"), 1.0, places=3)
     self.assertAlmostEqual(model.log_likelihood, -1.910, places=3)
     self.assertAlmostEqual(model.mutation_rate, np.log(1.5), places=3)
Ejemplo n.º 3
0
 def test_subtree_collapses_when_no_mutations(self, name, solver):
     """
     A subtree with no mutations should collapse to 0. It reduces the
     problem to the same as in 'test_hand_solvable_problem_1'
     """
     tree = nx.DiGraph()
     tree.add_nodes_from(["0", "1", "2", "3", "4"]),
     tree.add_edges_from([("0", "1"), ("1", "2"), ("1", "3"), ("0", "4")])
     tree = CassiopeiaTree(tree=tree)
     tree.set_all_character_states(
         {"0": [0], "1": [1], "2": [1], "3": [1], "4": [0]}
     )
     model = IIDExponentialMLE(minimum_branch_length=1e-4, solver=solver)
     model.estimate_branch_lengths(tree)
     self.assertAlmostEqual(model.log_likelihood, -1.386, places=3)
     self.assertAlmostEqual(tree.get_branch_length("0", "1"), 1.0, places=3)
     self.assertAlmostEqual(tree.get_branch_length("1", "2"), 0.0, places=3)
     self.assertAlmostEqual(tree.get_branch_length("1", "3"), 0.0, places=3)
     self.assertAlmostEqual(tree.get_branch_length("0", "4"), 1.0, places=3)
     self.assertAlmostEqual(model.mutation_rate, np.log(2), places=3)
Ejemplo n.º 4
0
    def overlay_data(
        self,
        tree: CassiopeiaTree,
        attribute_key: str = "spatial",
    ):
        """Overlays spatial data onto the CassiopeiaTree via Brownian motion.

        Args:
            tree: The CassiopeiaTree to overlay spatial data on to.
            attribute_key: The name of the attribute to save the coordinates as.
                This also serves as the prefix of the coordinates saved into
                the `cell_meta` attribute as `{attribute_key}_i` where i is
                an integer from 0...`dim-1`.
        """
        # Using numpy arrays instead of tuples for easy vector operations
        locations = {tree.root: np.zeros(self.dim)}
        for parent, child in tree.depth_first_traverse_edges(source=tree.root):
            parent_location = locations[parent]
            branch_length = tree.get_branch_length(parent, child)

            locations[child] = parent_location + np.random.normal(
                scale=np.sqrt(2 * self.diffusion_coefficient * branch_length),
                size=self.dim,
            )

        # Scale if desired
        # Note that Python dictionaries preserve order since 3.6
        if self.scale_unit_area:
            all_coordinates = np.array(list(locations.values()))

            # Shift each dimension so that the smallest value is at 0.
            all_coordinates -= all_coordinates.min(axis=0)

            # Scale all dimensions (by the same value) so that all values are
            # between [0, 1]. We don't scale each dimension separately because
            # we want to retain the shape of the distribution.
            all_coordinates /= all_coordinates.max()
            locations = {
                node: coordinates
                for node, coordinates in zip(locations.keys(), all_coordinates)
            }

        # Set node attributes
        for node, loc in locations.items():
            tree.set_attribute(node, attribute_key, tuple(loc))

        # Set cell meta
        cell_meta = (tree.cell_meta.copy() if tree.cell_meta is not None else
                     pd.DataFrame(index=tree.leaves))
        columns = [f"{attribute_key}_{i}" for i in range(self.dim)]
        cell_meta[columns] = np.nan
        for leaf in tree.leaves:
            cell_meta.loc[leaf, columns] = locations[leaf]
        tree.cell_meta = cell_meta
Ejemplo n.º 5
0
    def test_minimum_branch_length(self, name, solver):
        """
        Test that the minimum branch length feature works.

        Same as test_small_tree_with_one_mutation but now we constrain the
        minimum branch length.Should give very short edges 1->3,1->4,0->2
        and edges 0->1,2->5,2->6 close to 1.
        """
        tree = nx.DiGraph()
        tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]),
        tree.add_edges_from(
            [
                ("0", "1"),
                ("0", "2"),
                ("1", "3"),
                ("1", "4"),
                ("2", "5"),
                ("2", "6"),
            ]
        )
        tree = CassiopeiaTree(tree=tree)
        tree.set_all_character_states(
            {
                "0": [0],
                "1": [0],
                "2": [0],
                "3": [0],
                "4": [0],
                "5": [0],
                "6": [1],
            }
        )
        model = IIDExponentialMLE(minimum_branch_length=0.01, solver=solver)
        model.estimate_branch_lengths(tree)
        self.assertAlmostEqual(
            tree.get_branch_length("0", "1"), 0.990, places=3
        )
        self.assertAlmostEqual(
            tree.get_branch_length("0", "2"), 0.010, places=3
        )
        self.assertAlmostEqual(
            tree.get_branch_length("1", "3"), 0.010, places=3
        )
        self.assertAlmostEqual(
            tree.get_branch_length("1", "4"), 0.010, places=3
        )
        self.assertAlmostEqual(
            tree.get_branch_length("2", "5"), 0.990, places=3
        )
        self.assertAlmostEqual(
            tree.get_branch_length("2", "6"), 0.990, places=3
        )
        self.assertAlmostEqual(model.log_likelihood, -1.922, places=3)
        self.assertAlmostEqual(model.mutation_rate, 0.405, places=3)
Ejemplo n.º 6
0
    def test_small_tree_regression(self, name, solver):
        """
        Perfect binary tree with "normal" amount of mutations on each edge.

        Regression test. Cannot be solved by hand. We just check that this
        solution never changes.
        """
        tree = nx.DiGraph()
        tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]),
        tree.add_edges_from(
            [
                ("0", "1"),
                ("0", "2"),
                ("1", "3"),
                ("1", "4"),
                ("2", "5"),
                ("2", "6"),
            ]
        )
        tree = CassiopeiaTree(tree=tree)
        tree.set_all_character_states(
            {
                "0": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                "1": [1, 0, 0, 0, 0, 0, 0, 0, 0, -1],
                "2": [0, 0, 0, 0, 0, 6, 0, 0, 0, -1],
                "3": [1, 2, 0, 0, 0, 0, 0, 0, 0, -1],
                "4": [1, 0, 3, 0, 0, 0, 0, 0, 0, -1],
                "5": [0, 0, 0, 0, 5, 6, 7, 0, 0, -1],
                "6": [0, 0, 0, 4, 0, 6, 0, 8, 9, -1],
            }
        )
        model = IIDExponentialMLE(minimum_branch_length=1e-4, solver=solver)
        model.estimate_branch_lengths(tree)
        self.assertAlmostEqual(model.mutation_rate, 0.378, places=3)
        self.assertAlmostEqual(
            tree.get_branch_length("0", "1"), 0.537, places=3
        )
        self.assertAlmostEqual(
            tree.get_branch_length("0", "2"), 0.219, places=3
        )
        self.assertAlmostEqual(
            tree.get_branch_length("1", "3"), 0.463, places=3
        )
        self.assertAlmostEqual(
            tree.get_branch_length("1", "4"), 0.463, places=3
        )
        self.assertAlmostEqual(
            tree.get_branch_length("2", "5"), 0.781, places=3
        )
        self.assertAlmostEqual(
            tree.get_branch_length("2", "6"), 0.781, places=3
        )
        self.assertAlmostEqual(model.log_likelihood, -22.689, places=3)
def calc_exact_log_full_joint(
    tree: CassiopeiaTree,
    mutation_rate: float,
    birth_rate: float,
    sampling_probability: float,
) -> float:
    """
    Exact log full joint probability computation.

    This method is used for testing the implementation of the model.

    The log full joint probability density of the observed tree topology,
    state vectors, and branch lengths. In other words:
    log P(branch lengths, character states, tree topology)
    Intergrating this function allows computing the marginals and hence
    the posteriors of the times of any internal node in the tree.

    Note that this method is only fast enough for small trees. It's
    run time scales exponentially with the number of internal nodes of the
    tree.

    Args:
        tree: The CassiopeiaTree containing the tree topology and all
            character states.
        node: An internal node of the tree, for which to compute the
            posterior log joint.
        mutation_rate: The mutation rate of the model.
        birth_rate: The birth rate of the model.
        sampling_probability: The sampling probability of the model.

    Returns:
        log P(branch lengths, character states, tree topology)
    """
    tree = deepcopy(tree)
    ll = 0.0
    lam = birth_rate
    r = mutation_rate
    p = sampling_probability
    q_inv = (1.0 - p) / p
    lg = np.log
    e = np.exp
    b = binom
    T = tree.get_max_depth_of_tree()
    for (p, c) in tree.edges:
        t = tree.get_branch_length(p, c)
        # Birth process with subsampling likelihood
        h = T - tree.get_time(p) + tree.get_time(tree.root)
        h_tilde = T - tree.get_time(c) + tree.get_time(tree.root)
        if c in tree.leaves:
            # "Easy" case
            assert h_tilde == 0
            ll += (2.0 * lg(q_inv + 1.0) + lam * h -
                   2.0 * lg(q_inv + e(lam * h)) + lg(sampling_probability))
        else:
            ll += (lg(lam) + lam * h - 2.0 * lg(q_inv + e(lam * h)) +
                   2.0 * lg(q_inv + e(lam * h_tilde)) - lam * h_tilde)
        # Mutation process likelihood
        cuts = len(
            tree.get_mutations_along_edge(p,
                                          c,
                                          treat_missing_as_mutations=False))
        uncuts = tree.get_character_states(c).count(0)
        # Care must be taken here, we might get a nan
        if np.isnan(lg(1 - e(-t * r)) * cuts):
            return -np.inf
        ll += ((-t * r) * uncuts + lg(1 - e(-t * r)) * cuts +
               lg(b(cuts + uncuts, cuts)))
    return ll
Ejemplo n.º 8
0
def calculate_likelihood_continuous(
    tree: CassiopeiaTree,
    use_internal_character_states: bool = False,
    layer: Optional[str] = None,
) -> float:
    """
    Calculates the log likelihood of a tree under a continuous process.

    A wrapper function for `get_lineage_tracing_parameters` and
    `log_likelihood_of_character` under a continuous model of lineage tracing.

    This function acquires the mutation rate, the heritable missing rate, and
    the stochastic missing probability from the tree using
    `get_lineage_tracing_parameters`. The rates are assumed to be instantaneous
    rates. Then, it calculates the log likelihood for each character using
    `log_likelihood_of_character`, and then by assumption that characters
    mutate independently, sums their likelihoods to get the likelihood for the
    tree.

    Here, branch lengths are to be used. We assume that the rates are
    instantaneous rates representing the frequency at which mutation and missing
    data events occur in a period of time. Under this continuous model, we assume
    that the waiting time until a mutation/missing data event is exponentially
    distributed. The probability that an event occurred in time t is thus given
    by the exponential CDF.

    Args:
        tree: The tree on which to calculate likelihood over
        use_internal_character_states: Indicates if internal node
            character states should be assumed to be specified exactly
        layer: Layer to use for the character matrix in estimating parameters.
            If this is None, then the current `character_matrix` variable will
            be used.

    Returns:
        The log likelihood of the tree given the observed character states.

    Raises:
        CassiopeiaError if the tree priors are not populated, or if character
            state annotations are missing at a node.
    """

    if tree.priors is None:
        raise TreeMetricError(
            "Priors must be specified for this tree to calculate the"
            " likelihood.")

    for l in tree.leaves:
        if tree.get_character_states(l) == []:
            raise TreeMetricError(
                "Character states have not been initialized at leaves."
                " Use set_character_states_at_leaves or populate_tree"
                " with the character matrix that specifies the leaf"
                " character states.")

    if use_internal_character_states:
        for i in tree.internal_nodes:
            if tree.get_character_states(i) == []:
                raise TreeMetricError(
                    "Character states empty at internal node. Character"
                    " states must be annotated at each node if internal"
                    " character states are to be used.")

    (
        mutation_rate,
        heritable_missing_rate,
        stochastic_missing_probability,
    ) = get_lineage_tracing_parameters(
        tree,
        True,
        (not use_internal_character_states),
        layer,
    )

    mutation_probability_function_of_time = lambda t: 1 - np.exp(-mutation_rate
                                                                 * t)
    missing_probability_function_of_time = lambda t: 1 - np.exp(
        -heritable_missing_rate * t)
    implicit_root_branch_length = np.mean(
        [tree.get_branch_length(u, v) for u, v in tree.edges])

    return np.sum([
        log_likelihood_of_character(
            tree,
            character,
            use_internal_character_states,
            mutation_probability_function_of_time,
            missing_probability_function_of_time,
            stochastic_missing_probability,
            implicit_root_branch_length,
        ) for character in range(tree.n_character)
    ])
Ejemplo n.º 9
0
def log_likelihood_of_character(
    tree: CassiopeiaTree,
    character: int,
    use_internal_character_states: bool,
    mutation_probability_function_of_time: Callable[[float], float],
    missing_probability_function_of_time: Callable[[float], float],
    stochastic_missing_probability: float,
    implicit_root_branch_length: float,
) -> float:
    """Calculates the log likelihood of a given character on the tree.

    Calculates the log likelihood of a tree given the states at a given
    character in the leaves using Felsenstein's Pruning Algorithm, which sets
    up a recursive relation between the likelihoods of states at nodes for this
    character. The likelihood L(s, n) at a given state s at a given node n is:

    L(s, n) = Π_{n'}(Σ_{s'}(P(s'|s) * L(s', n')))

    for all n' that are children of n, and s' in the state space, with
    P(s'|s) being the transition probability from s to s'. That is,
    the likelihood at a given state at a given node is the product of
    the likelihoods of the states at this character at the children scaled by
    the probability of the current state transitioning to those states. This
    includes the missing state, as specified by `tree.missing_state_indicator`.

    We assume here that mutations are irreversible. Once a character mutates to
    a certain state that character cannot mutate again, with the exception of
    the fact that any non-missing state can mutate to a missing state.
    `mutation_probability_function_of_time` is expected to be a function that
    determine the probability of a mutation occuring given an amount of time.
    To determine the probability of acquiring a given (non-missing) state once
    a mutation occurs, the priors of the tree are used. Likewise,
    `missing_probability_function_of_time` determines the the probability of a
    missing data event occuring given an amount of time.

    The user can choose to use the character states annotated at internal
    nodes. If these are not used, then the likelihood is marginalized over
    all possible internal state characters. If the actual internal states
    are not provided, then the root is assumed to have the unmutated state
    at each character. Additionally, it is assumed that there is a single
    branch leading from the root that represents the roots' lifetime. If
    this branch does not exist and `use_internal_character_states` is set
    to False, then this branch is added with branch length equal to the
    average branch length of this tree.

    Args:
        tree: The tree on which to calculate the likelihood
        character: The index of the character to calculate the likelihood of
        use_internal_character_states: Indicates if internal node
            character states should be assumed to be specified exactly
        mutation_probability_function_of_time: The function defining the
            probability of a lineage acquiring a mutation within a given time
        missing_probability_function_of_time: The function defining the
            probability of a lineage acquiring heritable missing data within a
            given time
        stochastic_missing_probability: The probability that a cell/character
            pair acquires stochastic missing data at the end of the lineage
        implicit_root_branch_length: The length of the implicit root branch.
            Used if the implicit root needs to be added

    Returns:
        The log likelihood of the tree on one character
    """

    # This dictionary uses a nested dictionary structure. Each node is mapped
    # to a dictionary storing the likelihood for each possible state
    # (states that have non-0 likelihood)
    likelihoods_at_nodes = {}

    # Perform a DFS to propagate the likelihood from the leaves
    for n in tree.depth_first_traverse_nodes(postorder=True):
        state_at_n = tree.get_character_states(n)
        # If states are observed, their likelihoods are set to 1
        if tree.is_leaf(n):
            likelihoods_at_nodes[n] = {state_at_n[character]: 0}
            continue

        possible_states = []
        # If internal character states are to be used, then the likelihood
        # for all other states are ignored. Otherwise, marginalize over
        # only states that do not break irreversibility, as all states that
        # do have likelihood of 0
        if use_internal_character_states:
            possible_states = [state_at_n[character]]
        else:
            child_possible_states = []
            for c in [
                    set(likelihoods_at_nodes[child])
                    for child in tree.children(n)
            ]:
                if tree.missing_state_indicator not in c and "&" not in c:
                    child_possible_states.append(c)
            # "&" stands in for any non-missing state (including uncut), and
            # is a possible state when all children are missing, as any
            # state could have occurred at the parent if all missing data
            # events occurred independently. Used to avoid marginalizing
            # over the entire state space.
            if child_possible_states == []:
                possible_states = [
                    "&",
                    tree.missing_state_indicator,
                ]
            else:
                possible_states = list(
                    set.intersection(*child_possible_states))
                if 0 not in possible_states:
                    possible_states.append(0)

        # This stores the likelihood of each possible state at the current node
        likelihoods_per_state_at_n = {}

        # We calculate the likelihood of the states at the current node
        # according to the recurrence relation. For each state, we marginalize
        # over the likelihoods of the states that it could transition to in the
        # daughter nodes
        for s in possible_states:
            likelihood_for_s = 0
            for child in tree.children(n):
                likelihoods_for_s_marginalize_over_s_ = []
                for s_ in likelihoods_at_nodes[child]:
                    likelihood_s_ = (log_transition_probability(
                        tree,
                        character,
                        s,
                        s_,
                        tree.get_branch_length(n, child),
                        mutation_probability_function_of_time,
                        missing_probability_function_of_time,
                    ) + likelihoods_at_nodes[child][s_])
                    # Here we take into account the probability of
                    # stochastic missing data
                    if tree.is_leaf(child):
                        if (s_ == tree.missing_state_indicator
                                and s != tree.missing_state_indicator):
                            likelihood_s_ = np.log(
                                np.exp(likelihood_s_) +
                                (1 - missing_probability_function_of_time(
                                    tree.get_branch_length(n, child))) *
                                stochastic_missing_probability)
                        if s_ != tree.missing_state_indicator:
                            likelihood_s_ += np.log(
                                1 - stochastic_missing_probability)
                    likelihoods_for_s_marginalize_over_s_.append(likelihood_s_)
                likelihood_for_s += scipy.special.logsumexp(
                    np.array(likelihoods_for_s_marginalize_over_s_))
            likelihoods_per_state_at_n[s] = likelihood_for_s

        likelihoods_at_nodes[n] = likelihoods_per_state_at_n

    # If we are not to use the internal state annotations explicitly,
    # then we assume an implicit root where each state is the uncut state (0)
    # Thus, we marginalize over the transition from 0 in the implicit root
    # to all non-0 states in its child
    if not use_internal_character_states:
        # If the implicit root does not exist in the tree, then we impose it,
        # with the length of the branch being specified as
        # `implicit_root_branch_length`. Otherwise, we just use the existing
        # root with a singleton child as the implicit root
        if len(tree.children(tree.root)) != 1:

            likelihood_contribution_from_each_root_state = [
                log_transition_probability(
                    tree,
                    character,
                    0,
                    s_,
                    implicit_root_branch_length,
                    mutation_probability_function_of_time,
                    missing_probability_function_of_time,
                ) + likelihoods_at_nodes[tree.root][s_]
                for s_ in likelihoods_at_nodes[tree.root]
            ]
            likelihood_at_implicit_root = scipy.special.logsumexp(
                likelihood_contribution_from_each_root_state)

            return likelihood_at_implicit_root

        else:
            # Here we account for the edge case in which all of the leaves are
            # missing, in which case the root will have "&" in place of 0. The
            # likelihood at "&" will have the same likelihood as 0 based on the
            # transition rules regarding "&". As "&" is a placeholder when the
            # state is unknown, this can be thought of realizing "&" as 0.
            if 0 not in likelihoods_at_nodes[tree.root]:
                return likelihoods_at_nodes[tree.root]["&"]
            else:
                # Otherwise, we return the likelihood of the 0 state at the
                # existing implicit root
                return likelihoods_at_nodes[tree.root][0]

    # If we use the internal state annotations explicitly, then we return
    # the likelihood of the state annotated at this character at the root
    else:
        return list(likelihoods_at_nodes[tree.root].values())[0]