Beispiel #1
0
def test_clustering_collection_timeout(load_reaction_tree):
    collection = RouteCollection(reaction_trees=[
        ReactionTree.from_dict(
            load_reaction_tree("routes_for_clustering.json", idx))
        for idx in range(3)
    ])
    cluster_labels = collection.cluster(n_clusters=1, timeout=0)

    assert len(cluster_labels) == 0
    assert collection.clusters is None
Beispiel #2
0
def test_compute_new_score_for_trees(default_config, setup_linear_reaction_tree):
    rt = setup_linear_reaction_tree()
    routes = RouteCollection(reaction_trees=[rt])

    assert routes.nodes[0] is None
    assert routes.scores[0] is np.nan
    assert routes.all_scores[0] == {}

    routes.compute_scores(StateScorer(default_config), NumberOfReactionsScorer())

    assert routes.scores[0] is np.nan
    assert np.round(routes.all_scores[0]["state score"], 3) == 0.994
    assert routes.all_scores[0]["number of reactions"] == 2
Beispiel #3
0
def test_clustering_collection(load_reaction_tree):
    collection = RouteCollection(reaction_trees=[
        ReactionTree.from_dict(
            load_reaction_tree("routes_for_clustering.json", idx))
        for idx in range(3)
    ])
    collection.cluster(n_clusters=1)

    assert len(collection.clusters) == 2
    assert collection.clusters[0].reaction_trees == collection.reaction_trees[
        1:3]
    assert collection.clusters[1].reaction_trees == [
        collection.reaction_trees[0]
    ]
Beispiel #4
0
def test_create_combine_tree_dict_from_tree(mock_stock, default_config,
                                            load_reaction_tree,
                                            shared_datadir):
    mock_stock(
        default_config,
        "Nc1ccc(NC(=S)Nc2ccccc2)cc1",
        "Cc1ccc2nc3ccccc3c(Cl)c2c1",
        "Nc1ccc(N)cc1",
        "S=C=Nc1ccccc1",
        "Cc1ccc2nc3ccccc3c(N)c2c1",
        "Nc1ccc(Br)cc1",
    )
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_for_clustering.json", default_config)
    analysis = TreeAnalysis(search_tree)
    collection = RouteCollection.from_analysis(analysis, 3)
    expected = load_reaction_tree("combined_example_tree.json")

    combined_dict = collection.combined_reaction_trees().to_dict()

    assert len(combined_dict["children"]) == 2
    assert combined_dict["children"][0]["is_reaction"]
    assert len(combined_dict["children"][0]["children"]) == 2
    assert len(combined_dict["children"][1]["children"]) == 2
    assert len(combined_dict["children"][1]["children"][0]["children"]) == 2
    assert combined_dict["children"][1]["children"][0]["children"][0][
        "is_reaction"]
    assert combined_dict == expected
Beispiel #5
0
def test_create_route_collection_andor_tree(setup_analysis_andor_tree):
    analysis = setup_analysis_andor_tree()

    routes = RouteCollection.from_analysis(analysis)

    assert len(routes) == 3
    assert routes.nodes == [None, None, None]
Beispiel #6
0
def test_create_route_collection_full(setup_analysis, mocker):
    analysis, _ = setup_analysis()

    routes = RouteCollection.from_analysis(analysis, 5)

    assert len(routes) == 7
    # Check a few of the routes
    assert np.round(routes.scores[0], 3) == 0.994
    assert len(routes.reaction_trees[0].graph) == 8
    assert np.round(routes.scores[1], 3) == 0.681
    assert len(routes.reaction_trees[1].graph) == 5

    assert "dict" not in routes[0]
    assert "json" not in routes[0]
    assert "image" not in routes[0]

    mocker.patch("aizynthfinder.analysis.ReactionTree.to_dict")
    mocker.patch("aizynthfinder.analysis.json.dumps")
    mocker.patch("aizynthfinder.utils.image.GraphvizReactionGraph.to_image")
    mocker.patch(
        "aizynthfinder.utils.image.GraphvizReactionGraph.add_molecule")

    # Just see that the code does not crash, does not verify content
    assert len(routes.images) == 7
    assert len(routes.dicts) == 7
    assert len(routes.jsons) == 7
Beispiel #7
0
def test_compute_new_score_for_trees(default_config, mock_stock,
                                     load_reaction_tree):
    mock_stock(default_config, "N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1",
               "CN1CCC(Cl)CC1", "O")
    rt = ReactionTree.from_dict(load_reaction_tree("sample_reaction.json"))
    routes = RouteCollection(reaction_trees=[rt])

    assert routes.nodes[0] is None
    assert routes.scores[0] is np.nan
    assert routes.all_scores[0] == {}

    routes.compute_scores(StateScorer(default_config),
                          NumberOfReactionsScorer())

    assert routes.scores[0] is np.nan
    assert np.round(routes.all_scores[0]["state score"], 3) == 0.994
    assert routes.all_scores[0]["number of reactions"] == 2
Beispiel #8
0
    def prepare_tree(self) -> None:
        """
        Setup the tree for searching

        :raises ValueError: if the target molecule was not set
        """
        if not self.target_mol:
            raise ValueError("No target molecule set")

        self.stock.reset_exclusion_list()
        if self.config.exclude_target_from_stock and self.target_mol in self.stock:
            self.stock.exclude(self.target_mol)
            self._logger.debug("Excluding the target compound from the stock")

        self._setup_search_tree()
        self.analysis = None
        self.routes = RouteCollection([])
Beispiel #9
0
def test_create_combine_tree_to_visjs(load_reaction_tree, tmpdir):
    collection = RouteCollection(
        reaction_trees=[
            ReactionTree.from_dict(load_reaction_tree("routes_for_clustering.json", 0)),
            ReactionTree.from_dict(load_reaction_tree("routes_for_clustering.json", 1)),
            ReactionTree.from_dict(load_reaction_tree("routes_for_clustering.json", 2)),
        ]
    )
    tar_filename = str(tmpdir / "routes.tar")
    combined = collection.combined_reaction_trees()

    combined.to_visjs_page(tar_filename)

    assert os.path.exists(tar_filename)
    with TarFile(tar_filename) as tarobj:
        assert "./route.html" in tarobj.getnames()
        assert len([name for name in tarobj.getnames() if name.endswith(".png")]) == 8
Beispiel #10
0
def test_dict_with_scores(setup_analysis):
    analysis, _ = setup_analysis()
    routes = RouteCollection.from_analysis(analysis, 5)

    dicts = routes.dict_with_scores()

    assert "scores" not in routes.dicts[0]
    assert "scores" in dicts[0]
    assert np.round(dicts[0]["scores"]["state score"], 3) == 0.994
Beispiel #11
0
def test_create_combine_tree_dict_from_json(load_reaction_tree):
    collection = RouteCollection(
        reaction_trees=[
            ReactionTree.from_dict(load_reaction_tree("routes_for_clustering.json", 0)),
            ReactionTree.from_dict(load_reaction_tree("routes_for_clustering.json", 1)),
            ReactionTree.from_dict(load_reaction_tree("routes_for_clustering.json", 2)),
        ]
    )
    expected = load_reaction_tree("combined_example_tree.json")

    combined_dict = collection.combined_reaction_trees().to_dict()

    assert len(combined_dict["children"]) == 2
    assert combined_dict["children"][0]["is_reaction"]
    assert len(combined_dict["children"][0]["children"]) == 2
    assert len(combined_dict["children"][1]["children"]) == 2
    assert len(combined_dict["children"][1]["children"][1]["children"]) == 2
    assert combined_dict["children"][1]["children"][1]["children"][0]["is_reaction"]
    assert combined_dict == expected
Beispiel #12
0
def test_distance_collection(load_reaction_tree):
    collection = RouteCollection(reaction_trees=[
        ReactionTree.from_dict(
            load_reaction_tree("routes_for_clustering.json", idx))
        for idx in range(3)
    ])

    dist_mat1 = collection.distance_matrix()
    dist_mat2 = collection.distance_matrix(recreate=True)

    assert (dist_mat1 - dist_mat2).sum() == 0

    dist_mat3 = collection.distance_matrix(content="molecules")

    assert (dist_mat1 - dist_mat3).sum() != 0
    assert len(dist_mat3) == 3
    assert pytest.approx(dist_mat3[0, 1], abs=1e-2) == 2.6522
    assert pytest.approx(dist_mat3[0, 2], abs=1e-2) == 3.0779
    assert pytest.approx(dist_mat3[2, 1], abs=1e-2) == 0.7483
Beispiel #13
0
    def __init__(self, configfile: str = None, configdict: StrDict = None) -> None:
        self._logger = logger()

        if configfile:
            self.config = Configuration.from_file(configfile)
        elif configdict:
            self.config = Configuration.from_dict(configdict)
        else:
            self.config = Configuration()

        self.expansion_policy = self.config.expansion_policy
        self.filter_policy = self.config.filter_policy
        self.stock = self.config.stock
        self.scorers = self.config.scorers
        self.tree: Optional[SearchTree] = None
        self._target_mol: Optional[Molecule] = None
        self.search_stats: StrDict = dict()
        self.routes = RouteCollection([])
        self.analysis: Optional[TreeAnalysis] = None
Beispiel #14
0
def test_create_clustering_gui(mocker, load_reaction_tree):
    collection = RouteCollection(reaction_trees=[
        ReactionTree.from_dict(
            load_reaction_tree("routes_for_clustering.json", idx))
        for idx in range(3)
    ])
    display_patch = mocker.patch(
        "aizynthfinder.interfaces.gui.clustering.display")
    ClusteringGui(collection)

    display_patch.assert_called()
    def build_routes(self, min_nodes=5):
        """
        Build reaction routes

        This is necessary to call after the tree search has completed in order
        to extract results from the tree search.

        :param min_nodes: the minimum number of top-ranked nodes to consider, defaults to 5
        :type min_nodes: int, optional
        """
        self.analysis = TreeAnalysis(self.tree)
        self.routes = RouteCollection.from_analysis(self.analysis, min_nodes)
Beispiel #16
0
def test_rescore_collection(setup_analysis):
    analysis, _ = setup_analysis()
    routes = RouteCollection.from_analysis(analysis, 5)

    routes.rescore(NumberOfReactionsScorer())

    assert routes.scores[0] == 1
    assert np.round(routes.all_scores[0]["state score"], 3) == 0.681
    assert routes.all_scores[0]["number of reactions"] == 1

    assert np.round(routes.all_scores[1]["state score"], 3) == 0.523
    assert routes.scores[1] == 1
    assert routes.all_scores[1]["number of reactions"] == 1
Beispiel #17
0
def test_compute_new_score(setup_analysis):
    analysis, _ = setup_analysis()
    routes = RouteCollection.from_analysis(analysis, 5)

    routes.compute_scores(NumberOfReactionsScorer())

    assert np.round(routes.scores[0], 3) == 0.994
    assert np.round(routes.all_scores[0]["state score"], 3) == 0.994
    assert routes.all_scores[0]["number of reactions"] == 2

    assert np.round(routes.scores[1], 3) == 0.681
    assert np.round(routes.all_scores[1]["state score"], 3) == 0.681
    assert routes.all_scores[1]["number of reactions"] == 1
    def build_routes(self, min_nodes=5, scorer="state score"):
        """
        Build reaction routes

        This is necessary to call after the tree search has completed in order
        to extract results from the tree search.

        :param min_nodes: the minimum number of top-ranked nodes to consider, defaults to 5
        :type min_nodes: int, optional
        :param scorer: the object used to score the nodes
        :type scorer: str, optional
        """
        self.analysis = TreeAnalysis(self.tree, scorer=self.scorers[scorer])
        self.routes = RouteCollection.from_analysis(self.analysis, min_nodes)
Beispiel #19
0
def test_rescore_collection_for_trees(default_config, setup_linear_reaction_tree):
    rt = setup_linear_reaction_tree()
    routes = RouteCollection(reaction_trees=[rt])
    routes.compute_scores(StateScorer(default_config))

    routes.rescore(NumberOfReactionsScorer())

    assert routes.scores[0] == 2
    assert np.round(routes.all_scores[0]["state score"], 3) == 0.994
    assert routes.all_scores[0]["number of reactions"] == 2
Beispiel #20
0
    def build_routes(self, min_nodes: int = 5, scorer: str = "state score") -> None:
        """
        Build reaction routes

        This is necessary to call after the tree search has completed in order
        to extract results from the tree search.

        :param min_nodes: the minimum number of top-ranked nodes to consider, defaults to 5
        :param scorer: a reference to the object used to score the nodes
        :raises ValueError: if the search tree not initialized
        """
        if not self.tree:
            raise ValueError("Search tree not initialized")

        self.analysis = TreeAnalysis(self.tree, scorer=self.scorers[scorer])
        self.routes = RouteCollection.from_analysis(self.analysis, min_nodes)
    def build_routes(self,
                     selection: RouteSelectionArguments = None,
                     scorer: str = "state score") -> None:
        """
        Build reaction routes

        This is necessary to call after the tree search has completed in order
        to extract results from the tree search.

        :param selection: the selection criteria for the routes
        :param scorer: a reference to the object used to score the nodes
        :raises ValueError: if the search tree not initialized
        """
        if not self.tree:
            raise ValueError("Search tree not initialized")

        self.analysis = TreeAnalysis(self.tree, scorer=self.scorers[scorer])
        self.routes = RouteCollection.from_analysis(self.analysis, selection)
Beispiel #22
0
def test_create_route_collection(setup_complete_tree, mocker):
    tree, nodes = setup_complete_tree
    analysis = TreeAnalysis(tree)
    mocker.patch("aizynthfinder.analysis.ReactionTree.to_dict")
    mocker.patch("aizynthfinder.analysis.json.dumps")

    routes = RouteCollection.from_analysis(analysis, 5)

    assert len(routes) == 3
    assert routes[0]["score"] == 0.99
    assert routes[0]["node"] is nodes[2]
    reaction_nodes = [
        node for node in routes[0]["reaction_tree"].graph
        if isinstance(node, Reaction)
    ]
    assert len(reaction_nodes) == 2

    # Just see that the code does not crash, does not verify content
    assert len(routes.images) == 3
    assert len(routes.dicts) == 3
    assert len(routes.jsons) == 3
Beispiel #23
0
class AiZynthFinder:
    """
    Public API to the aizynthfinder tool

    If instantiated with the path to a yaml file or dictionary of settings
    the stocks and policy networks are loaded directly.
    Otherwise, the user is responsible for loading them prior to
    executing the tree search.

    :ivar config: the configuration of the search
    :ivar expansion_policy: the expansion policy model
    :ivar filter_policy: the filter policy model
    :ivar stock: the stock
    :ivar scorers: the loaded scores
    :ivar tree: the search tree
    :ivar analysis: the tree analysis
    :ivar routes: the top-ranked routes
    :ivar search_stats: statistics of the latest search

    :param configfile: the path to yaml file with configuration (has priority over configdict), defaults to None
    :param configdict: the config as a dictionary source, defaults to None
    """
    def __init__(self,
                 configfile: str = None,
                 configdict: StrDict = None) -> None:
        self._logger = logger()

        if configfile:
            self.config = Configuration.from_file(configfile)
        elif configdict:
            self.config = Configuration.from_dict(configdict)
        else:
            self.config = Configuration()

        self.expansion_policy = self.config.expansion_policy
        self.filter_policy = self.config.filter_policy
        self.stock = self.config.stock
        self.scorers = self.config.scorers
        self.tree: Optional[Union[MctsSearchTree, AndOrSearchTreeBase]] = None
        self._target_mol: Optional[Molecule] = None
        self.search_stats: StrDict = dict()
        self.routes = RouteCollection([])
        self.analysis: Optional[TreeAnalysis] = None

    @property
    def target_smiles(self) -> str:
        """The SMILES representation of the molecule to predict routes on."""
        if not self._target_mol:
            return ""
        return self._target_mol.smiles

    @target_smiles.setter
    def target_smiles(self, smiles: str) -> None:
        self.target_mol = Molecule(smiles=smiles)

    @property
    def target_mol(self) -> Optional[Molecule]:
        """The molecule to predict routes on"""
        return self._target_mol

    @target_mol.setter
    def target_mol(self, mol: Molecule) -> None:
        self.tree = None
        self._target_mol = mol

    def build_routes(self,
                     min_nodes: int = 5,
                     scorer: str = "state score") -> None:
        """
        Build reaction routes

        This is necessary to call after the tree search has completed in order
        to extract results from the tree search.

        :param min_nodes: the minimum number of top-ranked nodes to consider, defaults to 5
        :param scorer: a reference to the object used to score the nodes
        :raises ValueError: if the search tree not initialized
        """
        if not self.tree:
            raise ValueError("Search tree not initialized")

        self.analysis = TreeAnalysis(self.tree, scorer=self.scorers[scorer])
        self.routes = RouteCollection.from_analysis(self.analysis, min_nodes)

    def extract_statistics(self) -> StrDict:
        """Extracts tree statistics as a dictionary"""
        if not self.analysis:
            return {}
        stats = {
            "target":
            self.target_smiles,
            "search_time":
            self.search_stats["time"],
            "first_solution_time":
            self.search_stats.get("first_solution_time", 0),
            "first_solution_iteration":
            self.search_stats.get("first_solution_iteration", 0),
        }
        stats.update(self.analysis.tree_statistics())
        return stats

    def prepare_tree(self) -> None:
        """
        Setup the tree for searching

        :raises ValueError: if the target molecule was not set
        """
        if not self.target_mol:
            raise ValueError("No target molecule set")

        self.stock.reset_exclusion_list()
        if self.config.exclude_target_from_stock and self.target_mol in self.stock:
            self.stock.exclude(self.target_mol)
            self._logger.debug("Excluding the target compound from the stock")

        self._setup_search_tree()
        self.analysis = None
        self.routes = RouteCollection([])

    @deprecated(version="2.1.0", reason="Not supported anymore")
    def run_from_json(self, params: StrDict) -> StrDict:
        """
        Run a search tree by reading settings from a JSON

        :param params: the parameters of the tree search
        :return: dictionary with all settings and top scored routes
        """
        self.stock.select(params["stocks"])
        self.expansion_policy.select(
            params.get("policy", params.get("policies", "")))
        if "filter" in params:
            self.filter_policy.select(params["filter"])
        else:
            self.filter_policy.deselect()
        self.config.C = params["C"]
        self.config.max_transforms = params["max_transforms"]
        self.config.cutoff_cumulative = params["cutoff_cumulative"]
        self.config.cutoff_number = params["cutoff_number"]
        self.target_smiles = params["smiles"]
        self.config.return_first = params["return_first"]
        self.config.time_limit = params["time_limit"]
        self.config.iteration_limit = params["iteration_limit"]
        self.config.exclude_target_from_stock = params[
            "exclude_target_from_stock"]
        self.config.filter_cutoff = params["filter_cutoff"]

        self.prepare_tree()
        self.tree_search()
        self.build_routes()
        if not params.get("score_trees", False):
            return {
                "request": self._get_settings(),
                "trees": self.routes.dicts,
            }

        self.routes.compute_scores(*self.scorers.objects())
        return {
            "request": self._get_settings(),
            "trees": self.routes.dict_with_scores(),
        }

    def tree_search(self, show_progress: bool = False) -> float:
        """
        Perform the actual tree search

        :param show_progress: if True, shows a progress bar
        :return: the time past in seconds
        """
        if not self.tree:
            self.prepare_tree()
        assert (self.tree is not None
                )  # This is for type checking, prepare_tree is creating it.
        self.search_stats = {"returned_first": False, "iterations": 0}

        time0 = time.time()
        i = 1
        self._logger.debug("Starting search")
        time_past = time.time() - time0

        if show_progress:
            pbar = tqdm(total=self.config.iteration_limit, leave=False)

        while time_past < self.config.time_limit and i <= self.config.iteration_limit:
            if show_progress:
                pbar.update(1)
            self.search_stats["iterations"] += 1

            try:
                is_solved = self.tree.one_iteration()
            except StopIteration:
                break

            if is_solved and "first_solution_time" not in self.search_stats:
                self.search_stats["first_solution_time"] = time.time() - time0
                self.search_stats["first_solution_iteration"] = i

            if self.config.return_first and is_solved:
                self._logger.debug("Found first solved route")
                self.search_stats["returned_first"] = True
                break
            i = i + 1
            time_past = time.time() - time0

        if show_progress:
            pbar.close()
        time_past = time.time() - time0
        self._logger.debug("Search completed")
        self.search_stats["time"] = time_past
        return time_past

    def _get_settings(self) -> StrDict:
        """Get the current settings as a dictionary"""
        # To be backward-compatible
        if (self.expansion_policy.selection
                and len(self.expansion_policy.selection) == 1):
            policy_value = self.expansion_policy.selection[0]
            policy_key = "policy"
        else:
            policy_value = self.expansion_policy.selection  # type: ignore
            policy_key = "policies"

        dict_ = {
            "stocks": self.stock.selection,
            policy_key: policy_value,
            "C": self.config.C,
            "max_transforms": self.config.max_transforms,
            "cutoff_cumulative": self.config.cutoff_cumulative,
            "cutoff_number": self.config.cutoff_number,
            "smiles": self.target_smiles,
            "return_first": self.config.return_first,
            "time_limit": self.config.time_limit,
            "iteration_limit": self.config.iteration_limit,
            "exclude_target_from_stock": self.config.exclude_target_from_stock,
            "filter_cutoff": self.config.filter_cutoff,
        }
        if self.filter_policy.selection:
            dict_["filter"] = self.filter_policy.selection
        return dict_

    def _setup_search_tree(self):
        self._logger.debug("Defining tree root: %s" % self.target_smiles)
        if self.config.search_algorithm.lower() == "mcts":
            self.tree = MctsSearchTree(root_smiles=self.target_smiles,
                                       config=self.config)
        else:
            module_name, cls_name = self.config.search_algorithm.rsplit(
                ".", maxsplit=1)
            try:
                module_obj = importlib.import_module(module_name)
            except ImportError:
                raise ValueError(f"Could not import module {module_name}")

            if not hasattr(module_obj, cls_name):
                raise ValueError(
                    f"Could not identify class {cls_name} in module")

            self.tree: AndOrSearchTreeBase = getattr(module_obj, cls_name)(
                root_smiles=self.target_smiles, config=self.config)