def test_find_repetetive_patterns_created_tree_no_patterns( default_config, mock_stock, shared_datadir): mock_stock(default_config, Molecule(smiles="CC"), Molecule(smiles="CCCO")) # Try with a short tree (3 nodes, 1 reaction) search_tree = SearchTree.from_json( shared_datadir / "tree_without_repetition.json", default_config) analysis = TreeAnalysis(search_tree) rt = ReactionTree.from_analysis(analysis) assert not rt.has_repeating_patterns hidden_nodes = [ node for node in rt.graph if rt.graph.nodes[node].get("hide", False) ] assert len(hidden_nodes) == 0 # Try with something longer search_tree = SearchTree.from_json( shared_datadir / "tree_without_repetition_longer.json", default_config) analysis = TreeAnalysis(search_tree) rt = ReactionTree.from_analysis(analysis) assert not rt.has_repeating_patterns
def test_find_repetetive_patterns_created_tree(default_config, mock_stock, shared_datadir): mock_stock(default_config, Molecule(smiles="CC"), Molecule(smiles="C")) # Try one with 2 repetetive units search_tree = SearchTree.from_json( shared_datadir / "tree_with_repetition.json", default_config) analysis = TreeAnalysis(search_tree) rt = ReactionTree.from_analysis(analysis) assert rt.has_repeating_patterns hidden_nodes = [ node for node in rt.graph if rt.graph.nodes[node].get("hide", False) ] assert len(hidden_nodes) == 5 # Try one with 3 repetetive units search_tree = SearchTree.from_json( shared_datadir / "tree_with_3_repetitions.json", default_config) analysis = TreeAnalysis(search_tree) rt = ReactionTree.from_analysis(analysis) assert rt.has_repeating_patterns hidden_nodes = [ node for node in rt.graph if rt.graph.nodes[node].get("hide", False) ] assert len(hidden_nodes) == 10
def prepare_tree(self): """ Setup the tree for searching """ self.stock.reset_exclusion_list() if self.config.exclude_target_from_stock and self.target_mol in self.stock: self.stock.exclude(self.target_mol) self._logger.debug("Excluding the target compound from the stock") self._logger.debug("Defining tree root: %s" % self.target_smiles) self.tree = SearchTree(root_smiles=self.target_smiles, config=self.config) self.analysis = None self.routes = None
def test_create_combine_tree_dict_from_tree(mock_stock, default_config, load_reaction_tree, shared_datadir): mock_stock( default_config, "Nc1ccc(NC(=S)Nc2ccccc2)cc1", "Cc1ccc2nc3ccccc3c(Cl)c2c1", "Nc1ccc(N)cc1", "S=C=Nc1ccccc1", "Cc1ccc2nc3ccccc3c(N)c2c1", "Nc1ccc(Br)cc1", ) search_tree = SearchTree.from_json( shared_datadir / "tree_for_clustering.json", default_config) analysis = TreeAnalysis(search_tree) collection = RouteCollection.from_analysis(analysis, 3) expected = load_reaction_tree("combined_example_tree.json") combined_dict = collection.combined_reaction_trees().to_dict() assert len(combined_dict["children"]) == 2 assert combined_dict["children"][0]["is_reaction"] assert len(combined_dict["children"][0]["children"]) == 2 assert len(combined_dict["children"][1]["children"]) == 2 assert len(combined_dict["children"][1]["children"][0]["children"]) == 2 assert combined_dict["children"][1]["children"][0]["children"][0][ "is_reaction"] assert combined_dict == expected
def test_number_of_reaction_scorer_node(shared_datadir, default_config): search_tree = SearchTree.from_json( shared_datadir / "tree_without_repetition.json", default_config) nodes = list(search_tree.graph()) scorer = NumberOfReactionsScorer() assert scorer(nodes[1]) == 1
def test_template_occurence_scorer_no_metadata(shared_datadir, default_config): search_tree = SearchTree.from_json( shared_datadir / "tree_without_repetition.json", default_config) nodes = list(search_tree.graph()) scorer = AverageTemplateOccurenceScorer() assert scorer(nodes[1]) == 0
def test_scoring_branch_mcts_tree_in_stock(shared_datadir, default_config, mock_stock): mock_stock( default_config, "CC(C)(C)CO", "CC(C)(C)OC(=O)N(CCCl)CCCl", "N#CCc1cccc(O)c1F", "O=[N+]([O-])c1ccccc1F", "O=C1CCC(=O)N1Br", "O=C=Nc1csc(C(F)(F)F)n1", "CCC[Sn](Cl)(CCC)CCC", "COc1ccc2ncsc2c1", ) search_tree = SearchTree.from_json( shared_datadir / "tree_with_branching.json", default_config) nodes = list(search_tree.graph()) assert pytest.approx(StateScorer(default_config)(nodes[-1]), abs=1e-3) == 0.950 assert NumberOfReactionsScorer()(nodes[-1]) == 14 assert NumberOfPrecursorsScorer(default_config)(nodes[-1]) == 8 assert NumberOfPrecursorsInStockScorer(default_config)(nodes[-1]) == 8 assert PriceSumScorer(default_config)(nodes[-1]) == 8 cost_score = RouteCostScorer(default_config)(nodes[-1]) assert pytest.approx(cost_score, abs=1e-3) == 77.4797
def prepare_tree(self) -> None: """ Setup the tree for searching :raises ValueError: if the target molecule was not set """ if not self.target_mol: raise ValueError("No target molecule set") self.stock.reset_exclusion_list() if self.config.exclude_target_from_stock and self.target_mol in self.stock: self.stock.exclude(self.target_mol) self._logger.debug("Excluding the target compound from the stock") self._logger.debug("Defining tree root: %s" % self.target_smiles) self.tree = SearchTree(root_smiles=self.target_smiles, config=self.config) self.analysis = None self.routes = RouteCollection([])
def test_scoring_branched_mcts_tree(shared_datadir, default_config): search_tree = SearchTree.from_json( shared_datadir / "tree_with_branching.json", default_config) nodes = list(search_tree.graph()) assert pytest.approx(StateScorer()(nodes[-1]), abs=1e-6) == 0.00012363 assert NumberOfReactionsScorer()(nodes[-1]) == 14 assert NumberOfPrecursorsScorer(default_config)(nodes[-1]) == 8 assert NumberOfPrecursorsInStockScorer(default_config)(nodes[-1]) == 0
def test_template_occurence_scorer(shared_datadir, default_config): search_tree = SearchTree.from_json( shared_datadir / "tree_without_repetition.json", default_config) nodes = list(search_tree.graph()) nodes[0][nodes[1]]["action"].metadata["library_occurence"] = 5 scorer = AverageTemplateOccurenceScorer() assert scorer(nodes[0]) == 0 assert scorer(nodes[1]) == 5
def test_scorers_one_mcts_node(default_config): tree = SearchTree(default_config, root_smiles="CCCCOc1ccc(CC(=O)N(C)O)cc1") node = tree.root assert pytest.approx(StateScorer(default_config)(node), abs=1e-3) == 0.0497 assert NumberOfReactionsScorer(default_config)(node) == 0 assert NumberOfPrecursorsScorer(default_config)(node) == 1 assert NumberOfPrecursorsInStockScorer(default_config)(node) == 0 assert PriceSumScorer(default_config)(node) == 10 assert RouteCostScorer(default_config)(node) == 10
def test_serialize_deserialize_tree( fresh_tree, generate_root, simple_actions, mock_expansion_policy, default_config, mocker, tmpdir, ): serializer = MoleculeSerializer() root = generate_root("CCCCOc1ccc(CC(=O)N(C)O)cc1") fresh_tree.root = root action_list, prior_list = mock_expansion_policy(root.state.mols[0]) root.expand() child = root.promising_child() mocked_json_dump = mocker.patch("aizynthfinder.mcts.mcts.json.dump") serializer = MoleculeSerializer() filename = str(tmpdir / "dummy.json") # Test serialization fresh_tree.serialize(filename) expected_dict = { "tree": root.serialize(serializer), "molecules": serializer.store } mocked_json_dump.assert_called_once_with(expected_dict, mocker.ANY, indent=mocker.ANY) # Test deserialization mocker.patch("aizynthfinder.mcts.mcts.json.load", return_value=expected_dict) new_tree = SearchTree.from_json(filename, default_config) root_new = new_tree.root assert len(root_new.children()) == 1 new_child = root_new.children()[0] assert root_new.children_view()["values"] == root.children_view()["values"] assert root_new.children_view()["priors"] == root.children_view()["priors"] assert (root_new.children_view()["visitations"] == root.children_view() ["visitations"]) assert root_new.is_expanded assert new_child.children_view()["values"] == child.children_view( )["values"] assert new_child.children_view()["priors"] == child.children_view( )["priors"] assert (new_child.children_view()["visitations"] == child.children_view() ["visitations"]) assert not new_child.is_expanded assert str(root_new.state) == str(root.state) assert str(new_child.state) == str(child.state)
def test_sort(shared_datadir, default_config, mock_stock): mock_stock(default_config, "CCCO", "CC") search_tree = SearchTree.from_json( shared_datadir / "tree_without_repetition.json", default_config) nodes = list(search_tree.graph()) scorer = StateScorer(default_config) sorted_nodes, scores, _ = scorer.sort(nodes) assert [np.round(score, 4) for score in scores] == [0.9976, 0.0491] assert sorted_nodes == [nodes[1], nodes[0]]
def setup_analysis(default_config, shared_datadir, tmpdir, mock_stock): mock_stock( default_config, "N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1", "CN1CCC(Cl)CC1", "O" ) with gzip.open(shared_datadir / "full_search_tree.json.gz", "rb") as gzip_obj: with open(tmpdir / "full_search_tree.json", "wb") as fileobj: fileobj.write(gzip_obj.read()) tree = SearchTree.from_json(tmpdir / "full_search_tree.json", default_config) nodes = list(tree.graph()) def wrapper(scorer=None): return TreeAnalysis(tree, scorer=scorer), nodes return wrapper
def test_route_node_depth_from_analysis(default_config, mock_stock, shared_datadir): mock_stock(default_config, Molecule(smiles="CC"), Molecule(smiles="CCCO")) search_tree = SearchTree.from_json( shared_datadir / "tree_without_repetition.json", default_config) analysis = TreeAnalysis(search_tree) rt = ReactionTree.from_analysis(analysis) mols = list(rt.molecules()) assert rt.depth(mols[0]) == 0 assert rt.depth(mols[1]) == 2 assert rt.depth(mols[2]) == 2 rxns = list(rt.reactions()) assert rt.depth(rxns[0]) == 1 for mol in rt.molecules(): assert rt.depth(mol) == 2 * rt.graph.nodes[mol]["transform"]
def fresh_tree(default_config): return SearchTree(config=default_config, root_smiles=None)
class AiZynthFinder: """ Public API to the aizynthfinder tool If intantiated with the path to a yaml file or dictionary of settings the stocks and policy networks are loaded directly. Otherwise, the user is responsible for loading them prior to executing the tree search. :ivar config: the configuration of the search :vartype config: Configuration :ivar policy: the policy model :vartype policy: Policy :ivar stock: the stock :vartype stock: Stock :ivar tree: the search tree :vartype tree: SearchTree :ivar analysis: the tree analysis :vartype analysis: TreeAnalysis :ivar routes: the top-ranked routes :vartype routes: RouteCollection :ivar search_stats: statistics of the latest search: time, number of iterations and if it returned first solution :vartype search_stats: dict :param configfile: the path to yaml file with configuration (has priority over configdict), defaults to None :type configfile: str, optional :param configdict: the config as a dictionary source, defaults to None :type configdict: dict, optional """ def __init__(self, configfile=None, configdict=None): self._logger = logger() if configfile: self.config = Configuration.from_file(configfile) elif configdict: self.config = Configuration.from_dict(configdict) else: self.config = Configuration() self.expansion_policy = self.config.expansion_policy self.filter_policy = self.config.filter_policy self.stock = self.config.stock self.scorers = self.config.scorers self.tree = None self._target_mol = None self.search_stats = {} self.routes = None self.analysis = None @property def target_smiles(self): """ The SMILES representation of the molecule to predict routes on. :return: the SMILES :rvalue: str """ return self._target_mol.smiles @target_smiles.setter def target_smiles(self, smiles): self.target_mol = Molecule(smiles=smiles) @property def target_mol(self): """ The molecule to predict routes on :return: the molecule :rvalue: Molecule """ return self._target_mol @target_mol.setter def target_mol(self, mol): self.tree = None self._target_mol = mol def build_routes(self, min_nodes=5, scorer="state score"): """ Build reaction routes This is necessary to call after the tree search has completed in order to extract results from the tree search. :param min_nodes: the minimum number of top-ranked nodes to consider, defaults to 5 :type min_nodes: int, optional :param scorer: the object used to score the nodes :type scorer: str, optional """ self.analysis = TreeAnalysis(self.tree, scorer=self.scorers[scorer]) self.routes = RouteCollection.from_analysis(self.analysis, min_nodes) def extract_statistics(self): """ Extracts tree statistics as a dictionary """ if not self.analysis: return {} stats = { "target": self.target_smiles, "search_time": self.search_stats["time"] } stats.update(self.analysis.tree_statistics()) return stats def prepare_tree(self): """ Setup the tree for searching """ self.stock.reset_exclusion_list() if self.config.exclude_target_from_stock and self.target_mol in self.stock: self.stock.exclude(self.target_mol) self._logger.debug("Excluding the target compound from the stock") self._logger.debug("Defining tree root: %s" % self.target_smiles) self.tree = SearchTree(root_smiles=self.target_smiles, config=self.config) self.analysis = None self.routes = None @deprecated(version="2.1.0", reason="Not supported anymore") def run_from_json(self, params): """ Run a search tree by reading settings from a JSON :param params: the parameters of the tree search :type params: dict :return: dictionary with all settings and top scored routes :rtype: dict """ self.stock.select(params["stocks"]) self.expansion_policy.select( params.get("policy", params.get("policies", ""))) if "filter" in params: self.filter_policy.select(params["filter"]) else: self.filter_policy.deselect() self.config.C = params["C"] self.config.max_transforms = params["max_transforms"] self.config.cutoff_cumulative = params["cutoff_cumulative"] self.config.cutoff_number = params["cutoff_number"] self.target_smiles = params["smiles"] self.config.return_first = params["return_first"] self.config.time_limit = params["time_limit"] self.config.iteration_limit = params["iteration_limit"] self.config.exclude_target_from_stock = params[ "exclude_target_from_stock"] self.config.filter_cutoff = params["filter_cutoff"] self.prepare_tree() self.tree_search() self.build_routes() if not params.get("score_trees", False): return { "request": self._get_settings(), "trees": self.routes.dicts, } self.routes.compute_scores(*self.scorers.objects()) return { "request": self._get_settings(), "trees": self.routes.dict_with_scores(), } def tree_search(self, show_progress=False): """ Perform the actual tree search :param show_progress: if True, shows a progress bar :type show_progress: bool :return: the time past in seconds :rtype: float """ if not self.tree: self.prepare_tree() self.search_stats = {"returned_first": False, "iterations": 0} time0 = time.time() i = 1 self._logger.debug("Starting search") time_past = time.time() - time0 if show_progress: pbar = tqdm(total=self.config.iteration_limit) while time_past < self.config.time_limit and i <= self.config.iteration_limit: if show_progress: pbar.update(1) self.search_stats["iterations"] += 1 leaf = self.tree.select_leaf() leaf.expand() while not leaf.is_terminal(): child = leaf.promising_child() if child: child.expand() leaf = child self.tree.backpropagate(leaf, leaf.state.score) if self.config.return_first and leaf.state.is_solved: self._logger.debug("Found first solved route") self.search_stats["returned_first"] = True break i = i + 1 time_past = time.time() - time0 if show_progress: pbar.close() self._logger.debug("Search completed") self.search_stats["time"] = time_past return time_past def _get_settings(self): """Get the current settings as a dictionary """ # To be backward-compatible if len(self.expansion_policy.selection) == 1: policy_value = self.expansion_policy.selection[0] policy_key = "policy" else: policy_value = self.expansion_policy.selection policy_key = "policies" dict_ = { "stocks": self.stock.selection, policy_key: policy_value, "C": self.config.C, "max_transforms": self.config.max_transforms, "cutoff_cumulative": self.config.cutoff_cumulative, "cutoff_number": self.config.cutoff_number, "smiles": self.target_smiles, "return_first": self.config.return_first, "time_limit": self.config.time_limit, "iteration_limit": self.config.iteration_limit, "exclude_target_from_stock": self.config.exclude_target_from_stock, "filter_cutoff": self.config.filter_cutoff, } if self.filter_policy.selection: dict_["filter"] = self.filter_policy.selection return dict_