def simulate_tree(self, ) -> CassiopeiaTree:
        """Simulates a complete binary tree.

        Returns:
            A CassiopeiaTree with the tree topology initialized with the
            simulated tree
        """
        def node_name_generator() -> Generator[str, None, None]:
            """Generates unique node names for the tree."""
            i = 0
            while True:
                yield str(i)
                i += 1

        names = node_name_generator()

        tree = nx.balanced_tree(2, self.depth, create_using=nx.DiGraph)
        mapping = {"root": next(names)}
        mapping.update({node: next(names) for node in tree.nodes})
        # Add root, which indicates the initiating cell
        tree.add_edge("root", 0)
        nx.relabel_nodes(tree, mapping, copy=False)
        cassiopeia_tree = CassiopeiaTree(tree=tree)

        # Initialize branch lengths
        time_dict = {
            node: cassiopeia_tree.get_time(node) / (self.depth + 1)
            for node in cassiopeia_tree.nodes
        }
        cassiopeia_tree.set_times(time_dict)
        return cassiopeia_tree
示例#2
0
def extract_tree_statistics(
    tree: CassiopeiaTree, ) -> Tuple[List[float], int, bool]:
    """A helper function for testing simulated trees.

    Outputs the total lived time for each extant lineage, the number of extant
    lineages, and whether the tree has the expected node degrees (to ensure
    unifurcations were collapsed).

    Args:
        tree: The tree to test

    Returns:
        The total time lived for each leaf, the number of leaves, and if the
        degrees only have degree 0 or 2
    """

    times = []
    out_degrees = []
    for i in tree.nodes:
        if tree.is_leaf(i):
            times.append(tree.get_time(i))
        out_degrees.append(len(tree.children(i)))
    out_degrees.pop(0)

    correct_degrees = all(x == 2 or x == 0 for x in out_degrees)

    return times, len(times), correct_degrees
示例#3
0
 def test_bad_number_of_samples(self):
     tree = CassiopeiaTree(
         tree=self.test_network, character_matrix=self.character_matrix
     )
     tree_no_character_matrix = CassiopeiaTree(tree=self.test_network)
     with self.assertRaises(LeafSubsamplerError):
         sampler = SupercellularSampler(number_of_merges=10)
         sampler.subsample_leaves(tree)
     with self.assertRaises(LeafSubsamplerError):
         sampler = SupercellularSampler(number_of_merges=0)
         sampler.subsample_leaves(tree)
     with self.assertRaises(CassiopeiaTreeError):
         sampler = SupercellularSampler(number_of_merges=2)
         sampler.subsample_leaves(tree_no_character_matrix)
    def test_subsample_balanced_tree(self):
        balanced_tree = nx.balanced_tree(2, 3, create_using=nx.DiGraph)
        balanced_tree = nx.relabel_nodes(
            balanced_tree,
            dict([(i, "node" + str(i)) for i in balanced_tree.nodes]),
        )
        balanced_tree.add_node("node15")
        balanced_tree.add_edge("node15", "node0")
        tree = CassiopeiaTree(tree=balanced_tree)

        np.random.seed(10)
        uni = UniformLeafSubsampler(number_of_leaves=3)
        res = uni.subsample_leaves(tree=tree, keep_singular_root_edge=False)
        expected_edges = [
            ("node15", "node8"),
            ("node15", "node5"),
            ("node5", "node11"),
            ("node5", "node12"),
        ]
        self.assertEqual(set(res.edges), set(expected_edges))

        np.random.seed(10)
        uni = UniformLeafSubsampler(ratio=0.65)
        res = uni.subsample_leaves(tree=tree, keep_singular_root_edge=False)
        expected_edges = [
            ("node15", "node2"),
            ("node15", "node3"),
            ("node2", "node14"),
            ("node2", "node5"),
            ("node5", "node11"),
            ("node5", "node12"),
            ("node3", "node7"),
            ("node3", "node8"),
        ]
        self.assertEqual(set(res.edges), set(expected_edges))
 def test_bad_number_of_samples(self):
     balanced_tree = nx.balanced_tree(2, 3, create_using=nx.DiGraph)
     tree = CassiopeiaTree(tree=balanced_tree)
     with self.assertRaises(LeafSubsamplerError):
         uniform_sampler = UniformLeafSubsampler(number_of_leaves=0)
         uniform_sampler.subsample_leaves(tree)
     with self.assertRaises(LeafSubsamplerError):
         uniform_sampler = UniformLeafSubsampler(ratio=0.0001)
         uniform_sampler.subsample_leaves(tree)
示例#6
0
    def setUp(self):

        # --------------------- General NJ ---------------------
        cm = pd.DataFrame.from_dict(
            {
                "a": [0, 1, 2],
                "b": [1, 1, 2],
                "c": [2, 2, 2],
                "d": [1, 1, 1],
                "e": [0, 0, 0],
            },
            orient="index",
            columns=["x1", "x2", "x3"],
        )

        delta = pd.DataFrame.from_dict(
            {
                "a": [0, 17, 21, 31, 23],
                "b": [17, 0, 30, 34, 21],
                "c": [21, 30, 0, 28, 39],
                "d": [31, 34, 28, 0, 43],
                "e": [23, 21, 39, 43, 0],
            },
            orient="index",
            columns=["a", "b", "c", "d", "e"],
        )

        self.basic_dissimilarity_map = delta
        self.basic_tree = CassiopeiaTree(character_matrix=cm,
                                         dissimilarity_map=delta)

        self.upgma_solver = UPGMASolver()

        # ---------------- Lineage Tracing NJ ----------------

        pp_cm = pd.DataFrame.from_dict(
            {
                "a": [1, 1, 0],
                "b": [1, 2, 0],
                "c": [1, 2, 1],
                "d": [2, 0, 0],
                "e": [2, 0, 2],
            },
            orient="index",
            columns=["x1", "x2", "x3"],
        )

        self.pp_tree = CassiopeiaTree(character_matrix=pp_cm)

        self.upgma_solver_delta = UPGMASolver(
            dissimilarity_function=dissimilarity_functions.
            weighted_hamming_distance)

        # ------------- CM with Duplicates and Missing Data -----------------------
        duplicates_cm = pd.DataFrame.from_dict(
            {
                "a": [1, -1, 0],
                "b": [1, 2, 1],
                "c": [1, -1, 1],
                "d": [2, 0, -1],
                "e": [2, 0, 2],
                "f": [2, 0, 2],
            },
            orient="index",
            columns=["x1", "x2", "x3"],
        )

        self.duplicate_tree = CassiopeiaTree(character_matrix=duplicates_cm)

        # -------------  Hamming dissimilarity with weights  ------------
        priors = {
            0: {
                1: 0.5,
                2: 0.5
            },
            1: {
                1: 0.2,
                2: 0.8
            },
            2: {
                1: 0.3,
                2: 0.7
            }
        }
        self.pp_tree_priors = CassiopeiaTree(character_matrix=pp_cm,
                                             priors=priors)
        self.upgma_solver_modified = UPGMASolver(
            dissimilarity_function=dissimilarity_functions.
            weighted_hamming_distance)
示例#7
0
class TestUPGMASolver(unittest.TestCase):
    def setUp(self):

        # --------------------- General NJ ---------------------
        cm = pd.DataFrame.from_dict(
            {
                "a": [0, 1, 2],
                "b": [1, 1, 2],
                "c": [2, 2, 2],
                "d": [1, 1, 1],
                "e": [0, 0, 0],
            },
            orient="index",
            columns=["x1", "x2", "x3"],
        )

        delta = pd.DataFrame.from_dict(
            {
                "a": [0, 17, 21, 31, 23],
                "b": [17, 0, 30, 34, 21],
                "c": [21, 30, 0, 28, 39],
                "d": [31, 34, 28, 0, 43],
                "e": [23, 21, 39, 43, 0],
            },
            orient="index",
            columns=["a", "b", "c", "d", "e"],
        )

        self.basic_dissimilarity_map = delta
        self.basic_tree = CassiopeiaTree(character_matrix=cm,
                                         dissimilarity_map=delta)

        self.upgma_solver = UPGMASolver()

        # ---------------- Lineage Tracing NJ ----------------

        pp_cm = pd.DataFrame.from_dict(
            {
                "a": [1, 1, 0],
                "b": [1, 2, 0],
                "c": [1, 2, 1],
                "d": [2, 0, 0],
                "e": [2, 0, 2],
            },
            orient="index",
            columns=["x1", "x2", "x3"],
        )

        self.pp_tree = CassiopeiaTree(character_matrix=pp_cm)

        self.upgma_solver_delta = UPGMASolver(
            dissimilarity_function=dissimilarity_functions.
            weighted_hamming_distance)

        # ------------- CM with Duplicates and Missing Data -----------------------
        duplicates_cm = pd.DataFrame.from_dict(
            {
                "a": [1, -1, 0],
                "b": [1, 2, 1],
                "c": [1, -1, 1],
                "d": [2, 0, -1],
                "e": [2, 0, 2],
                "f": [2, 0, 2],
            },
            orient="index",
            columns=["x1", "x2", "x3"],
        )

        self.duplicate_tree = CassiopeiaTree(character_matrix=duplicates_cm)

        # -------------  Hamming dissimilarity with weights  ------------
        priors = {
            0: {
                1: 0.5,
                2: 0.5
            },
            1: {
                1: 0.2,
                2: 0.8
            },
            2: {
                1: 0.3,
                2: 0.7
            }
        }
        self.pp_tree_priors = CassiopeiaTree(character_matrix=pp_cm,
                                             priors=priors)
        self.upgma_solver_modified = UPGMASolver(
            dissimilarity_function=dissimilarity_functions.
            weighted_hamming_distance)

    def test_constructor(self):

        self.assertIsNotNone(self.upgma_solver_delta.dissimilarity_function)
        self.assertIsNotNone(self.basic_tree.get_dissimilarity_map())

    def test_find_cherry(self):

        cherry = self.upgma_solver.find_cherry(
            self.basic_dissimilarity_map.values)
        delta = self.basic_dissimilarity_map
        node_i, node_j = (delta.index[cherry[0]], delta.index[cherry[1]])

        self.assertIn((node_i, node_j), [("a", "b"), ("b", "a")])

    def test_update_dissimilarity_map(self):

        delta = self.basic_dissimilarity_map

        cherry = self.upgma_solver.find_cherry(delta.values)
        node_i, node_j = (delta.index[cherry[0]], delta.index[cherry[1]])

        delta = self.upgma_solver.update_dissimilarity_map(
            delta, (node_i, node_j), "ab")

        expected_delta = pd.DataFrame.from_dict(
            {
                "ab": [0, 25.5, 32.5, 22],
                "c": [25.5, 0, 28, 39],
                "d": [32.5, 28, 0, 43],
                "e": [22, 39, 43, 0],
            },
            orient="index",
            columns=["ab", "c", "d", "e"],
        )

        for sample in expected_delta.index:
            for sample2 in expected_delta.index:
                self.assertEqual(
                    delta.loc[sample, sample2],
                    expected_delta.loc[sample, sample2],
                )

        cherry = self.upgma_solver.find_cherry(delta.values)
        node_i, node_j = (delta.index[cherry[0]], delta.index[cherry[1]])

        delta = self.upgma_solver.update_dissimilarity_map(
            delta, (node_i, node_j), "abe")

        expected_delta = pd.DataFrame.from_dict(
            {
                "abe": [0, 30, 36],
                "c": [30, 0, 28],
                "d": [36, 28, 0]
            },
            orient="index",
            columns=["abe", "c", "d"],
        )

        for sample in expected_delta.index:
            for sample2 in expected_delta.index:
                self.assertEqual(
                    delta.loc[sample, sample2],
                    expected_delta.loc[sample, sample2],
                )

    def test_basic_solver(self):

        self.upgma_solver.solve(self.basic_tree)

        # test leaves exist in tree
        _leaves = self.basic_tree.leaves

        self.assertEqual(len(_leaves), self.basic_dissimilarity_map.shape[0])
        for _leaf in _leaves:
            self.assertIn(_leaf, self.basic_dissimilarity_map.index.values)

        # test for expected number of edges
        edges = list(self.basic_tree.edges)
        self.assertEqual(len(edges), 8)

        # test relationships between samples
        expected_tree = nx.DiGraph()
        expected_tree.add_edges_from([
            ("5", "a"),
            ("5", "b"),
            ("6", "5"),
            ("6", "e"),
            ("7", "c"),
            ("7", "d"),
            ("root", "6"),
            ("root", "7"),
        ])

        observed_tree = self.basic_tree.get_tree_topology()
        triplets = itertools.combinations(["a", "b", "c", "d", "e"], 3)
        for triplet in triplets:

            expected_triplet = find_triplet_structure(triplet, expected_tree)
            observed_triplet = find_triplet_structure(triplet, observed_tree)
            self.assertEqual(expected_triplet, observed_triplet)

        # compare tree distances
        observed_tree = observed_tree.to_undirected()
        expected_tree = expected_tree.to_undirected()
        for i in range(len(_leaves)):
            sample1 = _leaves[i]
            for j in range(i + 1, len(_leaves)):
                sample2 = _leaves[j]
                self.assertEqual(
                    nx.shortest_path_length(observed_tree, sample1, sample2),
                    nx.shortest_path_length(expected_tree, sample1, sample2),
                )

    def test_upgma_solver_weights(self):
        self.upgma_solver_modified.solve(self.pp_tree_priors)
        initial_d_map = self.pp_tree_priors.get_dissimilarity_map()
        expected_dissimilarity = (-np.log(0.2) - np.log(0.8)) / 3
        self.assertEqual(initial_d_map.loc["a", "b"], expected_dissimilarity)

        observed_tree = self.pp_tree_priors.get_tree_topology()

        expected_tree = nx.DiGraph()
        expected_tree.add_edges_from([
            ("root", "a"),
            ("root", "7"),
            ("7", "8"),
            ("7", "9"),
            ("8", "d"),
            ("8", "e"),
            ("9", "b"),
            ("9", "c"),
        ])

        triplets = itertools.combinations(["a", "b", "c", "d", "e"], 3)
        for triplet in triplets:
            expected_triplet = find_triplet_structure(triplet, expected_tree)
            observed_triplet = find_triplet_structure(triplet, observed_tree)
            self.assertEqual(expected_triplet, observed_triplet)

        self.upgma_solver_modified.solve(self.pp_tree_priors,
                                         collapse_mutationless_edges=True)
        observed_tree = self.pp_tree_priors.get_tree_topology()

        expected_tree = nx.DiGraph()
        expected_tree.add_edges_from([
            ("root", "a"),
            ("root", "8"),
            ("root", "9"),
            ("8", "d"),
            ("8", "e"),
            ("9", "b"),
            ("9", "c"),
        ])

        triplets = itertools.combinations(["a", "b", "c", "d", "e"], 3)
        for triplet in triplets:
            expected_triplet = find_triplet_structure(triplet, expected_tree)
            observed_triplet = find_triplet_structure(triplet, observed_tree)
            self.assertEqual(expected_triplet, observed_triplet)

    def test_pp_solver(self):
        self.upgma_solver_delta.solve(self.pp_tree)
        initial_d_map = self.pp_tree.get_dissimilarity_map()
        expected_dissimilarity = 1 / 3
        self.assertEqual(initial_d_map.loc["d", "e"], expected_dissimilarity)

        observed_tree = self.pp_tree.get_tree_topology()

        expected_tree = nx.DiGraph()
        expected_tree.add_edges_from([
            ("root", "8"),
            ("root", "7"),
            ("9", "7"),
            ("7", "6"),
            ("7", "a"),
            ("6", "b"),
            ("6", "c"),
            ("8", "e"),
            ("8", "d"),
        ])

        triplets = itertools.combinations(["a", "b", "c", "d", "e"], 3)
        for triplet in triplets:
            expected_triplet = find_triplet_structure(triplet, expected_tree)
            observed_triplet = find_triplet_structure(triplet, observed_tree)
            self.assertEqual(expected_triplet, observed_triplet)

        self.upgma_solver_delta.solve(self.pp_tree)
        observed_tree = self.pp_tree.get_tree_topology()
        triplets = itertools.combinations(["a", "b", "c", "d", "e"], 3)
        for triplet in triplets:
            expected_triplet = find_triplet_structure(triplet, expected_tree)
            observed_triplet = find_triplet_structure(triplet, observed_tree)
            self.assertEqual(expected_triplet, observed_triplet)

    def test_duplicate(self):
        # In this case, we see that the missing data can break up a duplicate
        # pair if the behavior is to ignore missing data

        self.upgma_solver_delta.solve(self.duplicate_tree)
        observed_tree = self.duplicate_tree.get_tree_topology()
        initial_d_map = self.duplicate_tree.get_dissimilarity_map()
        expected_dissimilarity = 1.5
        self.assertEqual(initial_d_map.loc["b", "d"], expected_dissimilarity)

        expected_tree = nx.DiGraph()
        expected_tree.add_edges_from([
            ("root", "9"),
            ("root", "8"),
            ("9", "a"),
            ("9", "6"),
            ("6", "b"),
            ("6", "c"),
            ("8", "7"),
            ("8", "f"),
            ("7", "d"),
            ("7", "e"),
        ])
        triplets = itertools.combinations(["a", "b", "c", "d", "e", "f"], 3)
        for triplet in triplets:
            expected_triplet = find_triplet_structure(triplet, expected_tree)
            observed_triplet = find_triplet_structure(triplet, observed_tree)
            self.assertEqual(expected_triplet, observed_triplet)
    def test_subsample_custom_tree(self):
        custom_tree = nx.DiGraph()
        custom_tree.add_nodes_from(["node" + str(i) for i in range(17)])
        custom_tree.add_edges_from(
            [
                ("node16", "node0"),
                ("node0", "node1"),
                ("node0", "node2"),
                ("node1", "node3"),
                ("node1", "node4"),
                ("node2", "node5"),
                ("node2", "node6"),
                ("node4", "node7"),
                ("node4", "node8"),
                ("node6", "node9"),
                ("node6", "node10"),
                ("node7", "node11"),
                ("node11", "node12"),
                ("node11", "node13"),
                ("node9", "node14"),
                ("node9", "node15"),
            ]
        )
        tree = CassiopeiaTree(tree=custom_tree)
        for u, v in tree.edges:
            tree.set_branch_length(u, v, 1.5)

        np.random.seed(10)
        uni = UniformLeafSubsampler(ratio=0.5)
        res = uni.subsample_leaves(tree=tree)

        expected_edges = {
            ("node16", "node0"): 1.5,
            ("node0", "node1"): 1.5,
            ("node0", "node5"): 3.0,
            ("node1", "node3"): 1.5,
            ("node1", "node11"): 4.5,
            ("node11", "node12"): 1.5,
            ("node11", "node13"): 1.5,
        }
        self.assertEqual(set(res.edges), set(expected_edges.keys()))
        for u, v in res.edges:
            self.assertEqual(
                res.get_branch_length(u, v), expected_edges[(u, v)]
            )

        expected_times = {
            "node16": 0.0,
            "node0": 1.5,
            "node1": 3.0,
            "node5": 4.5,
            "node3": 4.5,
            "node11": 7.5,
            "node12": 9.0,
            "node13": 9.0,
        }
        for u in res.nodes:
            self.assertEqual(res.get_time(u), expected_times[u])

        np.random.seed(11)
        uni = UniformLeafSubsampler(number_of_leaves=6)
        res = uni.subsample_leaves(tree=tree, keep_singular_root_edge=True)

        expected_edges = [
            ("node16", "node0"),
            ("node0", "node1"),
            ("node0", "node2"),
            ("node1", "node3"),
            ("node1", "node11"),
            ("node11", "node12"),
            ("node11", "node13"),
            ("node2", "node5"),
            ("node2", "node6"),
            ("node6", "node10"),
            ("node6", "node15"),
        ]
        self.assertEqual(set(res.edges), set(expected_edges))
示例#9
0
    def simulate_tree(self, ) -> CassiopeiaTree:
        """Simulates trees from a general birth/death process with fitness.

        A forward-time birth/death process is simulated by tracking a series of
        lineages and sampling event waiting times for each lineage. Each
        lineage draws death waiting times from the same distribution, but
        maintains its own birth scale parameter that determines the shape of
        its birth waiting time distribution. At each division event, fitness
        mutation events are sampled, and the birth scale parameter is scaled by
        their multiplicative coefficients. This updated birth scale passed
        onto successors.

        Returns:
            A CassiopeiaTree with the tree topology initialized with the
            simulated tree

        Raises:
            TreeSimulatorError if all lineages die before a stopping condition
        """
        def node_name_generator() -> Generator[str, None, None]:
            """Generates unique node names for the tree."""
            i = 0
            while True:
                yield str(i)
                i += 1

        names = node_name_generator()

        # Set the seed
        if self.random_seed:
            np.random.seed(self.random_seed)

        # Instantiate the implicit root
        tree = nx.DiGraph()
        root = next(names)
        tree.add_node(root)
        tree.nodes[root]["birth_scale"] = self.initial_birth_scale
        tree.nodes[root]["time"] = 0
        current_lineages = PriorityQueue()
        # Records the nodes that are observed at the end of the experiment
        observed_nodes = []
        starting_lineage = {
            "id": root,
            "birth_scale": self.initial_birth_scale,
            "total_time": 0,
            "active": True,
        }

        # Sample the waiting time until the first division
        self.sample_lineage_event(starting_lineage, current_lineages, tree,
                                  names, observed_nodes)

        # Perform the process until there are no active extant lineages left
        while not current_lineages.empty():
            # If number of extant lineages is the stopping criterion, at the
            # first instance of having n extant tips, stop the experiment
            # and set the total lineage time for each lineage to be equal to
            # the minimum, to produce ultrametric trees. Also, the birth_scale
            # parameter of each leaf is rolled back to equal its parent's.
            if self.num_extant:
                if current_lineages.qsize() == self.num_extant:
                    remaining_lineages = []
                    while not current_lineages.empty():
                        _, _, lineage = current_lineages.get()
                        remaining_lineages.append(lineage)
                    min_total_time = remaining_lineages[0]["total_time"]
                    for lineage in remaining_lineages:
                        parent = list(tree.predecessors(lineage["id"]))[0]
                        tree.nodes[lineage["id"]]["time"] += (
                            min_total_time - lineage["total_time"])
                        tree.nodes[lineage["id"]]["birth_scale"] = tree.nodes[
                            parent]["birth_scale"]
                        observed_nodes.append(lineage["id"])
                    break
            # Pop the minimum age lineage to simulate forward time
            _, _, lineage = current_lineages.get()
            # If the lineage is no longer active, just remove it from the queue.
            # This represents the time at which the lineage dies.
            if lineage["active"]:
                for _ in range(2):
                    self.sample_lineage_event(lineage, current_lineages, tree,
                                              names, observed_nodes)

        cassiopeia_tree = CassiopeiaTree(tree=tree)
        time_dictionary = {}
        for i in tree.nodes:
            time_dictionary[i] = tree.nodes[i]["time"]
        cassiopeia_tree.set_times(time_dictionary)

        # Prune dead lineages and collapse resulting unifurcations
        to_remove = list(set(cassiopeia_tree.leaves) - set(observed_nodes))
        cassiopeia_tree.remove_leaves_and_prune_lineages(to_remove)
        if self.collapse_unifurcations and len(cassiopeia_tree.nodes) > 1:
            cassiopeia_tree.collapse_unifurcations(source="1")

        # If only implicit root remains after pruning dead lineages, error
        if len(cassiopeia_tree.nodes) == 1:
            raise TreeSimulatorError(
                "All lineages died before stopping condition")

        return cassiopeia_tree
示例#10
0
    def setUp(self):

        # --------------------- General NJ ---------------------
        cm = pd.DataFrame.from_dict(
            {
                "a": [0, 1, 2],
                "b": [1, 1, 2],
                "c": [2, 2, 2],
                "d": [1, 1, 1],
                "e": [0, 0, 0],
            },
            orient="index",
            columns=["x1", "x2", "x3"],
        )

        delta = pd.DataFrame.from_dict(
            {
                "a": [0, 2, 1, 1, 0],
                "b": [2, 0, 1, 2, 0],
                "c": [1, 1, 0, 0, 0],
                "d": [1, 2, 0, 0, 0],
                "e": [0, 0, 0, 0, 0],
            },
            orient="index",
            columns=["a", "b", "c", "d", "e"],
        )

        self.basic_similarity_map = delta
        self.basic_tree = CassiopeiaTree(character_matrix=cm,
                                         dissimilarity_map=delta)

        self.smj_solver = SharedMutationJoiningSolver(
            similarity_function=dissimilarity_functions.
            hamming_similarity_without_missing)
        self.smj_solver_no_numba = SharedMutationJoiningSolver(
            similarity_function=partial(
                dissimilarity_functions.cluster_dissimilarity,
                dissimilarity_functions.hamming_similarity_without_missing,
            ))

        # ---------------- Lineage Tracing NJ ----------------

        pp_cm = pd.DataFrame.from_dict(
            {
                "a": [1, 2, 2],
                "b": [1, 2, 1],
                "c": [1, 2, 0],
                "d": [2, 0, 0],
                "e": [2, 0, 2],
            },
            orient="index",
            columns=["x1", "x2", "x3"],
        )

        self.pp_tree = CassiopeiaTree(character_matrix=pp_cm)

        self.smj_solver_pp = SharedMutationJoiningSolver(
            similarity_function=dissimilarity_functions.
            hamming_similarity_without_missing)

        # ------------- CM with Duplicates and Missing Data -----------------------
        duplicates_cm = pd.DataFrame.from_dict(
            {
                "a": [1, -1, 0],
                "b": [2, -1, 2],
                "c": [2, 0, 2],
                "d": [2, 0, -1],
                "e": [2, 0, 2],
                "f": [2, -1, 2],
            },
            orient="index",
            columns=["x1", "x2", "x3"],
        )

        self.duplicate_tree = CassiopeiaTree(character_matrix=duplicates_cm)

        # ------------- Hamming similarity with weights ------------
        priors = {
            0: {
                1: 0.5,
                2: 0.5
            },
            1: {
                1: 0.2,
                2: 0.8
            },
            2: {
                1: 0.9,
                2: 0.1
            }
        }
        self.pp_tree_priors = CassiopeiaTree(character_matrix=pp_cm,
                                             priors=priors)
        self.smj_solver_modified_pp = SharedMutationJoiningSolver(
            similarity_function=dissimilarity_functions.
            hamming_similarity_without_missing)
示例#11
0
class TestSharedMutationJoiningSolver(unittest.TestCase):
    def setUp(self):

        # --------------------- General NJ ---------------------
        cm = pd.DataFrame.from_dict(
            {
                "a": [0, 1, 2],
                "b": [1, 1, 2],
                "c": [2, 2, 2],
                "d": [1, 1, 1],
                "e": [0, 0, 0],
            },
            orient="index",
            columns=["x1", "x2", "x3"],
        )

        delta = pd.DataFrame.from_dict(
            {
                "a": [0, 2, 1, 1, 0],
                "b": [2, 0, 1, 2, 0],
                "c": [1, 1, 0, 0, 0],
                "d": [1, 2, 0, 0, 0],
                "e": [0, 0, 0, 0, 0],
            },
            orient="index",
            columns=["a", "b", "c", "d", "e"],
        )

        self.basic_similarity_map = delta
        self.basic_tree = CassiopeiaTree(character_matrix=cm,
                                         dissimilarity_map=delta)

        self.smj_solver = SharedMutationJoiningSolver(
            similarity_function=dissimilarity_functions.
            hamming_similarity_without_missing)
        self.smj_solver_no_numba = SharedMutationJoiningSolver(
            similarity_function=partial(
                dissimilarity_functions.cluster_dissimilarity,
                dissimilarity_functions.hamming_similarity_without_missing,
            ))

        # ---------------- Lineage Tracing NJ ----------------

        pp_cm = pd.DataFrame.from_dict(
            {
                "a": [1, 2, 2],
                "b": [1, 2, 1],
                "c": [1, 2, 0],
                "d": [2, 0, 0],
                "e": [2, 0, 2],
            },
            orient="index",
            columns=["x1", "x2", "x3"],
        )

        self.pp_tree = CassiopeiaTree(character_matrix=pp_cm)

        self.smj_solver_pp = SharedMutationJoiningSolver(
            similarity_function=dissimilarity_functions.
            hamming_similarity_without_missing)

        # ------------- CM with Duplicates and Missing Data -----------------------
        duplicates_cm = pd.DataFrame.from_dict(
            {
                "a": [1, -1, 0],
                "b": [2, -1, 2],
                "c": [2, 0, 2],
                "d": [2, 0, -1],
                "e": [2, 0, 2],
                "f": [2, -1, 2],
            },
            orient="index",
            columns=["x1", "x2", "x3"],
        )

        self.duplicate_tree = CassiopeiaTree(character_matrix=duplicates_cm)

        # ------------- Hamming similarity with weights ------------
        priors = {
            0: {
                1: 0.5,
                2: 0.5
            },
            1: {
                1: 0.2,
                2: 0.8
            },
            2: {
                1: 0.9,
                2: 0.1
            }
        }
        self.pp_tree_priors = CassiopeiaTree(character_matrix=pp_cm,
                                             priors=priors)
        self.smj_solver_modified_pp = SharedMutationJoiningSolver(
            similarity_function=dissimilarity_functions.
            hamming_similarity_without_missing)

    def test_init(self):
        # This should numbaize
        solver = SharedMutationJoiningSolver(
            similarity_function=dissimilarity_functions.
            hamming_similarity_without_missing)
        self.assertTrue(
            isinstance(solver.nb_similarity_function,
                       numba.core.registry.CPUDispatcher))
        self.assertTrue(
            isinstance(
                solver._SharedMutationJoiningSolver__update_similarity_map,
                numba.core.registry.CPUDispatcher,
            ))

        # This shouldn't numbaize
        with self.assertWarns(SharedMutationJoiningSolverWarning):
            solver = SharedMutationJoiningSolver(similarity_function=partial(
                dissimilarity_functions.cluster_dissimilarity,
                dissimilarity_functions.hamming_similarity_without_missing,
            ))
            self.assertFalse(
                isinstance(
                    solver.nb_similarity_function,
                    numba.core.registry.CPUDispatcher,
                ))
            self.assertFalse(
                isinstance(
                    solver._SharedMutationJoiningSolver__update_similarity_map,
                    numba.core.registry.CPUDispatcher,
                ))

    def test_find_cherry(self):
        cherry = self.smj_solver.find_cherry(self.basic_similarity_map.values)
        delta = self.basic_similarity_map
        node_i, node_j = (delta.index[cherry[0]], delta.index[cherry[1]])

        self.assertIn((node_i, node_j), [("a", "b"), ("b", "a")])

    def test_create_similarity_map(self):
        character_matrix = self.pp_tree_priors.character_matrix.copy()
        weights = solver_utilities.transform_priors(self.pp_tree_priors.priors,
                                                    "negative_log")

        similarity_map = data_utilities.compute_dissimilarity_map(
            character_matrix.to_numpy(),
            character_matrix.shape[0],
            dissimilarity_functions.hamming_similarity_without_missing,
            weights,
            self.pp_tree_priors.missing_state_indicator,
        )

        similarity_map = scipy.spatial.distance.squareform(similarity_map)

        similarity_map = pd.DataFrame(
            similarity_map,
            index=character_matrix.index,
            columns=character_matrix.index,
        )

        expected_similarity = -np.log(0.5) - np.log(0.8)
        self.assertEqual(similarity_map.loc["a", "b"], expected_similarity)
        expected_similarity = -np.log(0.1)
        self.assertEqual(similarity_map.loc["a", "e"], expected_similarity)

    def test_update_similarity_map_and_character_matrix(self):
        nb_similarity = numba.jit(
            dissimilarity_functions.hamming_similarity_without_missing,
            nopython=True,
        )
        nb_weights = numba.typed.Dict.empty(
            numba.types.int64,
            numba.types.DictType(numba.types.int64, numba.types.float64),
        )

        cm = self.basic_tree.character_matrix.copy()
        delta = self.basic_similarity_map

        cherry = self.smj_solver.find_cherry(delta.values)
        node_i, node_j = (delta.index[cherry[0]], delta.index[cherry[1]])

        delta = self.smj_solver.update_similarity_map_and_character_matrix(
            cm,
            nb_similarity,
            delta, (node_i, node_j),
            "ab",
            weights=nb_weights)

        expected_delta = pd.DataFrame.from_dict(
            {
                "ab": [0, 1, 1, 0],
                "c": [1, 0, 0, 0],
                "d": [1, 0, 0, 0],
                "e": [0, 0, 0, 0],
            },
            orient="index",
            columns=["ab", "c", "d", "e"],
        )

        for sample in expected_delta.index:
            for sample2 in expected_delta.index:
                self.assertEqual(
                    delta.loc[sample, sample2],
                    expected_delta.loc[sample, sample2],
                )

        cherry = self.smj_solver.find_cherry(delta.values)
        node_i, node_j = (delta.index[cherry[0]], delta.index[cherry[1]])

        delta = self.smj_solver.update_similarity_map_and_character_matrix(
            cm,
            nb_similarity,
            delta,
            (node_i, node_j),
            "abc",
            weights=nb_weights,
        )

        expected_delta = pd.DataFrame.from_dict(
            {
                "abc": [0, 0, 0],
                "d": [0, 0, 0],
                "e": [0, 0, 0]
            },
            orient="index",
            columns=["abc", "d", "e"],
        )

        for sample in expected_delta.index:
            for sample2 in expected_delta.index:
                self.assertEqual(
                    delta.loc[sample, sample2],
                    expected_delta.loc[sample, sample2],
                )

        expected_cm = pd.DataFrame.from_dict(
            {
                "abc": [0, 0, 2],
                "d": [1, 1, 1],
                "e": [0, 0, 0]
            },
            orient="index",
            columns=["x1", "x2", "x3"],
        )

        for sample in expected_cm.index:
            for col in expected_cm.columns:
                self.assertEqual(cm.loc[sample, col], expected_cm.loc[sample,
                                                                      col])

    def test_basic_solver(self):
        self.smj_solver.solve(self.basic_tree)

        # test that the dissimilarity map and character matrix were not altered
        cm = pd.DataFrame.from_dict(
            {
                "a": [0, 1, 2],
                "b": [1, 1, 2],
                "c": [2, 2, 2],
                "d": [1, 1, 1],
                "e": [0, 0, 0],
            },
            orient="index",
            columns=["x1", "x2", "x3"],
        )
        for i in self.basic_similarity_map.index:
            for j in self.basic_similarity_map.columns:
                self.assertEqual(
                    self.basic_similarity_map.loc[i, j],
                    self.basic_tree.get_dissimilarity_map().loc[i, j],
                )
        for i in self.basic_tree.character_matrix.index:
            for j in self.basic_tree.character_matrix.columns:
                self.assertEqual(cm.loc[i, j],
                                 self.basic_tree.character_matrix.loc[i, j])

        # test leaves exist in tree
        _leaves = self.basic_tree.leaves

        self.assertEqual(len(_leaves), self.basic_similarity_map.shape[0])
        for _leaf in _leaves:
            self.assertIn(_leaf, self.basic_similarity_map.index.values)

        # test for expected number of edges
        edges = list(self.basic_tree.edges)
        self.assertEqual(len(edges), 8)

        # test relationships between samples
        expected_tree = nx.DiGraph()
        expected_tree.add_edges_from([
            ("5", "a"),
            ("5", "b"),
            ("6", "5"),
            ("6", "c"),
            ("7", "d"),
            ("7", "e"),
            ("8", "6"),
            ("8", "7"),
        ])

        observed_tree = self.basic_tree.get_tree_topology()
        triplets = itertools.combinations(["a", "b", "c", "d", "e"], 3)
        for triplet in triplets:

            expected_triplet = find_triplet_structure(triplet, expected_tree)
            observed_triplet = find_triplet_structure(triplet, observed_tree)
            self.assertEqual(expected_triplet, observed_triplet)

        # compare tree distances
        observed_tree = observed_tree.to_undirected()
        expected_tree = expected_tree.to_undirected()
        for i in range(len(_leaves)):
            sample1 = _leaves[i]
            for j in range(i + 1, len(_leaves)):
                sample2 = _leaves[j]
                self.assertEqual(
                    nx.shortest_path_length(observed_tree, sample1, sample2),
                    nx.shortest_path_length(expected_tree, sample1, sample2),
                )

    def test_solver_no_numba(self):
        self.smj_solver_no_numba.solve(self.basic_tree)

        # test that the dissimilarity map and character matrix were not altered
        cm = pd.DataFrame.from_dict(
            {
                "a": [0, 1, 2],
                "b": [1, 1, 2],
                "c": [2, 2, 2],
                "d": [1, 1, 1],
                "e": [0, 0, 0],
            },
            orient="index",
            columns=["x1", "x2", "x3"],
        )
        for i in self.basic_similarity_map.index:
            for j in self.basic_similarity_map.columns:
                self.assertEqual(
                    self.basic_similarity_map.loc[i, j],
                    self.basic_tree.get_dissimilarity_map().loc[i, j],
                )
        for i in self.basic_tree.character_matrix.index:
            for j in self.basic_tree.character_matrix.columns:
                self.assertEqual(cm.loc[i, j],
                                 self.basic_tree.character_matrix.loc[i, j])

        # test leaves exist in tree
        _leaves = self.basic_tree.leaves

        self.assertEqual(len(_leaves), self.basic_similarity_map.shape[0])
        for _leaf in _leaves:
            self.assertIn(_leaf, self.basic_similarity_map.index.values)

        # test for expected number of edges
        edges = list(self.basic_tree.edges)
        self.assertEqual(len(edges), 8)

        # test relationships between samples
        expected_tree = nx.DiGraph()
        expected_tree.add_edges_from([
            ("5", "a"),
            ("5", "b"),
            ("6", "5"),
            ("6", "c"),
            ("7", "d"),
            ("7", "e"),
            ("8", "6"),
            ("8", "7"),
        ])

        observed_tree = self.basic_tree.get_tree_topology()
        triplets = itertools.combinations(["a", "b", "c", "d", "e"], 3)
        for triplet in triplets:

            expected_triplet = find_triplet_structure(triplet, expected_tree)
            observed_triplet = find_triplet_structure(triplet, observed_tree)
            self.assertEqual(expected_triplet, observed_triplet)

        # compare tree distances
        observed_tree = observed_tree.to_undirected()
        expected_tree = expected_tree.to_undirected()
        for i in range(len(_leaves)):
            sample1 = _leaves[i]
            for j in range(i + 1, len(_leaves)):
                sample2 = _leaves[j]
                self.assertEqual(
                    nx.shortest_path_length(observed_tree, sample1, sample2),
                    nx.shortest_path_length(expected_tree, sample1, sample2),
                )

    def test_smj_solver_weights(self):
        self.smj_solver_modified_pp.solve(self.pp_tree_priors)
        observed_tree = self.pp_tree_priors.get_tree_topology()

        expected_tree = nx.DiGraph()
        expected_tree.add_edges_from([
            ("5", "a"),
            ("5", "e"),
            ("6", "b"),
            ("6", "c"),
            ("7", "5"),
            ("7", "d"),
            ("8", "6"),
            ("8", "7"),
        ])

        triplets = itertools.combinations(["a", "b", "c", "d", "e"], 3)
        for triplet in triplets:
            expected_triplet = find_triplet_structure(triplet, expected_tree)
            observed_triplet = find_triplet_structure(triplet, observed_tree)
            self.assertEqual(expected_triplet, observed_triplet)

        self.smj_solver_pp.solve(self.pp_tree,
                                 collapse_mutationless_edges=True)
        expected_tree = nx.DiGraph()
        expected_tree.add_edges_from([
            ("5", "a"),
            ("5", "e"),
            ("6", "b"),
            ("6", "c"),
            ("8", "5"),
            ("8", "d"),
            ("8", "6"),
        ])

    def test_pp_solver(self):
        self.smj_solver_pp.solve(self.pp_tree)
        observed_tree = self.pp_tree.get_tree_topology()

        pp_cm = pd.DataFrame.from_dict(
            {
                "a": [1, 2, 2],
                "b": [1, 2, 1],
                "c": [1, 2, 0],
                "d": [2, 0, 0],
                "e": [2, 0, 2],
            },
            orient="index",
            columns=["x1", "x2", "x3"],
        )
        self.assertIsNone(self.pp_tree.get_dissimilarity_map())
        for i in self.pp_tree.character_matrix.index:
            for j in self.pp_tree.character_matrix.columns:
                self.assertEqual(pp_cm.loc[i, j],
                                 self.pp_tree.character_matrix.loc[i, j])

        expected_tree = nx.DiGraph()
        expected_tree.add_edges_from([
            ("5", "a"),
            ("5", "b"),
            ("6", "5"),
            ("6", "c"),
            ("7", "d"),
            ("7", "e"),
            ("8", "6"),
            ("8", "7"),
        ])

        triplets = itertools.combinations(["a", "b", "c", "d", "e"], 3)
        for triplet in triplets:
            expected_triplet = find_triplet_structure(triplet, expected_tree)
            observed_triplet = find_triplet_structure(triplet, observed_tree)
            self.assertEqual(expected_triplet, observed_triplet)

        self.smj_solver_pp.solve(self.pp_tree,
                                 collapse_mutationless_edges=True)
        observed_tree = self.pp_tree.get_tree_topology()
        for triplet in triplets:
            expected_triplet = find_triplet_structure(triplet, expected_tree)
            observed_triplet = find_triplet_structure(triplet, observed_tree)
            self.assertEqual(expected_triplet, observed_triplet)

    def test_duplicate(self):
        # In this case, we see that the missing data can break up a duplicate
        # pair if the behavior is to ignore missing data

        self.smj_solver_pp.solve(self.duplicate_tree)
        observed_tree = self.duplicate_tree.get_tree_topology()

        expected_tree = nx.DiGraph()
        expected_tree.add_edges_from([
            ("5", "b"),
            ("5", "c"),
            ("6", "e"),
            ("6", "f"),
            ("7", "5"),
            ("7", "6"),
            ("8", "7"),
            ("8", "d"),
            ("9", "8"),
            ("9", "a"),
        ])
        triplets = itertools.combinations(["a", "b", "c", "d", "e", "f"], 3)
        for triplet in triplets:
            expected_triplet = find_triplet_structure(triplet, expected_tree)
            observed_triplet = find_triplet_structure(triplet, observed_tree)
            self.assertEqual(expected_triplet, observed_triplet)
def estimate_mutation_rate(
    tree: CassiopeiaTree,
    continuous: bool = True,
    assume_root_implicit_branch: bool = True,
    layer: Optional[str] = None,
) -> float:
    """Estimates the mutation rate from a tree and character matrix.

    Infers the mutation rate using the proportion of the cell/character
    entries in the leaves that have a non-uncut (non-0) state and the node
    depth/the total time of the tree. The mutation rate is either a
    continuous or per-generation rate according to which lineages accumulate
    mutations.

    In estimating the mutation rate, we use the observed proportion of mutated
    entries in the character matrix as an estimate of the probability that a
    mutation occurs on a lineage. Using this probability, we can then infer
    the mutation rate.

    This function attempts to consume the observed mutation proportion as
    `mutation_proportion` in `tree.parameters`. If this field is not populated,
    it is inferred using `get_proportion_of_mutation`.

    In the case where the rate is per-generation (probability a mutation occurs
    on an edge), it is estimated using:

        mutated proportion = 1 - (1 - mutation_rate) ^ (average depth of tree)

    In the case when the rate is continuous, it is estimated using:

        mutated proportion = ExponentialCDF(average time of tree, mutation rate)

    Note that these naive estimates perform better when the tree is ultrametric
    in depth or time. The average depth/lineage time of the tree is used as a
    proxy for the depth/total time when the tree is not ultrametric.

    In the inference, we need to consider whether to assume an implicit root,
    specified by `assume_root_implicit_branch`. In the case where the tree does
    not have a single leading edge from the root representing the progenitor
    cell before cell division begins, this additional edge is added to the
    total time in calculating the estimate if `assume_root_implicit_branch` is
    True.

    Args:
        tree: The CassiopeiaTree specifying the tree and the character matrix
        continuous: Whether to calculate a continuous mutation rate,
            accounting for branch lengths. Otherwise, calculates a
            discrete mutation rate using the node depths
        assume_root_implicit_branch: Whether to assume that there is an
            implicit branch leading from the root, if it doesn't exist
        layer: Layer to use for character matrix. If this is None,
            then the current `character_matrix` variable will be used.

    Returns:
        The estimated mutation rate

    Raises:
        ParameterEstimateError if the `mutation_proportion` parameter is not
            between 0 and 1
    """
    if "mutated_proportion" not in tree.parameters:
        mutation_proportion = get_proportion_of_mutation(tree, layer)
    else:
        mutation_proportion = tree.parameters["mutated_proportion"]
    if mutation_proportion < 0 or mutation_proportion > 1:
        raise ParameterEstimateError(
            "Mutation proportion must be between 0 and 1."
        )
    if not continuous:
        mean_depth = tree.get_mean_depth_of_tree()
        # We account for the added depth of the implicit branch leading
        # from the root, if it is to be added
        if assume_root_implicit_branch and len(tree.children(tree.root)) != 1:
            mean_depth += 1
        mutation_rate = 1 - (1 - mutation_proportion) ** (1 / mean_depth)
    else:
        times = tree.get_times()
        mean_time = np.mean([times[l] for l in tree.leaves])
        if assume_root_implicit_branch and len(tree.children(tree.root)) != 1:
            mean_time += np.mean(
                [tree.get_branch_length(u, v) for u, v in tree.edges]
            )
        mutation_rate = -np.log(1 - mutation_proportion) / mean_time
    return mutation_rate
def estimate_missing_data_rates(
    tree: CassiopeiaTree,
    continuous: bool = True,
    assume_root_implicit_branch: bool = True,
    stochastic_missing_probability: Optional[float] = None,
    heritable_missing_rate: Optional[float] = None,
    layer: Optional[str] = None,
) -> Tuple[float, float]:
    """
    Estimates both missing data parameters given one of the two from a tree.

    The stochastic missing probability is the probability that any given
    cell/character pair acquires stochastic missing data in the character
    matrix due to low-capture in single-cell RNA sequencing. The heritable
    missing rate is either a continuous or per-generation rate according to
    which lineages accumulate heritable missing data events, such as
    transcriptional silencing or resection.

    In most instances, the two types of missing data are convolved and we
    determine whether any single occurrence of missing data is due to stochastic
    or heritable missing data. We assume both contribute to the total amount of
    missing data as:

        total missing proportion = heritable proportion + stochastic proportion
            - heritable proportion * stochastic proportion

    This function attempts to consume the amount of missing data (the total
    missing proportion) as `missing_proportion` in `tree.parameters`, inferring
    it using `get_proportion_of_missing_data` if it is not populated.

    Additionally, as the two types of data are convolved, we need to know the
    contribution of one of the types of missing data in order to estimate the
    other. This function attempts to consume the heritable missing rate as
    `heritable_missing_rate` in `tree.parameters` and the stochastic missing
    probability as `stochastic_missing_probability` in `tree.parameters`.
    If they are not provided on the tree, then they may be provided as
    function arguments. If neither or both parameters are provided by either of
    these methods, the function errors.

    In estimating the heritable missing rate from the stochastic missing data
    probability, we take the proportion of stochastic missing data in the
    character matrix as equal to the stochastic probability. Then using the
    total observed proportion of missing data as well as the estimated
    proportion of stochastic missing data we can estimate the proportion
    of heritable missing data using the expression above. Finally, we use the
    heritable proportion as an estimate of the probability a lineage acquires
    a missing data event by the end of the phylogeny, and using this
    probability we can estimate the rate.

    In the case where the rate is per-generation (probability a heritable
    missing data event occurs on an edge), it is estimated using:

        heritable missing proportion =
            1 - (1 - heritable missing rate) ^ (average depth of tree)

    In the case where the rate is continuous, it is estimated using:

        heritable_missing_proportion =
            ExponentialCDF(average time of tree, heritable missing rate)

    Note that these naive estimates perform better when the tree is ultrametric
    in depth or time. The average depth/lineage time of the tree is used as a
    proxy for the depth/total time when the tree is not ultrametric.

    In calculating the heritable proportion from the heritable missing rate,
    we need to consider whether to assume an implicit root. This is specified
    by `assume_root_implicit_branch`. In the case where the tree does not have
    a single leading edge from the root representing the progenitor cell before
    cell division begins, this additional edge is added to the total time in
    calculating the estimate if `assume_root_implicit_branch` is True.

    In estimating the stochastic missing probability from the heritable missing
    rate, we calculate the expected proportion of heritable missing data using
    the heritable rate in the same way, and then as above use the total
    proportion of missing data to estimate the stochastic proportion, which we
    assume is equal to the probability.

    Args:
        tree: The CassiopeiaTree specifying the tree and the character matrix
        continuous: Whether to calculate a continuous missing rate,
            accounting for branch lengths. Otherwise, calculates a
            discrete missing rate based on the number of generations
        assume_root_implicit_branch: Whether to assume that there is an
            implicit branch leading from the root, if it doesn't exist
        stochastic_missing_probability: The stochastic missing probability.
            Will override the value on the tree. Observed probabilites of
            stochastic missing data range between 10-20%
        heritable_missing_rate: The heritable missing rate. Will override the
            value on the tree
        layer: Layer to use for character matrix. If this is None,
            then the current `character_matrix` variable will be used.

    Returns:
        The stochastic missing probability and heritable missing rate. One of
        these will be the parameter as provided, the other will be an estimate

    Raises:
        ParameterEstimateError if the `total_missing_proportion`,
            `stochastic_missing_probability`, or `heritable_missing_rate` that
            are provided have invalid values, or if both or neither of
            `stochastic_missing_probability`, and `heritable_missing_rate` are
            provided. ParameterEstimateWarning if the estimated parameter is
            negative
    """

    if "missing_proportion" not in tree.parameters:
        total_missing_proportion = get_proportion_of_missing_data(tree, layer)
    else:
        total_missing_proportion = tree.parameters["missing_proportion"]
    if total_missing_proportion < 0 or total_missing_proportion > 1:
        raise ParameterEstimateError(
            "Missing proportion must be between 0 and 1."
        )

    if stochastic_missing_probability is None:
        if "stochastic_missing_probability" in tree.parameters:
            stochastic_missing_probability = tree.parameters[
                "stochastic_missing_probability"
            ]

    if heritable_missing_rate is None:
        if "heritable_missing_rate" in tree.parameters:
            heritable_missing_rate = tree.parameters["heritable_missing_rate"]

    if (
        heritable_missing_rate is None
        and stochastic_missing_probability is None
    ):
        raise ParameterEstimateError(
            "Neither `heritable_missing_rate` nor "
            "`stochastic_missing_probability` were provided as arguments or "
            "found in `tree.parameters`. Please provide one of these "
            "parameters, otherwise they are convolved and cannot be estimated"
        )

    if (
        heritable_missing_rate is not None
        and stochastic_missing_probability is not None
    ):
        raise ParameterEstimateError(
            "Both `heritable_missing_rate` and `stochastic_missing_probability`"
            " were provided as parameters or found in `tree.parameters`. "
            "Please only supply one of the two"
        )

    if heritable_missing_rate is None:
        if stochastic_missing_probability < 0:
            raise ParameterEstimateError(
                "Stochastic missing data rate must be > 0."
            )
        if stochastic_missing_probability > 1:
            raise ParameterEstimateError(
                "Stochastic missing data rate must be < 1."
            )

        if not continuous:
            mean_depth = tree.get_mean_depth_of_tree()
            # We account for the added depth of the implicit branch leading
            # from the root, if it is to be added
            if (
                assume_root_implicit_branch
                and len(tree.children(tree.root)) != 1
            ):
                mean_depth += 1
            heritable_missing_rate = 1 - (
                (1 - total_missing_proportion)
                / (1 - stochastic_missing_probability)
            ) ** (1 / mean_depth)

        else:
            times = tree.get_times()
            mean_time = np.mean([times[l] for l in tree.leaves])
            if (
                assume_root_implicit_branch
                and len(tree.children(tree.root)) != 1
            ):
                mean_time += np.mean(
                    [tree.get_branch_length(u, v) for u, v in tree.edges]
                )
            heritable_missing_rate = (
                -np.log(
                    (1 - total_missing_proportion)
                    / (1 - stochastic_missing_probability)
                )
                / mean_time
            )

    if stochastic_missing_probability is None:
        if heritable_missing_rate < 0:
            raise ParameterEstimateError(
                "Heritable missing data rate must be > 0."
            )
        if not continuous and heritable_missing_rate > 1:
            raise ParameterEstimateError(
                "Per-generation heritable missing data rate must be < 1."
            )

        if not continuous:
            mean_depth = tree.get_mean_depth_of_tree()
            # We account for the added depth of the implicit branch leading
            # from the root, if it is to be added
            if (
                assume_root_implicit_branch
                and len(tree.children(tree.root)) != 1
            ):
                mean_depth += 1

            heritable_proportion = 1 - (1 - heritable_missing_rate) ** (
                mean_depth
            )

        else:
            times = tree.get_times()
            mean_time = np.mean([times[l] for l in tree.leaves])
            if (
                assume_root_implicit_branch
                and len(tree.children(tree.root)) != 1
            ):
                mean_time += np.mean(
                    [tree.get_branch_length(u, v) for u, v in tree.edges]
                )

            heritable_proportion = 1 - np.exp(
                -heritable_missing_rate * mean_time
            )

        stochastic_missing_probability = (
            total_missing_proportion - heritable_proportion
        ) / (1 - heritable_proportion)

    if stochastic_missing_probability < 0:
        raise ParameterEstimateWarning(
            "Estimate of the stochastic missing probability using this "
            "heritable rate resulted in a negative stochastic missing "
            "probability. It may be that this heritable rate is too high."
        )

    if heritable_missing_rate < 0:
        raise ParameterEstimateWarning(
            "Estimate of the heritable rate using this stochastic missing "
            "probability resulted in a negative heritable rate. It may be that "
            "this stochastic missing probability is too high."
        )

    return stochastic_missing_probability, heritable_missing_rate
示例#14
0
    def test_subsample_balanced_tree(self):
        tree = CassiopeiaTree(
            tree=self.test_network, character_matrix=self.character_matrix
        )

        np.random.seed(10)
        sampler = SupercellularSampler(number_of_merges=2)
        res = sampler.subsample_leaves(tree=tree)
        cm = res.character_matrix
        expected_character_matrix = pd.DataFrame.from_dict(
            {
                "node7": [1, 1, 0, 0, 0, 0, 0, 0],
                "node9": [1, 1, 1, 0, 0, 0, 0, 0],
                "node11": [1, 1, 1, 1, 0, 0, 0, 0],
                "node13": [1, 1, 1, 1, 1, 0, 0, 0],
                "node17": [1, 1, 1, 1, 1, 1, 1, 0],
                "node6": [2, 2, 0, 0, 0, 0, 0, 0],
                "node18-node15": [
                    (1,),
                    (1,),
                    (1,),
                    (1,),
                    (1,),
                    (1,),
                    (0, 1),
                    (0, 1),
                ],
                "node3-node5": [
                    (1, 2),
                    (0,),
                    (0,),
                    (0,),
                    (0,),
                    (0,),
                    (0,),
                    (0,),
                ],
            },
            orient="index",
        )
        pd.testing.assert_frame_equal(expected_character_matrix, cm)
        expected_edges = [
            ("node0", "node3-node5"),
            ("node0", "node4"),
            ("node0", "node6"),
            ("node4", "node7"),
            ("node4", "node8"),
            ("node8", "node10"),
            ("node8", "node9"),
            ("node10", "node11"),
            ("node10", "node12"),
            ("node12", "node13"),
            ("node14", "node17"),
            ("node12", "node14"),
            ("node14", "node18-node15"),
        ]
        self.assertEqual(set(res.edges), set(expected_edges))

        np.random.seed(10)
        res = sampler.subsample_leaves(tree=tree, collapse_duplicates=False)
        cm = res.character_matrix
        expected_character_matrix = pd.DataFrame.from_dict(
            {
                "node7": [1, 1, 0, 0, 0, 0, 0, 0],
                "node9": [1, 1, 1, 0, 0, 0, 0, 0],
                "node11": [1, 1, 1, 1, 0, 0, 0, 0],
                "node13": [1, 1, 1, 1, 1, 0, 0, 0],
                "node17": [1, 1, 1, 1, 1, 1, 1, 0],
                "node6": [2, 2, 0, 0, 0, 0, 0, 0],
                "node18-node15": [
                    (1, 1),
                    (1, 1),
                    (1, 1),
                    (1, 1),
                    (1, 1),
                    (1, 1),
                    (1, 0),
                    (1, 0),
                ],
                "node3-node5": [
                    (1, 2),
                    (0, 0),
                    (0, 0),
                    (0, 0),
                    (0, 0),
                    (0, 0),
                    (0, 0),
                    (0, 0),
                ],
            },
            orient="index",
        )
        pd.testing.assert_frame_equal(expected_character_matrix, cm)
        expected_edges = [
            ("node4", "node7"),
            ("node4", "node8"),
            ("node8", "node9"),
            ("node12", "node13"),
            ("node14", "node17"),
            ("node12", "node14"),
            ("node14", "node18-node15"),
            ("node0", "node6"),
            ("node10", "node12"),
            ("node10", "node11"),
            ("node8", "node10"),
            ("node0", "node4"),
            ("node0", "node3-node5"),
        ]
        self.assertEqual(set(res.edges), set(expected_edges))