def test_sort(default_config, setup_linear_mcts, setup_branched_mcts): _, node1 = setup_linear_mcts() _, node2 = setup_branched_mcts() scorer = StateScorer(default_config) sorted_nodes, scores, _ = scorer.sort([node2, node1]) assert [np.round(score, 4) for score in scores] == [0.994, 0.9866] assert sorted_nodes == [node1, node2]
def __init__(self, search_tree, scorer=None): self.search_tree = search_tree if scorer is None: # Do import here to avoid circular imports from aizynthfinder.context.scoring import StateScorer self.scorer = StateScorer() else: self.scorer = scorer
def test_sort(shared_datadir, default_config, mock_stock): mock_stock(default_config, "CCCO", "CC") search_tree = SearchTree.from_json( shared_datadir / "tree_without_repetition.json", default_config) nodes = list(search_tree.graph()) scorer = StateScorer(default_config) sorted_nodes, scores, _ = scorer.sort(nodes) assert [np.round(score, 4) for score in scores] == [0.9976, 0.0491] assert sorted_nodes == [nodes[1], nodes[0]]
def test_scoring_branch_mcts_tree_in_stock(shared_datadir, default_config, mock_stock): mock_stock( default_config, "CC(C)(C)CO", "CC(C)(C)OC(=O)N(CCCl)CCCl", "N#CCc1cccc(O)c1F", "O=[N+]([O-])c1ccccc1F", "O=C1CCC(=O)N1Br", "O=C=Nc1csc(C(F)(F)F)n1", "CCC[Sn](Cl)(CCC)CCC", "COc1ccc2ncsc2c1", ) search_tree = SearchTree.from_json( shared_datadir / "tree_with_branching.json", default_config) nodes = list(search_tree.graph()) assert pytest.approx(StateScorer(default_config)(nodes[-1]), abs=1e-3) == 0.950 assert NumberOfReactionsScorer()(nodes[-1]) == 14 assert NumberOfPrecursorsScorer(default_config)(nodes[-1]) == 8 assert NumberOfPrecursorsInStockScorer(default_config)(nodes[-1]) == 8 assert PriceSumScorer(default_config)(nodes[-1]) == 8 cost_score = RouteCostScorer(default_config)(nodes[-1]) assert pytest.approx(cost_score, abs=1e-3) == 77.4797
def test_add_scorer_to_collection(default_config): collection = ScorerCollection(default_config) del collection["state score"] collection.load(StateScorer(default_config)) assert "state score" in collection.names()
def test_state_scorer_tree(load_reaction_tree, default_config, mock_stock): mock_stock(default_config, "N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1", "CN1CCC(Cl)CC1", "O") tree = ReactionTree.from_dict(load_reaction_tree("sample_reaction.json")) scorer = StateScorer(default_config) assert round(scorer(tree), 4) == 0.994
def test_scoring_branched_route(load_reaction_tree, default_config): tree = ReactionTree.from_dict(load_reaction_tree("branched_route.json")) assert pytest.approx(StateScorer(default_config)(tree), abs=1e-6) == 0.00012363 assert NumberOfReactionsScorer(default_config)(tree) == 14 assert NumberOfPrecursorsScorer(default_config)(tree) == 8 assert NumberOfPrecursorsInStockScorer(default_config)(tree) == 0
def test_state_scorer_trees(default_config, setup_linear_reaction_tree): rt = setup_linear_reaction_tree() scorer = StateScorer(default_config) scores = scorer([rt, rt]) assert round(scores[0], 4) == 0.994 assert round(scores[1], 4) == 0.994
def test_scoring_branched_mcts_tree(shared_datadir, default_config): search_tree = SearchTree.from_json( shared_datadir / "tree_with_branching.json", default_config) nodes = list(search_tree.graph()) assert pytest.approx(StateScorer()(nodes[-1]), abs=1e-6) == 0.00012363 assert NumberOfReactionsScorer()(nodes[-1]) == 14 assert NumberOfPrecursorsScorer(default_config)(nodes[-1]) == 8 assert NumberOfPrecursorsInStockScorer(default_config)(nodes[-1]) == 0
def test_state_scorer_nodes(generate_root, default_config): root = generate_root("CCCCOc1ccc(CC(=O)N(C)O)cc1") scorer = StateScorer(default_config) scores = scorer([root, root]) assert repr(scorer) == "state score" assert round(scores[0], 4) == 0.0491 assert round(scores[1], 4) == 0.0491
def test_scorers_one_mcts_node(default_config): tree = SearchTree(default_config, root_smiles="CCCCOc1ccc(CC(=O)N(C)O)cc1") node = tree.root assert pytest.approx(StateScorer(default_config)(node), abs=1e-3) == 0.0497 assert NumberOfReactionsScorer(default_config)(node) == 0 assert NumberOfPrecursorsScorer(default_config)(node) == 1 assert NumberOfPrecursorsInStockScorer(default_config)(node) == 0 assert PriceSumScorer(default_config)(node) == 10 assert RouteCostScorer(default_config)(node) == 10
def test_rescore_collection_for_trees(default_config, setup_linear_reaction_tree): rt = setup_linear_reaction_tree() routes = RouteCollection(reaction_trees=[rt]) routes.compute_scores(StateScorer(default_config)) routes.rescore(NumberOfReactionsScorer()) assert routes.scores[0] == 2 assert np.round(routes.all_scores[0]["state score"], 3) == 0.994 assert routes.all_scores[0]["number of reactions"] == 2
def test_scoring_branched_route(default_config, setup_branched_reaction_tree): tree = setup_branched_reaction_tree() assert pytest.approx(StateScorer(default_config)(tree), abs=1e-4) == 0.9866 assert NumberOfReactionsScorer()(tree) == 4 assert NumberOfPrecursorsScorer(default_config)(tree) == 5 assert NumberOfPrecursorsInStockScorer(default_config)(tree) == 5 assert PriceSumScorer(default_config)(tree) == 5 cost_score = RouteCostScorer(default_config)(tree) assert pytest.approx(cost_score, abs=1e-4) == 13.6563
def __init__( self, search_tree: Union[MctsSearchTree, AndOrSearchTreeBase], scorer: Scorer = None, ) -> None: self.search_tree = search_tree if scorer is None: self.scorer: Scorer = StateScorer(search_tree.config) else: self.scorer = scorer
def test_state_scorer_nodes(setup_linear_mcts, setup_branched_mcts, default_config): _, node1 = setup_linear_mcts() _, node2 = setup_branched_mcts() scorer = StateScorer(default_config) scores = scorer([node1, node2]) assert repr(scorer) == "state score" assert round(scores[0], 4) == 0.994 assert round(scores[1], 4) == 0.9866
def test_scorers_tree_one_node_route(default_config): tree = ReactionTree() tree.root = UniqueMolecule(smiles="CCCCOc1ccc(CC(=O)N(C)O)cc1") tree.graph.add_node(tree.root) assert pytest.approx(StateScorer(default_config)(tree), abs=1e-3) == 0.0497 assert NumberOfReactionsScorer(default_config)(tree) == 0 assert NumberOfPrecursorsScorer(default_config)(tree) == 1 assert NumberOfPrecursorsInStockScorer(default_config)(tree) == 0 assert PriceSumScorer(default_config)(tree) == 10 assert RouteCostScorer(default_config)(tree) == 10
def test_scoring_branched_route_not_in_stock(default_config, setup_branched_reaction_tree): tree = setup_branched_reaction_tree("O") assert pytest.approx(StateScorer(default_config)(tree), abs=1e-4) == 0.7966 assert NumberOfReactionsScorer()(tree) == 4 assert NumberOfPrecursorsScorer(default_config)(tree) == 5 assert NumberOfPrecursorsInStockScorer(default_config)(tree) == 4 assert PriceSumScorer(default_config)(tree) == 14 cost_score = RouteCostScorer(default_config)(tree) assert pytest.approx(cost_score, abs=1e-4) == 31.2344
def test_rescore_collection_for_trees(default_config, mock_stock, load_reaction_tree): mock_stock(default_config, "N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1", "CN1CCC(Cl)CC1", "O") rt = ReactionTree.from_dict(load_reaction_tree("sample_reaction.json")) routes = RouteCollection(reaction_trees=[rt]) routes.compute_scores(StateScorer(default_config)) routes.rescore(NumberOfReactionsScorer()) assert routes.scores[0] == 2 assert np.round(routes.all_scores[0]["state score"], 3) == 0.994 assert routes.all_scores[0]["number of reactions"] == 2
def test_compute_new_score_for_trees(default_config, setup_linear_reaction_tree): rt = setup_linear_reaction_tree() routes = RouteCollection(reaction_trees=[rt]) assert routes.nodes[0] is None assert routes.scores[0] is np.nan assert routes.all_scores[0] == {} routes.compute_scores(StateScorer(default_config), NumberOfReactionsScorer()) assert routes.scores[0] is np.nan assert np.round(routes.all_scores[0]["state score"], 3) == 0.994 assert routes.all_scores[0]["number of reactions"] == 2
def test_scoring_branched_route_in_stock(load_reaction_tree, default_config, mock_stock): mock_stock( default_config, "CC(C)(C)CO", "CC(C)(C)OC(=O)N(CCCl)CCCl", "N#CCc1cccc(O)c1F", "O=[N+]([O-])c1ccccc1F", "O=C1CCC(=O)N1Br", "O=C=Nc1csc(C(F)(F)F)n1", "CCC[Sn](Cl)(CCC)CCC", "COc1ccc2ncsc2c1", ) tree = ReactionTree.from_dict(load_reaction_tree("branched_route.json")) assert pytest.approx(StateScorer(default_config)(tree), abs=1e-3) == 0.950 assert NumberOfReactionsScorer(default_config)(tree) == 14 assert NumberOfPrecursorsScorer(default_config)(tree) == 8 assert NumberOfPrecursorsInStockScorer(default_config)(tree) == 8 assert PriceSumScorer(default_config)(tree) == 8 cost_score = RouteCostScorer(default_config)(tree) assert pytest.approx(cost_score, abs=1e-3) == 77.4797
def test_state_scorer_node(default_config, setup_linear_mcts): _, node = setup_linear_mcts() scorer = StateScorer(default_config) assert repr(scorer) == "state score" assert round(scorer(node), 4) == 0.994
def test_state_scorer_tree(default_config, setup_linear_reaction_tree): tree = setup_linear_reaction_tree() scorer = StateScorer(default_config) assert round(scorer(tree), 4) == 0.994
class TreeAnalysis: """ Class that encapsulate various analysis that can be performed on a search tree. :ivar scorer: the object used to score the nodes :type scorer: Scorer :ivar search_tree: the search tree :vartype search_tree: SearchTree :param scorer: the object used to score the nodes, defaults to StateScorer :type scorer: Scorer, optional :parameter search_tree: the search tree to do the analysis on :type search_tree: SearchTree """ def __init__(self, search_tree, scorer=None): self.search_tree = search_tree if scorer is None: # Do import here to avoid circular imports from aizynthfinder.context.scoring import StateScorer self.scorer = StateScorer() else: self.scorer = scorer def best_node(self): """ Returns the node with the highest score. If several nodes have the same score, it will return the first :return: the top scoring node :rtype: Node """ nodes = self._all_nodes() sorted_nodes, _ = self.scorer.sort(nodes) return sorted_nodes[0] def sort_nodes(self, min_return=5, max_return=25): """ Sort and select the nodes, so that the best scoring routes are returned. The algorithm filter away identical routes and returns at minimum the number specified. If multiple alternative routes have the same score as the n'th route, they will be included and returned. :param min_return: the minium number of routes to return, defaults to 5 :type min_return: int, optional :param max_return: the maximum number of routes to return :type max_return: int, optional :return: the nodes :rtype: list of Node :return: the score :rtype: list of float """ nodes = self._all_nodes() sorted_nodes, sorted_scores = self.scorer.sort(nodes) if len(nodes) <= min_return: return sorted_nodes, sorted_scores seen_hashes = set() best_nodes = [] best_scores = [] last_score = 1e16 for score, node in zip(sorted_scores, sorted_nodes): if len(best_nodes) >= min_return and score < last_score: break route_actions, _ = self.search_tree.route_to_node(node) route_hash = hash_reactions(route_actions) if route_hash in seen_hashes: continue seen_hashes.add(route_hash) best_nodes.append(node) best_scores.append(score) last_score = score if max_return and len(best_nodes) == max_return: break return best_nodes, best_scores def tree_statistics(self): """ Returns statiscs of the tree Currently it returns the number of nodes, the maximum number of transforms, maximum number of children, top score, if the top score route is solved, the number of molecule in the top score node, and information on pre-cursors :return: the statistics :rtype: dict """ top_node = self.best_node() top_state = top_node.state nodes = list(self.search_tree.graph()) mols_in_stock = ", ".join( mol.smiles for mol, instock in zip(top_state.mols, top_state.in_stock_list) if instock) mols_not_in_stock = ", ".join( mol.smiles for mol, instock in zip(top_state.mols, top_state.in_stock_list) if not instock) policy_used_counts = defaultdict(lambda: 0) for node in nodes: for child in node.children(): policy_used = node[child]["action"].metadata.get("policy_name") if policy_used: policy_used_counts[policy_used] += 1 return { "number_of_nodes": len(nodes), "max_transforms": max(node.state.max_transforms for node in nodes), "max_children": max(len(node.children()) for node in nodes), "number_of_leafs": sum(1 for node in nodes if not node.children()), "number_of_solved_leafs": sum(1 for node in nodes if not node.children() and node.state.is_solved), "top_score": self.scorer(top_node), "is_solved": top_state.is_solved, "number_of_steps": top_state.max_transforms, "number_of_precursors": len(top_state.mols), "number_of_precursors_in_stock": sum(top_state.in_stock_list), "precursors_in_stock": mols_in_stock, "precursors_not_in_stock": mols_not_in_stock, "policy_used_counts": dict(policy_used_counts), } def _all_nodes(self): # This is to keep backwards compatibility, this should be investigate further if repr(self.scorer) == "state score": return list(self.search_tree.graph()) return [ node for node in self.search_tree.graph() if not node.children() ]
def test_state_scorer_node(generate_root, default_config): root = generate_root("CCCCOc1ccc(CC(=O)N(C)O)cc1") scorer = StateScorer(default_config) assert repr(scorer) == "state score" assert round(scorer(root), 4) == 0.0491