Esempio n. 1
0
def test_find_repetetive_patterns_created_tree_no_patterns(
        default_config, mock_stock, shared_datadir):
    mock_stock(default_config, Molecule(smiles="CC"), Molecule(smiles="CCCO"))

    # Try with a short tree (3 nodes, 1 reaction)
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_without_repetition.json", default_config)
    analysis = TreeAnalysis(search_tree)

    rt = ReactionTree.from_analysis(analysis)

    assert not rt.has_repeating_patterns
    hidden_nodes = [
        node for node in rt.graph if rt.graph.nodes[node].get("hide", False)
    ]
    assert len(hidden_nodes) == 0

    # Try with something longer
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_without_repetition_longer.json", default_config)
    analysis = TreeAnalysis(search_tree)

    rt = ReactionTree.from_analysis(analysis)

    assert not rt.has_repeating_patterns
Esempio n. 2
0
def test_find_repetetive_patterns_created_tree(default_config, mock_stock,
                                               shared_datadir):
    mock_stock(default_config, Molecule(smiles="CC"), Molecule(smiles="C"))

    # Try one with 2 repetetive units
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_with_repetition.json", default_config)
    analysis = TreeAnalysis(search_tree)

    rt = ReactionTree.from_analysis(analysis)

    assert rt.has_repeating_patterns
    hidden_nodes = [
        node for node in rt.graph if rt.graph.nodes[node].get("hide", False)
    ]
    assert len(hidden_nodes) == 5

    # Try one with 3 repetetive units
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_with_3_repetitions.json", default_config)
    analysis = TreeAnalysis(search_tree)

    rt = ReactionTree.from_analysis(analysis)

    assert rt.has_repeating_patterns
    hidden_nodes = [
        node for node in rt.graph if rt.graph.nodes[node].get("hide", False)
    ]
    assert len(hidden_nodes) == 10
Esempio n. 3
0
    def prepare_tree(self):
        """ Setup the tree for searching
        """
        self.stock.reset_exclusion_list()
        if self.config.exclude_target_from_stock and self.target_mol in self.stock:
            self.stock.exclude(self.target_mol)
            self._logger.debug("Excluding the target compound from the stock")

        self._logger.debug("Defining tree root: %s" % self.target_smiles)
        self.tree = SearchTree(root_smiles=self.target_smiles, config=self.config)
        self.analysis = None
        self.routes = None
Esempio n. 4
0
def test_create_combine_tree_dict_from_tree(mock_stock, default_config,
                                            load_reaction_tree,
                                            shared_datadir):
    mock_stock(
        default_config,
        "Nc1ccc(NC(=S)Nc2ccccc2)cc1",
        "Cc1ccc2nc3ccccc3c(Cl)c2c1",
        "Nc1ccc(N)cc1",
        "S=C=Nc1ccccc1",
        "Cc1ccc2nc3ccccc3c(N)c2c1",
        "Nc1ccc(Br)cc1",
    )
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_for_clustering.json", default_config)
    analysis = TreeAnalysis(search_tree)
    collection = RouteCollection.from_analysis(analysis, 3)
    expected = load_reaction_tree("combined_example_tree.json")

    combined_dict = collection.combined_reaction_trees().to_dict()

    assert len(combined_dict["children"]) == 2
    assert combined_dict["children"][0]["is_reaction"]
    assert len(combined_dict["children"][0]["children"]) == 2
    assert len(combined_dict["children"][1]["children"]) == 2
    assert len(combined_dict["children"][1]["children"][0]["children"]) == 2
    assert combined_dict["children"][1]["children"][0]["children"][0][
        "is_reaction"]
    assert combined_dict == expected
Esempio n. 5
0
def test_number_of_reaction_scorer_node(shared_datadir, default_config):
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_without_repetition.json", default_config)
    nodes = list(search_tree.graph())
    scorer = NumberOfReactionsScorer()

    assert scorer(nodes[1]) == 1
Esempio n. 6
0
def test_template_occurence_scorer_no_metadata(shared_datadir, default_config):
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_without_repetition.json", default_config)
    nodes = list(search_tree.graph())
    scorer = AverageTemplateOccurenceScorer()

    assert scorer(nodes[1]) == 0
Esempio n. 7
0
def test_scoring_branch_mcts_tree_in_stock(shared_datadir, default_config,
                                           mock_stock):
    mock_stock(
        default_config,
        "CC(C)(C)CO",
        "CC(C)(C)OC(=O)N(CCCl)CCCl",
        "N#CCc1cccc(O)c1F",
        "O=[N+]([O-])c1ccccc1F",
        "O=C1CCC(=O)N1Br",
        "O=C=Nc1csc(C(F)(F)F)n1",
        "CCC[Sn](Cl)(CCC)CCC",
        "COc1ccc2ncsc2c1",
    )
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_with_branching.json", default_config)
    nodes = list(search_tree.graph())

    assert pytest.approx(StateScorer(default_config)(nodes[-1]),
                         abs=1e-3) == 0.950
    assert NumberOfReactionsScorer()(nodes[-1]) == 14
    assert NumberOfPrecursorsScorer(default_config)(nodes[-1]) == 8
    assert NumberOfPrecursorsInStockScorer(default_config)(nodes[-1]) == 8
    assert PriceSumScorer(default_config)(nodes[-1]) == 8
    cost_score = RouteCostScorer(default_config)(nodes[-1])
    assert pytest.approx(cost_score, abs=1e-3) == 77.4797
Esempio n. 8
0
    def prepare_tree(self) -> None:
        """
        Setup the tree for searching

        :raises ValueError: if the target molecule was not set
        """
        if not self.target_mol:
            raise ValueError("No target molecule set")

        self.stock.reset_exclusion_list()
        if self.config.exclude_target_from_stock and self.target_mol in self.stock:
            self.stock.exclude(self.target_mol)
            self._logger.debug("Excluding the target compound from the stock")

        self._logger.debug("Defining tree root: %s" % self.target_smiles)
        self.tree = SearchTree(root_smiles=self.target_smiles, config=self.config)
        self.analysis = None
        self.routes = RouteCollection([])
Esempio n. 9
0
def test_scoring_branched_mcts_tree(shared_datadir, default_config):
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_with_branching.json", default_config)
    nodes = list(search_tree.graph())

    assert pytest.approx(StateScorer()(nodes[-1]), abs=1e-6) == 0.00012363
    assert NumberOfReactionsScorer()(nodes[-1]) == 14
    assert NumberOfPrecursorsScorer(default_config)(nodes[-1]) == 8
    assert NumberOfPrecursorsInStockScorer(default_config)(nodes[-1]) == 0
Esempio n. 10
0
def test_template_occurence_scorer(shared_datadir, default_config):
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_without_repetition.json", default_config)
    nodes = list(search_tree.graph())
    nodes[0][nodes[1]]["action"].metadata["library_occurence"] = 5
    scorer = AverageTemplateOccurenceScorer()

    assert scorer(nodes[0]) == 0
    assert scorer(nodes[1]) == 5
Esempio n. 11
0
def test_scorers_one_mcts_node(default_config):
    tree = SearchTree(default_config, root_smiles="CCCCOc1ccc(CC(=O)N(C)O)cc1")
    node = tree.root

    assert pytest.approx(StateScorer(default_config)(node), abs=1e-3) == 0.0497
    assert NumberOfReactionsScorer(default_config)(node) == 0
    assert NumberOfPrecursorsScorer(default_config)(node) == 1
    assert NumberOfPrecursorsInStockScorer(default_config)(node) == 0
    assert PriceSumScorer(default_config)(node) == 10
    assert RouteCostScorer(default_config)(node) == 10
def test_serialize_deserialize_tree(
    fresh_tree,
    generate_root,
    simple_actions,
    mock_expansion_policy,
    default_config,
    mocker,
    tmpdir,
):
    serializer = MoleculeSerializer()
    root = generate_root("CCCCOc1ccc(CC(=O)N(C)O)cc1")
    fresh_tree.root = root
    action_list, prior_list = mock_expansion_policy(root.state.mols[0])
    root.expand()
    child = root.promising_child()
    mocked_json_dump = mocker.patch("aizynthfinder.mcts.mcts.json.dump")
    serializer = MoleculeSerializer()
    filename = str(tmpdir / "dummy.json")

    # Test serialization

    fresh_tree.serialize(filename)

    expected_dict = {
        "tree": root.serialize(serializer),
        "molecules": serializer.store
    }
    mocked_json_dump.assert_called_once_with(expected_dict,
                                             mocker.ANY,
                                             indent=mocker.ANY)

    # Test deserialization

    mocker.patch("aizynthfinder.mcts.mcts.json.load",
                 return_value=expected_dict)

    new_tree = SearchTree.from_json(filename, default_config)
    root_new = new_tree.root
    assert len(root_new.children()) == 1

    new_child = root_new.children()[0]
    assert root_new.children_view()["values"] == root.children_view()["values"]
    assert root_new.children_view()["priors"] == root.children_view()["priors"]
    assert (root_new.children_view()["visitations"] == root.children_view()
            ["visitations"])
    assert root_new.is_expanded
    assert new_child.children_view()["values"] == child.children_view(
    )["values"]
    assert new_child.children_view()["priors"] == child.children_view(
    )["priors"]
    assert (new_child.children_view()["visitations"] == child.children_view()
            ["visitations"])
    assert not new_child.is_expanded
    assert str(root_new.state) == str(root.state)
    assert str(new_child.state) == str(child.state)
Esempio n. 13
0
def test_sort(shared_datadir, default_config, mock_stock):
    mock_stock(default_config, "CCCO", "CC")
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_without_repetition.json", default_config)
    nodes = list(search_tree.graph())
    scorer = StateScorer(default_config)

    sorted_nodes, scores, _ = scorer.sort(nodes)

    assert [np.round(score, 4) for score in scores] == [0.9976, 0.0491]
    assert sorted_nodes == [nodes[1], nodes[0]]
Esempio n. 14
0
def setup_analysis(default_config, shared_datadir, tmpdir, mock_stock):
    mock_stock(
        default_config, "N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1", "CN1CCC(Cl)CC1", "O"
    )

    with gzip.open(shared_datadir / "full_search_tree.json.gz", "rb") as gzip_obj:
        with open(tmpdir / "full_search_tree.json", "wb") as fileobj:
            fileobj.write(gzip_obj.read())
    tree = SearchTree.from_json(tmpdir / "full_search_tree.json", default_config)
    nodes = list(tree.graph())

    def wrapper(scorer=None):
        return TreeAnalysis(tree, scorer=scorer), nodes

    return wrapper
Esempio n. 15
0
def test_route_node_depth_from_analysis(default_config, mock_stock,
                                        shared_datadir):
    mock_stock(default_config, Molecule(smiles="CC"), Molecule(smiles="CCCO"))
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_without_repetition.json", default_config)
    analysis = TreeAnalysis(search_tree)
    rt = ReactionTree.from_analysis(analysis)

    mols = list(rt.molecules())

    assert rt.depth(mols[0]) == 0
    assert rt.depth(mols[1]) == 2
    assert rt.depth(mols[2]) == 2

    rxns = list(rt.reactions())
    assert rt.depth(rxns[0]) == 1

    for mol in rt.molecules():
        assert rt.depth(mol) == 2 * rt.graph.nodes[mol]["transform"]
Esempio n. 16
0
def fresh_tree(default_config):
    return SearchTree(config=default_config, root_smiles=None)
Esempio n. 17
0
class AiZynthFinder:
    """
    Public API to the aizynthfinder tool

    If intantiated with the path to a yaml file or dictionary of settings
    the stocks and policy networks are loaded directly.
    Otherwise, the user is responsible for loading them prior to
    executing the tree search.

    :ivar config: the configuration of the search
    :vartype config: Configuration
    :ivar policy: the policy model
    :vartype policy: Policy
    :ivar stock: the stock
    :vartype stock: Stock
    :ivar tree: the search tree
    :vartype tree: SearchTree
    :ivar analysis: the tree analysis
    :vartype analysis: TreeAnalysis
    :ivar routes: the top-ranked routes
    :vartype routes: RouteCollection
    :ivar search_stats: statistics of the latest search: time, number of iterations and if it returned first solution
    :vartype search_stats: dict

    :param configfile: the path to yaml file with configuration (has priority over configdict), defaults to None
    :type configfile: str, optional
    :param configdict: the config as a dictionary source, defaults to None
    :type configdict: dict, optional
    """
    def __init__(self, configfile=None, configdict=None):
        self._logger = logger()

        if configfile:
            self.config = Configuration.from_file(configfile)
        elif configdict:
            self.config = Configuration.from_dict(configdict)
        else:
            self.config = Configuration()

        self.expansion_policy = self.config.expansion_policy
        self.filter_policy = self.config.filter_policy
        self.stock = self.config.stock
        self.scorers = self.config.scorers
        self.tree = None
        self._target_mol = None
        self.search_stats = {}
        self.routes = None
        self.analysis = None

    @property
    def target_smiles(self):
        """
        The SMILES representation of the molecule to predict routes on.

        :return: the SMILES
        :rvalue: str
        """
        return self._target_mol.smiles

    @target_smiles.setter
    def target_smiles(self, smiles):
        self.target_mol = Molecule(smiles=smiles)

    @property
    def target_mol(self):
        """
        The molecule to predict routes on

        :return: the molecule
        :rvalue: Molecule
        """
        return self._target_mol

    @target_mol.setter
    def target_mol(self, mol):
        self.tree = None
        self._target_mol = mol

    def build_routes(self, min_nodes=5, scorer="state score"):
        """
        Build reaction routes

        This is necessary to call after the tree search has completed in order
        to extract results from the tree search.

        :param min_nodes: the minimum number of top-ranked nodes to consider, defaults to 5
        :type min_nodes: int, optional
        :param scorer: the object used to score the nodes
        :type scorer: str, optional
        """
        self.analysis = TreeAnalysis(self.tree, scorer=self.scorers[scorer])
        self.routes = RouteCollection.from_analysis(self.analysis, min_nodes)

    def extract_statistics(self):
        """ Extracts tree statistics as a dictionary
        """
        if not self.analysis:
            return {}
        stats = {
            "target": self.target_smiles,
            "search_time": self.search_stats["time"]
        }
        stats.update(self.analysis.tree_statistics())
        return stats

    def prepare_tree(self):
        """ Setup the tree for searching
        """
        self.stock.reset_exclusion_list()
        if self.config.exclude_target_from_stock and self.target_mol in self.stock:
            self.stock.exclude(self.target_mol)
            self._logger.debug("Excluding the target compound from the stock")

        self._logger.debug("Defining tree root: %s" % self.target_smiles)
        self.tree = SearchTree(root_smiles=self.target_smiles,
                               config=self.config)
        self.analysis = None
        self.routes = None

    @deprecated(version="2.1.0", reason="Not supported anymore")
    def run_from_json(self, params):
        """
        Run a search tree by reading settings from a JSON

        :param params: the parameters of the tree search
        :type params: dict
        :return: dictionary with all settings and top scored routes
        :rtype: dict
        """
        self.stock.select(params["stocks"])
        self.expansion_policy.select(
            params.get("policy", params.get("policies", "")))
        if "filter" in params:
            self.filter_policy.select(params["filter"])
        else:
            self.filter_policy.deselect()
        self.config.C = params["C"]
        self.config.max_transforms = params["max_transforms"]
        self.config.cutoff_cumulative = params["cutoff_cumulative"]
        self.config.cutoff_number = params["cutoff_number"]
        self.target_smiles = params["smiles"]
        self.config.return_first = params["return_first"]
        self.config.time_limit = params["time_limit"]
        self.config.iteration_limit = params["iteration_limit"]
        self.config.exclude_target_from_stock = params[
            "exclude_target_from_stock"]
        self.config.filter_cutoff = params["filter_cutoff"]

        self.prepare_tree()
        self.tree_search()
        self.build_routes()
        if not params.get("score_trees", False):
            return {
                "request": self._get_settings(),
                "trees": self.routes.dicts,
            }

        self.routes.compute_scores(*self.scorers.objects())
        return {
            "request": self._get_settings(),
            "trees": self.routes.dict_with_scores(),
        }

    def tree_search(self, show_progress=False):
        """
        Perform the actual tree search

        :param show_progress: if True, shows a progress bar
        :type show_progress: bool
        :return: the time past in seconds
        :rtype: float
        """
        if not self.tree:
            self.prepare_tree()
        self.search_stats = {"returned_first": False, "iterations": 0}

        time0 = time.time()
        i = 1
        self._logger.debug("Starting search")
        time_past = time.time() - time0

        if show_progress:
            pbar = tqdm(total=self.config.iteration_limit)

        while time_past < self.config.time_limit and i <= self.config.iteration_limit:
            if show_progress:
                pbar.update(1)
            self.search_stats["iterations"] += 1

            leaf = self.tree.select_leaf()
            leaf.expand()
            while not leaf.is_terminal():
                child = leaf.promising_child()
                if child:
                    child.expand()
                    leaf = child
            self.tree.backpropagate(leaf, leaf.state.score)
            if self.config.return_first and leaf.state.is_solved:
                self._logger.debug("Found first solved route")
                self.search_stats["returned_first"] = True
                break
            i = i + 1
            time_past = time.time() - time0

        if show_progress:
            pbar.close()
        self._logger.debug("Search completed")
        self.search_stats["time"] = time_past
        return time_past

    def _get_settings(self):
        """Get the current settings as a dictionary
        """
        # To be backward-compatible
        if len(self.expansion_policy.selection) == 1:
            policy_value = self.expansion_policy.selection[0]
            policy_key = "policy"
        else:
            policy_value = self.expansion_policy.selection
            policy_key = "policies"

        dict_ = {
            "stocks": self.stock.selection,
            policy_key: policy_value,
            "C": self.config.C,
            "max_transforms": self.config.max_transforms,
            "cutoff_cumulative": self.config.cutoff_cumulative,
            "cutoff_number": self.config.cutoff_number,
            "smiles": self.target_smiles,
            "return_first": self.config.return_first,
            "time_limit": self.config.time_limit,
            "iteration_limit": self.config.iteration_limit,
            "exclude_target_from_stock": self.config.exclude_target_from_stock,
            "filter_cutoff": self.config.filter_cutoff,
        }
        if self.filter_policy.selection:
            dict_["filter"] = self.filter_policy.selection
        return dict_