コード例 #1
0
def test_sort(default_config, setup_linear_mcts, setup_branched_mcts):
    _, node1 = setup_linear_mcts()
    _, node2 = setup_branched_mcts()
    scorer = StateScorer(default_config)

    sorted_nodes, scores, _ = scorer.sort([node2, node1])

    assert [np.round(score, 4) for score in scores] == [0.994, 0.9866]
    assert sorted_nodes == [node1, node2]
コード例 #2
0
def test_sort(shared_datadir, default_config, mock_stock):
    mock_stock(default_config, "CCCO", "CC")
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_without_repetition.json", default_config)
    nodes = list(search_tree.graph())
    scorer = StateScorer(default_config)

    sorted_nodes, scores, _ = scorer.sort(nodes)

    assert [np.round(score, 4) for score in scores] == [0.9976, 0.0491]
    assert sorted_nodes == [nodes[1], nodes[0]]
コード例 #3
0
class TreeAnalysis:
    """
    Class that encapsulate various analysis that can be
    performed on a search tree.

    :ivar scorer: the object used to score the nodes
    :type scorer: Scorer
    :ivar search_tree: the search tree
    :vartype search_tree: SearchTree

    :param scorer: the object used to score the nodes, defaults to StateScorer
    :type scorer: Scorer, optional
    :parameter search_tree: the search tree to do the analysis on
    :type search_tree: SearchTree
    """
    def __init__(self, search_tree, scorer=None):
        self.search_tree = search_tree
        if scorer is None:
            # Do import here to avoid circular imports
            from aizynthfinder.context.scoring import StateScorer

            self.scorer = StateScorer()
        else:
            self.scorer = scorer

    def best_node(self):
        """
        Returns the node with the highest score.
        If several nodes have the same score, it will return the first

        :return: the top scoring node
        :rtype: Node
        """
        nodes = self._all_nodes()
        sorted_nodes, _ = self.scorer.sort(nodes)
        return sorted_nodes[0]

    def sort_nodes(self, min_return=5, max_return=25):
        """
        Sort and select the nodes, so that the best scoring routes are returned.
        The algorithm filter away identical routes and returns at minimum the number specified.
        If multiple alternative routes have the same score as the n'th route, they will be included and returned.

        :param min_return: the minium number of routes to return, defaults to 5
        :type min_return: int, optional
        :param max_return: the maximum number of routes to return
        :type max_return: int, optional
        :return: the nodes
        :rtype: list of Node
        :return: the score
        :rtype: list of float
        """
        nodes = self._all_nodes()
        sorted_nodes, sorted_scores = self.scorer.sort(nodes)

        if len(nodes) <= min_return:
            return sorted_nodes, sorted_scores

        seen_hashes = set()
        best_nodes = []
        best_scores = []
        last_score = 1e16
        for score, node in zip(sorted_scores, sorted_nodes):
            if len(best_nodes) >= min_return and score < last_score:
                break
            route_actions, _ = self.search_tree.route_to_node(node)
            route_hash = hash_reactions(route_actions)

            if route_hash in seen_hashes:
                continue
            seen_hashes.add(route_hash)
            best_nodes.append(node)
            best_scores.append(score)
            last_score = score

            if max_return and len(best_nodes) == max_return:
                break

        return best_nodes, best_scores

    def tree_statistics(self):
        """
        Returns statiscs of the tree

        Currently it returns the number of nodes, the maximum number of transforms,
        maximum number of children, top score, if the top score route is solved,
        the number of molecule in the top score node, and information on pre-cursors

        :return: the statistics
        :rtype: dict
        """
        top_node = self.best_node()
        top_state = top_node.state
        nodes = list(self.search_tree.graph())
        mols_in_stock = ", ".join(
            mol.smiles
            for mol, instock in zip(top_state.mols, top_state.in_stock_list)
            if instock)
        mols_not_in_stock = ", ".join(
            mol.smiles
            for mol, instock in zip(top_state.mols, top_state.in_stock_list)
            if not instock)

        policy_used_counts = defaultdict(lambda: 0)
        for node in nodes:
            for child in node.children():
                policy_used = node[child]["action"].metadata.get("policy_name")
                if policy_used:
                    policy_used_counts[policy_used] += 1

        return {
            "number_of_nodes":
            len(nodes),
            "max_transforms":
            max(node.state.max_transforms for node in nodes),
            "max_children":
            max(len(node.children()) for node in nodes),
            "number_of_leafs":
            sum(1 for node in nodes if not node.children()),
            "number_of_solved_leafs":
            sum(1 for node in nodes
                if not node.children() and node.state.is_solved),
            "top_score":
            self.scorer(top_node),
            "is_solved":
            top_state.is_solved,
            "number_of_steps":
            top_state.max_transforms,
            "number_of_precursors":
            len(top_state.mols),
            "number_of_precursors_in_stock":
            sum(top_state.in_stock_list),
            "precursors_in_stock":
            mols_in_stock,
            "precursors_not_in_stock":
            mols_not_in_stock,
            "policy_used_counts":
            dict(policy_used_counts),
        }

    def _all_nodes(self):
        # This is to keep backwards compatibility, this should be investigate further
        if repr(self.scorer) == "state score":
            return list(self.search_tree.graph())
        return [
            node for node in self.search_tree.graph() if not node.children()
        ]