Exemplo n.º 1
0
def test_sort(default_config, setup_linear_mcts, setup_branched_mcts):
    _, node1 = setup_linear_mcts()
    _, node2 = setup_branched_mcts()
    scorer = StateScorer(default_config)

    sorted_nodes, scores, _ = scorer.sort([node2, node1])

    assert [np.round(score, 4) for score in scores] == [0.994, 0.9866]
    assert sorted_nodes == [node1, node2]
Exemplo n.º 2
0
    def __init__(self, search_tree, scorer=None):
        self.search_tree = search_tree
        if scorer is None:
            # Do import here to avoid circular imports
            from aizynthfinder.context.scoring import StateScorer

            self.scorer = StateScorer()
        else:
            self.scorer = scorer
Exemplo n.º 3
0
def test_sort(shared_datadir, default_config, mock_stock):
    mock_stock(default_config, "CCCO", "CC")
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_without_repetition.json", default_config)
    nodes = list(search_tree.graph())
    scorer = StateScorer(default_config)

    sorted_nodes, scores, _ = scorer.sort(nodes)

    assert [np.round(score, 4) for score in scores] == [0.9976, 0.0491]
    assert sorted_nodes == [nodes[1], nodes[0]]
Exemplo n.º 4
0
def test_scoring_branch_mcts_tree_in_stock(shared_datadir, default_config,
                                           mock_stock):
    mock_stock(
        default_config,
        "CC(C)(C)CO",
        "CC(C)(C)OC(=O)N(CCCl)CCCl",
        "N#CCc1cccc(O)c1F",
        "O=[N+]([O-])c1ccccc1F",
        "O=C1CCC(=O)N1Br",
        "O=C=Nc1csc(C(F)(F)F)n1",
        "CCC[Sn](Cl)(CCC)CCC",
        "COc1ccc2ncsc2c1",
    )
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_with_branching.json", default_config)
    nodes = list(search_tree.graph())

    assert pytest.approx(StateScorer(default_config)(nodes[-1]),
                         abs=1e-3) == 0.950
    assert NumberOfReactionsScorer()(nodes[-1]) == 14
    assert NumberOfPrecursorsScorer(default_config)(nodes[-1]) == 8
    assert NumberOfPrecursorsInStockScorer(default_config)(nodes[-1]) == 8
    assert PriceSumScorer(default_config)(nodes[-1]) == 8
    cost_score = RouteCostScorer(default_config)(nodes[-1])
    assert pytest.approx(cost_score, abs=1e-3) == 77.4797
Exemplo n.º 5
0
def test_add_scorer_to_collection(default_config):
    collection = ScorerCollection(default_config)
    del collection["state score"]

    collection.load(StateScorer(default_config))

    assert "state score" in collection.names()
Exemplo n.º 6
0
def test_state_scorer_tree(load_reaction_tree, default_config, mock_stock):
    mock_stock(default_config, "N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1",
               "CN1CCC(Cl)CC1", "O")
    tree = ReactionTree.from_dict(load_reaction_tree("sample_reaction.json"))
    scorer = StateScorer(default_config)

    assert round(scorer(tree), 4) == 0.994
Exemplo n.º 7
0
def test_scoring_branched_route(load_reaction_tree, default_config):
    tree = ReactionTree.from_dict(load_reaction_tree("branched_route.json"))

    assert pytest.approx(StateScorer(default_config)(tree),
                         abs=1e-6) == 0.00012363
    assert NumberOfReactionsScorer(default_config)(tree) == 14
    assert NumberOfPrecursorsScorer(default_config)(tree) == 8
    assert NumberOfPrecursorsInStockScorer(default_config)(tree) == 0
Exemplo n.º 8
0
def test_state_scorer_trees(default_config, setup_linear_reaction_tree):
    rt = setup_linear_reaction_tree()
    scorer = StateScorer(default_config)

    scores = scorer([rt, rt])

    assert round(scores[0], 4) == 0.994
    assert round(scores[1], 4) == 0.994
Exemplo n.º 9
0
def test_scoring_branched_mcts_tree(shared_datadir, default_config):
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_with_branching.json", default_config)
    nodes = list(search_tree.graph())

    assert pytest.approx(StateScorer()(nodes[-1]), abs=1e-6) == 0.00012363
    assert NumberOfReactionsScorer()(nodes[-1]) == 14
    assert NumberOfPrecursorsScorer(default_config)(nodes[-1]) == 8
    assert NumberOfPrecursorsInStockScorer(default_config)(nodes[-1]) == 0
Exemplo n.º 10
0
def test_state_scorer_nodes(generate_root, default_config):
    root = generate_root("CCCCOc1ccc(CC(=O)N(C)O)cc1")
    scorer = StateScorer(default_config)

    scores = scorer([root, root])

    assert repr(scorer) == "state score"
    assert round(scores[0], 4) == 0.0491
    assert round(scores[1], 4) == 0.0491
Exemplo n.º 11
0
def test_scorers_one_mcts_node(default_config):
    tree = SearchTree(default_config, root_smiles="CCCCOc1ccc(CC(=O)N(C)O)cc1")
    node = tree.root

    assert pytest.approx(StateScorer(default_config)(node), abs=1e-3) == 0.0497
    assert NumberOfReactionsScorer(default_config)(node) == 0
    assert NumberOfPrecursorsScorer(default_config)(node) == 1
    assert NumberOfPrecursorsInStockScorer(default_config)(node) == 0
    assert PriceSumScorer(default_config)(node) == 10
    assert RouteCostScorer(default_config)(node) == 10
Exemplo n.º 12
0
def test_rescore_collection_for_trees(default_config, setup_linear_reaction_tree):
    rt = setup_linear_reaction_tree()
    routes = RouteCollection(reaction_trees=[rt])
    routes.compute_scores(StateScorer(default_config))

    routes.rescore(NumberOfReactionsScorer())

    assert routes.scores[0] == 2
    assert np.round(routes.all_scores[0]["state score"], 3) == 0.994
    assert routes.all_scores[0]["number of reactions"] == 2
Exemplo n.º 13
0
def test_scoring_branched_route(default_config, setup_branched_reaction_tree):
    tree = setup_branched_reaction_tree()

    assert pytest.approx(StateScorer(default_config)(tree), abs=1e-4) == 0.9866
    assert NumberOfReactionsScorer()(tree) == 4
    assert NumberOfPrecursorsScorer(default_config)(tree) == 5
    assert NumberOfPrecursorsInStockScorer(default_config)(tree) == 5
    assert PriceSumScorer(default_config)(tree) == 5
    cost_score = RouteCostScorer(default_config)(tree)
    assert pytest.approx(cost_score, abs=1e-4) == 13.6563
Exemplo n.º 14
0
 def __init__(
     self,
     search_tree: Union[MctsSearchTree, AndOrSearchTreeBase],
     scorer: Scorer = None,
 ) -> None:
     self.search_tree = search_tree
     if scorer is None:
         self.scorer: Scorer = StateScorer(search_tree.config)
     else:
         self.scorer = scorer
Exemplo n.º 15
0
def test_state_scorer_nodes(setup_linear_mcts, setup_branched_mcts,
                            default_config):
    _, node1 = setup_linear_mcts()
    _, node2 = setup_branched_mcts()
    scorer = StateScorer(default_config)

    scores = scorer([node1, node2])

    assert repr(scorer) == "state score"
    assert round(scores[0], 4) == 0.994
    assert round(scores[1], 4) == 0.9866
Exemplo n.º 16
0
def test_scorers_tree_one_node_route(default_config):
    tree = ReactionTree()
    tree.root = UniqueMolecule(smiles="CCCCOc1ccc(CC(=O)N(C)O)cc1")
    tree.graph.add_node(tree.root)

    assert pytest.approx(StateScorer(default_config)(tree), abs=1e-3) == 0.0497
    assert NumberOfReactionsScorer(default_config)(tree) == 0
    assert NumberOfPrecursorsScorer(default_config)(tree) == 1
    assert NumberOfPrecursorsInStockScorer(default_config)(tree) == 0
    assert PriceSumScorer(default_config)(tree) == 10
    assert RouteCostScorer(default_config)(tree) == 10
Exemplo n.º 17
0
def test_scoring_branched_route_not_in_stock(default_config,
                                             setup_branched_reaction_tree):
    tree = setup_branched_reaction_tree("O")

    assert pytest.approx(StateScorer(default_config)(tree), abs=1e-4) == 0.7966
    assert NumberOfReactionsScorer()(tree) == 4
    assert NumberOfPrecursorsScorer(default_config)(tree) == 5
    assert NumberOfPrecursorsInStockScorer(default_config)(tree) == 4
    assert PriceSumScorer(default_config)(tree) == 14
    cost_score = RouteCostScorer(default_config)(tree)
    assert pytest.approx(cost_score, abs=1e-4) == 31.2344
Exemplo n.º 18
0
def test_rescore_collection_for_trees(default_config, mock_stock,
                                      load_reaction_tree):
    mock_stock(default_config, "N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1",
               "CN1CCC(Cl)CC1", "O")
    rt = ReactionTree.from_dict(load_reaction_tree("sample_reaction.json"))
    routes = RouteCollection(reaction_trees=[rt])
    routes.compute_scores(StateScorer(default_config))

    routes.rescore(NumberOfReactionsScorer())

    assert routes.scores[0] == 2
    assert np.round(routes.all_scores[0]["state score"], 3) == 0.994
    assert routes.all_scores[0]["number of reactions"] == 2
Exemplo n.º 19
0
def test_compute_new_score_for_trees(default_config, setup_linear_reaction_tree):
    rt = setup_linear_reaction_tree()
    routes = RouteCollection(reaction_trees=[rt])

    assert routes.nodes[0] is None
    assert routes.scores[0] is np.nan
    assert routes.all_scores[0] == {}

    routes.compute_scores(StateScorer(default_config), NumberOfReactionsScorer())

    assert routes.scores[0] is np.nan
    assert np.round(routes.all_scores[0]["state score"], 3) == 0.994
    assert routes.all_scores[0]["number of reactions"] == 2
Exemplo n.º 20
0
def test_scoring_branched_route_in_stock(load_reaction_tree, default_config,
                                         mock_stock):
    mock_stock(
        default_config,
        "CC(C)(C)CO",
        "CC(C)(C)OC(=O)N(CCCl)CCCl",
        "N#CCc1cccc(O)c1F",
        "O=[N+]([O-])c1ccccc1F",
        "O=C1CCC(=O)N1Br",
        "O=C=Nc1csc(C(F)(F)F)n1",
        "CCC[Sn](Cl)(CCC)CCC",
        "COc1ccc2ncsc2c1",
    )
    tree = ReactionTree.from_dict(load_reaction_tree("branched_route.json"))

    assert pytest.approx(StateScorer(default_config)(tree), abs=1e-3) == 0.950
    assert NumberOfReactionsScorer(default_config)(tree) == 14
    assert NumberOfPrecursorsScorer(default_config)(tree) == 8
    assert NumberOfPrecursorsInStockScorer(default_config)(tree) == 8
    assert PriceSumScorer(default_config)(tree) == 8
    cost_score = RouteCostScorer(default_config)(tree)
    assert pytest.approx(cost_score, abs=1e-3) == 77.4797
Exemplo n.º 21
0
def test_state_scorer_node(default_config, setup_linear_mcts):
    _, node = setup_linear_mcts()
    scorer = StateScorer(default_config)

    assert repr(scorer) == "state score"
    assert round(scorer(node), 4) == 0.994
Exemplo n.º 22
0
def test_state_scorer_tree(default_config, setup_linear_reaction_tree):
    tree = setup_linear_reaction_tree()
    scorer = StateScorer(default_config)

    assert round(scorer(tree), 4) == 0.994
Exemplo n.º 23
0
class TreeAnalysis:
    """
    Class that encapsulate various analysis that can be
    performed on a search tree.

    :ivar scorer: the object used to score the nodes
    :type scorer: Scorer
    :ivar search_tree: the search tree
    :vartype search_tree: SearchTree

    :param scorer: the object used to score the nodes, defaults to StateScorer
    :type scorer: Scorer, optional
    :parameter search_tree: the search tree to do the analysis on
    :type search_tree: SearchTree
    """
    def __init__(self, search_tree, scorer=None):
        self.search_tree = search_tree
        if scorer is None:
            # Do import here to avoid circular imports
            from aizynthfinder.context.scoring import StateScorer

            self.scorer = StateScorer()
        else:
            self.scorer = scorer

    def best_node(self):
        """
        Returns the node with the highest score.
        If several nodes have the same score, it will return the first

        :return: the top scoring node
        :rtype: Node
        """
        nodes = self._all_nodes()
        sorted_nodes, _ = self.scorer.sort(nodes)
        return sorted_nodes[0]

    def sort_nodes(self, min_return=5, max_return=25):
        """
        Sort and select the nodes, so that the best scoring routes are returned.
        The algorithm filter away identical routes and returns at minimum the number specified.
        If multiple alternative routes have the same score as the n'th route, they will be included and returned.

        :param min_return: the minium number of routes to return, defaults to 5
        :type min_return: int, optional
        :param max_return: the maximum number of routes to return
        :type max_return: int, optional
        :return: the nodes
        :rtype: list of Node
        :return: the score
        :rtype: list of float
        """
        nodes = self._all_nodes()
        sorted_nodes, sorted_scores = self.scorer.sort(nodes)

        if len(nodes) <= min_return:
            return sorted_nodes, sorted_scores

        seen_hashes = set()
        best_nodes = []
        best_scores = []
        last_score = 1e16
        for score, node in zip(sorted_scores, sorted_nodes):
            if len(best_nodes) >= min_return and score < last_score:
                break
            route_actions, _ = self.search_tree.route_to_node(node)
            route_hash = hash_reactions(route_actions)

            if route_hash in seen_hashes:
                continue
            seen_hashes.add(route_hash)
            best_nodes.append(node)
            best_scores.append(score)
            last_score = score

            if max_return and len(best_nodes) == max_return:
                break

        return best_nodes, best_scores

    def tree_statistics(self):
        """
        Returns statiscs of the tree

        Currently it returns the number of nodes, the maximum number of transforms,
        maximum number of children, top score, if the top score route is solved,
        the number of molecule in the top score node, and information on pre-cursors

        :return: the statistics
        :rtype: dict
        """
        top_node = self.best_node()
        top_state = top_node.state
        nodes = list(self.search_tree.graph())
        mols_in_stock = ", ".join(
            mol.smiles
            for mol, instock in zip(top_state.mols, top_state.in_stock_list)
            if instock)
        mols_not_in_stock = ", ".join(
            mol.smiles
            for mol, instock in zip(top_state.mols, top_state.in_stock_list)
            if not instock)

        policy_used_counts = defaultdict(lambda: 0)
        for node in nodes:
            for child in node.children():
                policy_used = node[child]["action"].metadata.get("policy_name")
                if policy_used:
                    policy_used_counts[policy_used] += 1

        return {
            "number_of_nodes":
            len(nodes),
            "max_transforms":
            max(node.state.max_transforms for node in nodes),
            "max_children":
            max(len(node.children()) for node in nodes),
            "number_of_leafs":
            sum(1 for node in nodes if not node.children()),
            "number_of_solved_leafs":
            sum(1 for node in nodes
                if not node.children() and node.state.is_solved),
            "top_score":
            self.scorer(top_node),
            "is_solved":
            top_state.is_solved,
            "number_of_steps":
            top_state.max_transforms,
            "number_of_precursors":
            len(top_state.mols),
            "number_of_precursors_in_stock":
            sum(top_state.in_stock_list),
            "precursors_in_stock":
            mols_in_stock,
            "precursors_not_in_stock":
            mols_not_in_stock,
            "policy_used_counts":
            dict(policy_used_counts),
        }

    def _all_nodes(self):
        # This is to keep backwards compatibility, this should be investigate further
        if repr(self.scorer) == "state score":
            return list(self.search_tree.graph())
        return [
            node for node in self.search_tree.graph() if not node.children()
        ]
Exemplo n.º 24
0
def test_state_scorer_node(generate_root, default_config):
    root = generate_root("CCCCOc1ccc(CC(=O)N(C)O)cc1")
    scorer = StateScorer(default_config)

    assert repr(scorer) == "state score"
    assert round(scorer(root), 4) == 0.0491