def test_clustering_collection_timeout(load_reaction_tree): collection = RouteCollection(reaction_trees=[ ReactionTree.from_dict( load_reaction_tree("routes_for_clustering.json", idx)) for idx in range(3) ]) cluster_labels = collection.cluster(n_clusters=1, timeout=0) assert len(cluster_labels) == 0 assert collection.clusters is None
def test_compute_new_score_for_trees(default_config, setup_linear_reaction_tree): rt = setup_linear_reaction_tree() routes = RouteCollection(reaction_trees=[rt]) assert routes.nodes[0] is None assert routes.scores[0] is np.nan assert routes.all_scores[0] == {} routes.compute_scores(StateScorer(default_config), NumberOfReactionsScorer()) assert routes.scores[0] is np.nan assert np.round(routes.all_scores[0]["state score"], 3) == 0.994 assert routes.all_scores[0]["number of reactions"] == 2
def test_clustering_collection(load_reaction_tree): collection = RouteCollection(reaction_trees=[ ReactionTree.from_dict( load_reaction_tree("routes_for_clustering.json", idx)) for idx in range(3) ]) collection.cluster(n_clusters=1) assert len(collection.clusters) == 2 assert collection.clusters[0].reaction_trees == collection.reaction_trees[ 1:3] assert collection.clusters[1].reaction_trees == [ collection.reaction_trees[0] ]
def test_create_combine_tree_dict_from_tree(mock_stock, default_config, load_reaction_tree, shared_datadir): mock_stock( default_config, "Nc1ccc(NC(=S)Nc2ccccc2)cc1", "Cc1ccc2nc3ccccc3c(Cl)c2c1", "Nc1ccc(N)cc1", "S=C=Nc1ccccc1", "Cc1ccc2nc3ccccc3c(N)c2c1", "Nc1ccc(Br)cc1", ) search_tree = SearchTree.from_json( shared_datadir / "tree_for_clustering.json", default_config) analysis = TreeAnalysis(search_tree) collection = RouteCollection.from_analysis(analysis, 3) expected = load_reaction_tree("combined_example_tree.json") combined_dict = collection.combined_reaction_trees().to_dict() assert len(combined_dict["children"]) == 2 assert combined_dict["children"][0]["is_reaction"] assert len(combined_dict["children"][0]["children"]) == 2 assert len(combined_dict["children"][1]["children"]) == 2 assert len(combined_dict["children"][1]["children"][0]["children"]) == 2 assert combined_dict["children"][1]["children"][0]["children"][0][ "is_reaction"] assert combined_dict == expected
def test_create_route_collection_andor_tree(setup_analysis_andor_tree): analysis = setup_analysis_andor_tree() routes = RouteCollection.from_analysis(analysis) assert len(routes) == 3 assert routes.nodes == [None, None, None]
def test_create_route_collection_full(setup_analysis, mocker): analysis, _ = setup_analysis() routes = RouteCollection.from_analysis(analysis, 5) assert len(routes) == 7 # Check a few of the routes assert np.round(routes.scores[0], 3) == 0.994 assert len(routes.reaction_trees[0].graph) == 8 assert np.round(routes.scores[1], 3) == 0.681 assert len(routes.reaction_trees[1].graph) == 5 assert "dict" not in routes[0] assert "json" not in routes[0] assert "image" not in routes[0] mocker.patch("aizynthfinder.analysis.ReactionTree.to_dict") mocker.patch("aizynthfinder.analysis.json.dumps") mocker.patch("aizynthfinder.utils.image.GraphvizReactionGraph.to_image") mocker.patch( "aizynthfinder.utils.image.GraphvizReactionGraph.add_molecule") # Just see that the code does not crash, does not verify content assert len(routes.images) == 7 assert len(routes.dicts) == 7 assert len(routes.jsons) == 7
def test_compute_new_score_for_trees(default_config, mock_stock, load_reaction_tree): mock_stock(default_config, "N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1", "CN1CCC(Cl)CC1", "O") rt = ReactionTree.from_dict(load_reaction_tree("sample_reaction.json")) routes = RouteCollection(reaction_trees=[rt]) assert routes.nodes[0] is None assert routes.scores[0] is np.nan assert routes.all_scores[0] == {} routes.compute_scores(StateScorer(default_config), NumberOfReactionsScorer()) assert routes.scores[0] is np.nan assert np.round(routes.all_scores[0]["state score"], 3) == 0.994 assert routes.all_scores[0]["number of reactions"] == 2
def prepare_tree(self) -> None: """ Setup the tree for searching :raises ValueError: if the target molecule was not set """ if not self.target_mol: raise ValueError("No target molecule set") self.stock.reset_exclusion_list() if self.config.exclude_target_from_stock and self.target_mol in self.stock: self.stock.exclude(self.target_mol) self._logger.debug("Excluding the target compound from the stock") self._setup_search_tree() self.analysis = None self.routes = RouteCollection([])
def test_create_combine_tree_to_visjs(load_reaction_tree, tmpdir): collection = RouteCollection( reaction_trees=[ ReactionTree.from_dict(load_reaction_tree("routes_for_clustering.json", 0)), ReactionTree.from_dict(load_reaction_tree("routes_for_clustering.json", 1)), ReactionTree.from_dict(load_reaction_tree("routes_for_clustering.json", 2)), ] ) tar_filename = str(tmpdir / "routes.tar") combined = collection.combined_reaction_trees() combined.to_visjs_page(tar_filename) assert os.path.exists(tar_filename) with TarFile(tar_filename) as tarobj: assert "./route.html" in tarobj.getnames() assert len([name for name in tarobj.getnames() if name.endswith(".png")]) == 8
def test_dict_with_scores(setup_analysis): analysis, _ = setup_analysis() routes = RouteCollection.from_analysis(analysis, 5) dicts = routes.dict_with_scores() assert "scores" not in routes.dicts[0] assert "scores" in dicts[0] assert np.round(dicts[0]["scores"]["state score"], 3) == 0.994
def test_create_combine_tree_dict_from_json(load_reaction_tree): collection = RouteCollection( reaction_trees=[ ReactionTree.from_dict(load_reaction_tree("routes_for_clustering.json", 0)), ReactionTree.from_dict(load_reaction_tree("routes_for_clustering.json", 1)), ReactionTree.from_dict(load_reaction_tree("routes_for_clustering.json", 2)), ] ) expected = load_reaction_tree("combined_example_tree.json") combined_dict = collection.combined_reaction_trees().to_dict() assert len(combined_dict["children"]) == 2 assert combined_dict["children"][0]["is_reaction"] assert len(combined_dict["children"][0]["children"]) == 2 assert len(combined_dict["children"][1]["children"]) == 2 assert len(combined_dict["children"][1]["children"][1]["children"]) == 2 assert combined_dict["children"][1]["children"][1]["children"][0]["is_reaction"] assert combined_dict == expected
def test_distance_collection(load_reaction_tree): collection = RouteCollection(reaction_trees=[ ReactionTree.from_dict( load_reaction_tree("routes_for_clustering.json", idx)) for idx in range(3) ]) dist_mat1 = collection.distance_matrix() dist_mat2 = collection.distance_matrix(recreate=True) assert (dist_mat1 - dist_mat2).sum() == 0 dist_mat3 = collection.distance_matrix(content="molecules") assert (dist_mat1 - dist_mat3).sum() != 0 assert len(dist_mat3) == 3 assert pytest.approx(dist_mat3[0, 1], abs=1e-2) == 2.6522 assert pytest.approx(dist_mat3[0, 2], abs=1e-2) == 3.0779 assert pytest.approx(dist_mat3[2, 1], abs=1e-2) == 0.7483
def __init__(self, configfile: str = None, configdict: StrDict = None) -> None: self._logger = logger() if configfile: self.config = Configuration.from_file(configfile) elif configdict: self.config = Configuration.from_dict(configdict) else: self.config = Configuration() self.expansion_policy = self.config.expansion_policy self.filter_policy = self.config.filter_policy self.stock = self.config.stock self.scorers = self.config.scorers self.tree: Optional[SearchTree] = None self._target_mol: Optional[Molecule] = None self.search_stats: StrDict = dict() self.routes = RouteCollection([]) self.analysis: Optional[TreeAnalysis] = None
def test_create_clustering_gui(mocker, load_reaction_tree): collection = RouteCollection(reaction_trees=[ ReactionTree.from_dict( load_reaction_tree("routes_for_clustering.json", idx)) for idx in range(3) ]) display_patch = mocker.patch( "aizynthfinder.interfaces.gui.clustering.display") ClusteringGui(collection) display_patch.assert_called()
def build_routes(self, min_nodes=5): """ Build reaction routes This is necessary to call after the tree search has completed in order to extract results from the tree search. :param min_nodes: the minimum number of top-ranked nodes to consider, defaults to 5 :type min_nodes: int, optional """ self.analysis = TreeAnalysis(self.tree) self.routes = RouteCollection.from_analysis(self.analysis, min_nodes)
def test_rescore_collection(setup_analysis): analysis, _ = setup_analysis() routes = RouteCollection.from_analysis(analysis, 5) routes.rescore(NumberOfReactionsScorer()) assert routes.scores[0] == 1 assert np.round(routes.all_scores[0]["state score"], 3) == 0.681 assert routes.all_scores[0]["number of reactions"] == 1 assert np.round(routes.all_scores[1]["state score"], 3) == 0.523 assert routes.scores[1] == 1 assert routes.all_scores[1]["number of reactions"] == 1
def test_compute_new_score(setup_analysis): analysis, _ = setup_analysis() routes = RouteCollection.from_analysis(analysis, 5) routes.compute_scores(NumberOfReactionsScorer()) assert np.round(routes.scores[0], 3) == 0.994 assert np.round(routes.all_scores[0]["state score"], 3) == 0.994 assert routes.all_scores[0]["number of reactions"] == 2 assert np.round(routes.scores[1], 3) == 0.681 assert np.round(routes.all_scores[1]["state score"], 3) == 0.681 assert routes.all_scores[1]["number of reactions"] == 1
def build_routes(self, min_nodes=5, scorer="state score"): """ Build reaction routes This is necessary to call after the tree search has completed in order to extract results from the tree search. :param min_nodes: the minimum number of top-ranked nodes to consider, defaults to 5 :type min_nodes: int, optional :param scorer: the object used to score the nodes :type scorer: str, optional """ self.analysis = TreeAnalysis(self.tree, scorer=self.scorers[scorer]) self.routes = RouteCollection.from_analysis(self.analysis, min_nodes)
def test_rescore_collection_for_trees(default_config, setup_linear_reaction_tree): rt = setup_linear_reaction_tree() routes = RouteCollection(reaction_trees=[rt]) routes.compute_scores(StateScorer(default_config)) routes.rescore(NumberOfReactionsScorer()) assert routes.scores[0] == 2 assert np.round(routes.all_scores[0]["state score"], 3) == 0.994 assert routes.all_scores[0]["number of reactions"] == 2
def build_routes(self, min_nodes: int = 5, scorer: str = "state score") -> None: """ Build reaction routes This is necessary to call after the tree search has completed in order to extract results from the tree search. :param min_nodes: the minimum number of top-ranked nodes to consider, defaults to 5 :param scorer: a reference to the object used to score the nodes :raises ValueError: if the search tree not initialized """ if not self.tree: raise ValueError("Search tree not initialized") self.analysis = TreeAnalysis(self.tree, scorer=self.scorers[scorer]) self.routes = RouteCollection.from_analysis(self.analysis, min_nodes)
def build_routes(self, selection: RouteSelectionArguments = None, scorer: str = "state score") -> None: """ Build reaction routes This is necessary to call after the tree search has completed in order to extract results from the tree search. :param selection: the selection criteria for the routes :param scorer: a reference to the object used to score the nodes :raises ValueError: if the search tree not initialized """ if not self.tree: raise ValueError("Search tree not initialized") self.analysis = TreeAnalysis(self.tree, scorer=self.scorers[scorer]) self.routes = RouteCollection.from_analysis(self.analysis, selection)
def test_create_route_collection(setup_complete_tree, mocker): tree, nodes = setup_complete_tree analysis = TreeAnalysis(tree) mocker.patch("aizynthfinder.analysis.ReactionTree.to_dict") mocker.patch("aizynthfinder.analysis.json.dumps") routes = RouteCollection.from_analysis(analysis, 5) assert len(routes) == 3 assert routes[0]["score"] == 0.99 assert routes[0]["node"] is nodes[2] reaction_nodes = [ node for node in routes[0]["reaction_tree"].graph if isinstance(node, Reaction) ] assert len(reaction_nodes) == 2 # Just see that the code does not crash, does not verify content assert len(routes.images) == 3 assert len(routes.dicts) == 3 assert len(routes.jsons) == 3
class AiZynthFinder: """ Public API to the aizynthfinder tool If instantiated with the path to a yaml file or dictionary of settings the stocks and policy networks are loaded directly. Otherwise, the user is responsible for loading them prior to executing the tree search. :ivar config: the configuration of the search :ivar expansion_policy: the expansion policy model :ivar filter_policy: the filter policy model :ivar stock: the stock :ivar scorers: the loaded scores :ivar tree: the search tree :ivar analysis: the tree analysis :ivar routes: the top-ranked routes :ivar search_stats: statistics of the latest search :param configfile: the path to yaml file with configuration (has priority over configdict), defaults to None :param configdict: the config as a dictionary source, defaults to None """ def __init__(self, configfile: str = None, configdict: StrDict = None) -> None: self._logger = logger() if configfile: self.config = Configuration.from_file(configfile) elif configdict: self.config = Configuration.from_dict(configdict) else: self.config = Configuration() self.expansion_policy = self.config.expansion_policy self.filter_policy = self.config.filter_policy self.stock = self.config.stock self.scorers = self.config.scorers self.tree: Optional[Union[MctsSearchTree, AndOrSearchTreeBase]] = None self._target_mol: Optional[Molecule] = None self.search_stats: StrDict = dict() self.routes = RouteCollection([]) self.analysis: Optional[TreeAnalysis] = None @property def target_smiles(self) -> str: """The SMILES representation of the molecule to predict routes on.""" if not self._target_mol: return "" return self._target_mol.smiles @target_smiles.setter def target_smiles(self, smiles: str) -> None: self.target_mol = Molecule(smiles=smiles) @property def target_mol(self) -> Optional[Molecule]: """The molecule to predict routes on""" return self._target_mol @target_mol.setter def target_mol(self, mol: Molecule) -> None: self.tree = None self._target_mol = mol def build_routes(self, min_nodes: int = 5, scorer: str = "state score") -> None: """ Build reaction routes This is necessary to call after the tree search has completed in order to extract results from the tree search. :param min_nodes: the minimum number of top-ranked nodes to consider, defaults to 5 :param scorer: a reference to the object used to score the nodes :raises ValueError: if the search tree not initialized """ if not self.tree: raise ValueError("Search tree not initialized") self.analysis = TreeAnalysis(self.tree, scorer=self.scorers[scorer]) self.routes = RouteCollection.from_analysis(self.analysis, min_nodes) def extract_statistics(self) -> StrDict: """Extracts tree statistics as a dictionary""" if not self.analysis: return {} stats = { "target": self.target_smiles, "search_time": self.search_stats["time"], "first_solution_time": self.search_stats.get("first_solution_time", 0), "first_solution_iteration": self.search_stats.get("first_solution_iteration", 0), } stats.update(self.analysis.tree_statistics()) return stats def prepare_tree(self) -> None: """ Setup the tree for searching :raises ValueError: if the target molecule was not set """ if not self.target_mol: raise ValueError("No target molecule set") self.stock.reset_exclusion_list() if self.config.exclude_target_from_stock and self.target_mol in self.stock: self.stock.exclude(self.target_mol) self._logger.debug("Excluding the target compound from the stock") self._setup_search_tree() self.analysis = None self.routes = RouteCollection([]) @deprecated(version="2.1.0", reason="Not supported anymore") def run_from_json(self, params: StrDict) -> StrDict: """ Run a search tree by reading settings from a JSON :param params: the parameters of the tree search :return: dictionary with all settings and top scored routes """ self.stock.select(params["stocks"]) self.expansion_policy.select( params.get("policy", params.get("policies", ""))) if "filter" in params: self.filter_policy.select(params["filter"]) else: self.filter_policy.deselect() self.config.C = params["C"] self.config.max_transforms = params["max_transforms"] self.config.cutoff_cumulative = params["cutoff_cumulative"] self.config.cutoff_number = params["cutoff_number"] self.target_smiles = params["smiles"] self.config.return_first = params["return_first"] self.config.time_limit = params["time_limit"] self.config.iteration_limit = params["iteration_limit"] self.config.exclude_target_from_stock = params[ "exclude_target_from_stock"] self.config.filter_cutoff = params["filter_cutoff"] self.prepare_tree() self.tree_search() self.build_routes() if not params.get("score_trees", False): return { "request": self._get_settings(), "trees": self.routes.dicts, } self.routes.compute_scores(*self.scorers.objects()) return { "request": self._get_settings(), "trees": self.routes.dict_with_scores(), } def tree_search(self, show_progress: bool = False) -> float: """ Perform the actual tree search :param show_progress: if True, shows a progress bar :return: the time past in seconds """ if not self.tree: self.prepare_tree() assert (self.tree is not None ) # This is for type checking, prepare_tree is creating it. self.search_stats = {"returned_first": False, "iterations": 0} time0 = time.time() i = 1 self._logger.debug("Starting search") time_past = time.time() - time0 if show_progress: pbar = tqdm(total=self.config.iteration_limit, leave=False) while time_past < self.config.time_limit and i <= self.config.iteration_limit: if show_progress: pbar.update(1) self.search_stats["iterations"] += 1 try: is_solved = self.tree.one_iteration() except StopIteration: break if is_solved and "first_solution_time" not in self.search_stats: self.search_stats["first_solution_time"] = time.time() - time0 self.search_stats["first_solution_iteration"] = i if self.config.return_first and is_solved: self._logger.debug("Found first solved route") self.search_stats["returned_first"] = True break i = i + 1 time_past = time.time() - time0 if show_progress: pbar.close() time_past = time.time() - time0 self._logger.debug("Search completed") self.search_stats["time"] = time_past return time_past def _get_settings(self) -> StrDict: """Get the current settings as a dictionary""" # To be backward-compatible if (self.expansion_policy.selection and len(self.expansion_policy.selection) == 1): policy_value = self.expansion_policy.selection[0] policy_key = "policy" else: policy_value = self.expansion_policy.selection # type: ignore policy_key = "policies" dict_ = { "stocks": self.stock.selection, policy_key: policy_value, "C": self.config.C, "max_transforms": self.config.max_transforms, "cutoff_cumulative": self.config.cutoff_cumulative, "cutoff_number": self.config.cutoff_number, "smiles": self.target_smiles, "return_first": self.config.return_first, "time_limit": self.config.time_limit, "iteration_limit": self.config.iteration_limit, "exclude_target_from_stock": self.config.exclude_target_from_stock, "filter_cutoff": self.config.filter_cutoff, } if self.filter_policy.selection: dict_["filter"] = self.filter_policy.selection return dict_ def _setup_search_tree(self): self._logger.debug("Defining tree root: %s" % self.target_smiles) if self.config.search_algorithm.lower() == "mcts": self.tree = MctsSearchTree(root_smiles=self.target_smiles, config=self.config) else: module_name, cls_name = self.config.search_algorithm.rsplit( ".", maxsplit=1) try: module_obj = importlib.import_module(module_name) except ImportError: raise ValueError(f"Could not import module {module_name}") if not hasattr(module_obj, cls_name): raise ValueError( f"Could not identify class {cls_name} in module") self.tree: AndOrSearchTreeBase = getattr(module_obj, cls_name)( root_smiles=self.target_smiles, config=self.config)