def test_two_expansions_cyclic(mock_get_actions, mock_create_root, mock_stock):
    """
    Test the building of this tree:
                root
                  |
                child 1
                  |
                child 2
    But making child 2 should be rejected because child 2 == root
    """
    finder = AiZynthFinder()
    root_smi = "COc1cc2cc(-c3ccc(OC(C)=O)c(OC(C)=O)c3)[n+](C)c(C)c2cc1OC"
    root_mol = mock_create_root(root_smi, finder.config)
    child1_smi = ["COc1cc2cc(-c3ccc(O)c(OC(C)=O)c3)[n+](C)c(C)c2cc1OC"]
    child2_smi = ["COc1cc2cc(-c3ccc(OC(C)=O)c(OC(C)=O)c3)[n+](C)c(C)c2cc1OC"]
    child1_mol, *_ = mock_get_actions(root_mol, tuple([root_smi]),
                                      [child1_smi], [0.3])
    child2_mol, *_ = mock_get_actions(
        child1_mol[0],
        tuple(child1_smi),
        [child2_smi],
        [0.3],
    )
    mock_stock(finder.config)
    finder.target_mol = root_mol
    finder.config.iteration_limit = 1

    finder.tree_search()

    nodes = list(finder.tree.graph())
    assert len(nodes) == 2
    assert nodes[0].state.mols == [root_mol]
    assert nodes[1].state.mols == child1_mol
    assert finder.search_stats["iterations"] == 1
def test_three_expansions_no_reactants_second_level(
    mock_get_actions, mock_create_root, mock_stock
):
    """
    Test the following scenario:
                root
            /           \
        child 1         child 2
           |               |
        grandchild 1 (+) grandchild 2 (*)

        - child 1 will be selected first for expansion (iteration 1)
        - grandchild 1 will be selected next,
        - it has no children that can be expanded (marked by x)
        -- end of iteration 1
        - child 2 will be selected for expansion  (iteration 2)
        - grandchild 2 will be selected next and it is in stock (marked by *)
        -- a solution is found and the tree search is terminated
        * nodes in tree will be root, child1, grandchild 1, child2, grandchild 2
    """
    finder = AiZynthFinder()
    root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
    root_mol = mock_create_root(root_smi, finder.config)
    child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"]
    child2_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F"]
    grandchild1_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"]
    grandchild2_smi = ["N#Cc1cccc(N)c1", "O=C(Cl)c1ccc(F)c(F)c1"]
    child1_mol, child2_mol = mock_get_actions(
        root_mol, tuple([root_smi]), [child1_smi, child2_smi], [0.3, 0.1],
    )
    grandchild1_mol, *_ = mock_get_actions(
        child1_mol[1], tuple(child1_smi), [grandchild1_smi], [0.3],
    )
    smiles_state1 = [child1_smi[0], child1_smi[2]] + grandchild1_smi
    mock_get_actions(
        grandchild1_mol[1], tuple(smiles_state1), [None], [0.3],
    )  # Will try to expand grandchild 1
    grandchild2_mol, *_ = mock_get_actions(
        child2_mol[1], tuple(child2_smi), [grandchild2_smi], [0.3],
    )
    mock_stock(
        finder.config,
        child1_mol[0],
        child1_mol[2],
        grandchild1_mol[0],
        *grandchild2_mol
    )
    finder.target_mol = root_mol
    finder.config.return_first = True

    finder.tree_search()

    nodes = list(finder.tree.graph())
    assert len(nodes) == 5
    assert nodes[0].state.mols == [root_mol]
    assert nodes[1].state.mols == child1_mol
    assert nodes[2].state.mols == [child1_mol[0], child1_mol[2]] + grandchild1_mol
    assert nodes[3].state.mols == child2_mol
    assert nodes[4].state.mols == [child2_mol[0]] + grandchild2_mol
    assert finder.search_stats["iterations"] == 2
def test_two_expansions(mock_get_actions, mock_create_root, mock_stock):
    """
    Test the building of this tree:
                root
                  |
                child 1
                  |
                child 2
    """
    finder = AiZynthFinder()
    root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
    root_mol = mock_create_root(root_smi, finder.config)
    child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"]
    child2_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"]
    child1_mol, *_ = mock_get_actions(root_mol, tuple([root_smi]),
                                      [child1_smi], [0.3])
    child2_mol, *_ = mock_get_actions(
        child1_mol[1],
        tuple(child1_smi),
        [child2_smi],
        [0.3],
    )
    mock_stock(finder.config, child1_mol[0], child1_mol[2], *child2_mol)
    finder.target_mol = root_mol
    finder.config.return_first = True

    finder.tree_search()

    nodes = list(finder.tree.graph())
    assert len(nodes) == 3
    assert nodes[0].state.mols == [root_mol]
    assert nodes[1].state.mols == child1_mol
    assert nodes[2].state.mols == [child1_mol[0], child1_mol[2]] + child2_mol
    assert finder.search_stats["iterations"] == 1
def test_one_expansion(mock_get_actions, mock_create_root, mock_stock):
    """
    Test the building of this tree:
                root
                  |
                child 1
    """
    finder = AiZynthFinder()
    root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
    root_mol = mock_create_root(root_smi, finder.config)
    child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"]
    child1_mol, *_ = mock_get_actions(root_mol, tuple([root_smi]), [child1_smi], [0.3])
    mock_stock(finder.config, *child1_mol)
    finder.target_mol = root_mol

    # Test first with return_first
    finder.config.return_first = True
    finder.tree_search()

    nodes = list(finder.tree.graph())
    assert len(nodes) == 2
    assert nodes[0].state.mols == [root_mol]
    assert nodes[1].state.mols == child1_mol
    assert finder.search_stats["iterations"] == 1
    assert finder.search_stats["returned_first"]

    # then test with iteration limit
    finder.config.return_first = False
    finder.config.iteration_limit = 45
    finder.prepare_tree()
    finder.tree_search()

    assert len(finder.tree.graph()) == 2
    assert finder.search_stats["iterations"] == 45
    assert not finder.search_stats["returned_first"]
Exemple #5
0
def _process_single_smiles(
    smiles: str,
    finder: AiZynthFinder,
    output_name: str,
    do_clustering: bool,
    route_distance_model: str = None,
) -> None:
    output_name = output_name or "trees.json"
    finder.target_smiles = smiles
    finder.prepare_tree()
    finder.tree_search(show_progress=True)
    finder.build_routes()

    with open(output_name, "w") as fileobj:
        json.dump(finder.routes.dicts, fileobj, indent=2)
    logger().info(f"Trees saved to {output_name}")

    scores = ", ".join("%.4f" % score for score in finder.routes.scores)
    logger().info(f"Scores for best routes: {scores}")

    stats = finder.extract_statistics()
    if do_clustering:
        _do_clustering(finder,
                       stats,
                       detailed_results=False,
                       model_path=route_distance_model)
    stats_str = "\n".join(f"{key.replace('_', ' ')}: {value}"
                          for key, value in stats.items())
    logger().info(stats_str)
def test_two_expansions_no_reactants_second_child(mock_get_actions,
                                                  mock_create_root,
                                                  mock_stock):
    """
    Test the following scenario:
                root
            /           \
        child 1        child 2 (+)
            |
        grandchild 1 (*)

        - child 1 will be selected first for expansion (iteration 1)
        - grandchild 1 will be selected next and it is in stock (marked by *)
        -- end of iteration 1
        - child 2 will be selected for expansion  (iteration 2)
        - it has no children that can be expanded (marked with +)
        -- will continue to iterate until reached number of iteration (set 10 in the test)
        * nodes in tree will be root, child1, grandchild 1, child2
    """
    finder = AiZynthFinder()
    root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
    root_mol = mock_create_root(root_smi, finder.config)
    child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"]
    child2_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F"]
    grandchild1_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"]
    child1_mol, child2_mol = mock_get_actions(
        root_mol,
        tuple([root_smi]),
        [child1_smi, child2_smi],
        [0.3, 0.1],
    )
    grandchild1_mol, *_ = mock_get_actions(
        child1_mol[1],
        tuple(child1_smi),
        [grandchild1_smi],
        [0.3],
    )
    mock_get_actions(
        child2_mol[1],
        tuple(child2_smi),
        [None],
        [0.3],
    )  # Will try to expand child2
    mock_stock(finder.config, child1_mol[0], child1_mol[2], *grandchild1_mol)
    finder.target_mol = root_mol
    finder.config.iteration_limit = 10

    finder.tree_search()

    nodes = list(finder.tree.graph())
    assert len(nodes) == 4
    assert nodes[0].state.mols == [root_mol]
    assert nodes[1].state.mols == child1_mol
    assert nodes[2].state.mols == [child1_mol[0], child1_mol[2]
                                   ] + grandchild1_mol
    assert nodes[3].state.mols == child2_mol
    assert finder.search_stats["iterations"] == 10
Exemple #7
0
def process_smiles_list(finder: AiZynthFinder,
                        target_smiles: List[str]) -> List[Dict[str, str]]:
    stats_list: List[Dict[str, str]] = []
    for smiles in target_smiles:
        finder.target_smiles = smiles
        finder.tree_search()
        finder.build_routes()
        stats_list.append(finder.extract_statistics())
    return stats_list
def test_three_expansions_not_solved(default_config, mock_get_actions,
                                     mock_create_root, mock_stock):
    """
    Test the building of this tree:
                root
                  |
                child 1
                  |
                child 2
                  |
                child 3
        - child 3 state is not solved (not in stock)
    """
    finder = AiZynthFinder()
    root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
    root_mol = mock_create_root(root_smi, finder.config)
    child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"]
    child2_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"]
    child3_smi = ["O=C(Cl)c1ccccc1"]
    child1_mol, *_ = mock_get_actions(root_mol, tuple([root_smi]),
                                      [child1_smi], [0.3])
    child2_mol, *_ = mock_get_actions(
        child1_mol[1],
        tuple(child1_smi),
        [child2_smi],
        [0.3],
    )
    smiles_state2 = [child1_smi[0], child1_smi[2]] + child2_smi
    child3_mol, *_ = mock_get_actions(
        child2_mol[1],
        tuple(smiles_state2),
        [child3_smi],
        [0.3],
    )
    mock_stock(finder.config, child1_mol[0], child1_mol[2], child2_mol[0])
    finder.target_mol = root_mol
    finder.config.return_first = True
    finder.config.max_transforms = 2
    finder.config.iteration_limit = 15

    finder.tree_search()

    nodes = list(finder.tree.graph())
    assert len(nodes) == 4
    assert nodes[0].state.mols == [root_mol]
    assert nodes[1].state.mols == child1_mol
    assert nodes[2].state.mols == [child1_mol[0], child1_mol[2]] + child2_mol
    expected_list = [child1_mol[0], child1_mol[2], child2_mol[0]] + child3_mol
    assert nodes[3].state.mols == expected_list
    assert not nodes[3].state.is_solved
    assert finder.search_stats["iterations"] == 15
def test_two_expansions_two_children(mock_get_actions, mock_create_root,
                                     mock_stock):
    """
    Test the building of this tree:
                root
            /           \
        child 1        child 2
            |             |
        grandchild 1   grandchild 2
    """
    finder = AiZynthFinder()
    root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
    root_mol = mock_create_root(root_smi, finder.config)
    child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"]
    child2_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F"]
    grandchild_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"]
    child1_mol, child2_mol = mock_get_actions(
        root_mol,
        tuple([root_smi]),
        [child1_smi, child2_smi],
        [0.3, 0.1],
    )
    grandchild1_mol = mock_get_actions(
        child1_mol[1],
        tuple(child1_smi),
        [grandchild_smi],
        [0.3],
    )
    grandchild2_mol = mock_get_actions(
        child2_mol[1],
        tuple(child2_smi),
        [grandchild_smi],
        [0.3],
    )
    mock_stock(finder.config, child1_mol[0], child1_mol[2],
               *grandchild1_mol[0])
    finder.target_mol = root_mol

    finder.tree_search()

    nodes = list(finder.tree.graph())
    assert len(nodes) == 5
    assert nodes[0].state.mols == [root_mol]
    assert nodes[1].state.mols == child1_mol
    assert nodes[2].state.mols == [child1_mol[0], child1_mol[2]
                                   ] + grandchild1_mol[0]
    assert nodes[3].state.mols == child2_mol
    assert nodes[4].state.mols == [child2_mol[0]] + grandchild2_mol[0]
    assert finder.search_stats["iterations"] == 100
Exemple #10
0
def _process_multi_smiles(filename: str, finder: AiZynthFinder,
                          output_name: str, do_clustering: bool) -> None:
    output_name = output_name or "output.hdf5"
    with open(filename, "r") as fileobj:
        smiles = [line.strip() for line in fileobj.readlines()]

    results = defaultdict(list)
    for smi in smiles:
        finder.target_smiles = smi
        finder.prepare_tree()
        search_time = finder.tree_search()
        finder.build_routes()
        stats = finder.extract_statistics()

        logger().info(f"Done with {smi} in {search_time:.3} s")
        if do_clustering:
            _do_clustering(finder, stats, detailed_results=True)
        for key, value in stats.items():
            results[key].append(value)
        results["top_scores"].append(", ".join(
            "%.4f" % score for score in finder.routes.scores))
        results["trees"].append(finder.routes.dicts)

    data = pd.DataFrame.from_dict(results)
    with warnings.catch_warnings():  # This wil suppress a PerformanceWarning
        warnings.simplefilter("ignore")
        data.to_hdf(output_name, key="table", mode="w")
    logger().info(f"Output saved to {output_name}")
Exemple #11
0
def generate_images(smiles, out_dir, config):
    finder = AiZynthFinder(configfile=config)
    finder.stock.select("zinc")
    finder.expansion_policy.select("uspto")
    finder.filter_policy.select("uspto")

    if not os.path.exists(out_dir):
        os.mkdir(out_dir)
    finder.target_smiles = smiles
    finder.tree_search()
    finder.build_routes()
    if finder.routes.images:
        for n, image in enumerate(finder.routes.images):
            image.save(f"{out_dir}/route{n:03d}.png")
        # Combine 4 images into a single image by ImageMagick
        images = glob.glob(f'{out_dir}/route*.png')
        for i, imgs in enumerate(chunked(sorted(images), 3)):
            concat_images(imgs, f"{out_dir}/result{i}.png")
def test_two_expansions_prune_cyclic(mock_get_actions, mock_create_root,
                                     mock_stock):
    """
    Test the building of this tree:
                root
                  |
                child 1
                  |
                child 2
    Child 2 will not be rejected, but the tree search will not end, so catch an exception and
    assert on what we got.
    """
    finder = AiZynthFinder()
    root_smi = "COc1cc2cc(-c3ccc(OC(C)=O)c(OC(C)=O)c3)[n+](C)c(C)c2cc1OC"
    root_mol = mock_create_root(root_smi, finder.config)
    child1_smi = ["COc1cc2cc(-c3ccc(O)c(OC(C)=O)c3)[n+](C)c(C)c2cc1OC"]
    child2_smi = ["COc1cc2cc(-c3ccc(OC(C)=O)c(OC(C)=O)c3)[n+](C)c(C)c2cc1OC"]
    child1_mol, *_ = mock_get_actions(root_mol, tuple([root_smi]),
                                      [child1_smi], [0.3])
    child2_mol, *_ = mock_get_actions(
        child1_mol[0],
        tuple(child1_smi),
        [child2_smi],
        [0.3],
    )
    mock_stock(finder.config)
    finder.target_mol = root_mol
    finder.config.iteration_limit = 1
    finder.config.prune_cycles_in_search = False

    try:
        finder.tree_search()
    except KeyError:
        pass

    nodes = list(finder.tree.graph())
    assert len(nodes) == 4
    assert nodes[0].state.mols == [root_mol]
    assert nodes[1].state.mols == child1_mol
    assert nodes[2].state.mols == child2_mol
    assert finder.search_stats["iterations"] == 1
def test_two_expansions_no_expandable_root(mock_get_actions, mock_create_root,
                                           mock_stock):
    """
    Test the following scenario:
                root
                  |
              child 1 (+)

        - child 1 will be selected first for expansion (iteration 1)
        - it has no children that can be expanded (marked by +)
        -- end of iteration 1
        - iteration 2 starts but selecting a leaf will raise an exception
        -- will continue to iterate until reached number of iteration (set 10 in the test)
        * nodes in tree will be root, child 1
    """
    finder = AiZynthFinder()
    root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1"
    root_mol = mock_create_root(root_smi, finder.config)
    child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"]
    child1_mol, *_ = mock_get_actions(root_mol, tuple([root_smi]),
                                      [child1_smi], [0.3])
    mock_get_actions(
        child1_mol[1],
        tuple(child1_smi),
        [None],
        [0.3],
    )  # Will try to expand child1
    mock_stock(finder.config, child1_mol[0], child1_mol[2])
    finder.target_mol = root_mol
    finder.config.return_first = True
    finder.config.iteration_limit = 10

    finder.tree_search()

    nodes = list(finder.tree.graph())
    assert len(nodes) == 2
    assert nodes[0].state.mols == [root_mol]
    assert nodes[1].state.mols == child1_mol
    assert finder.search_stats["iterations"] == 10
Exemple #14
0
class AiZynthApp:
    """
    Interface class to be used in a Jupyter Notebook.
    Provides a basic GUI to setup and analyze the tree search.

    Should be instantiated with the path of a yaml file with configuration:

    .. code-block::

        from aizynthfinder.interfaces import AiZynthApp
        configfile = "/path/to/configfile.yaml"
        app = AiZynthApp(configfile)

    :ivar finder: the finder instance
    :vartype finder: AiZynthFinder

    :param configfile: the path to yaml file with configuration
    :type configfile: str
    :param setup: if True will create and display the GUI on instatiation, defaults to True
    :type setup: bool, optional
    """

    def __init__(self, configfile, setup=True):
        setup_logger(logging.INFO)
        self.finder = AiZynthFinder(configfile=configfile)
        self._input = dict()
        self._output = dict()
        self._buttons = dict()
        if setup:
            self.setup()

    def setup(self):
        """
        Create the widgets and display the GUI.
        This is typically done on instatiation, but this method
        if for more advanced uses.
        """
        self._create_input_widgets()
        self._create_search_widgets()
        self._create_route_widgets()

    def _create_input_widgets(self):
        self._input["smiles"] = Text(description="SMILES", continuous_update=False)
        self._input["smiles"].observe(self._show_mol, names="value")
        display(self._input["smiles"])
        self._output["smiles"] = Output(
            layout={"border": "1px solid silver", "width": "50%", "height": "180px"}
        )
        display(self._output["smiles"])

        self._input["stocks"] = [
            Checkbox(value=True, description=key, layout={"justify": "left"})
            for key in self.finder.stock.available_stocks()
        ]
        box_stocks = VBox(
            [Label("Stocks")] + self._input["stocks"],
            layout={"border": "1px solid silver"},
        )

        self._input["policy"] = widgets.Dropdown(
            options=self.finder.policy.available_policies(),
            description="Neural Policy:",
            style={"description_width": "initial"},
        )

        max_time_box = self._make_slider_input("time_limit", "Time (min)", 1, 120)
        self._input["time_limit"].value = self.finder.config.time_limit / 60
        max_iter_box = self._make_slider_input(
            "iteration_limit", "Max Iterations", 100, 2000
        )
        self._input["iteration_limit"].value = self.finder.config.iteration_limit
        self._input["return_first"] = widgets.Checkbox(
            value=self.finder.config.return_first,
            description="Return first solved route",
        )
        vbox = VBox(
            [
                self._input["policy"],
                max_time_box,
                max_iter_box,
                self._input["return_first"],
            ]
        )
        box_options = HBox([box_stocks, vbox])

        self._input["C"] = FloatText(description="C", value=self.finder.config.C)
        self._input["max_transforms"] = BoundedIntText(
            description="Max steps for substrates",
            min=1,
            max=6,
            value=self.finder.config.max_transforms,
            style={"description_width": "initial"},
        )
        self._input["cutoff_cumulative"] = BoundedFloatText(
            description="Policy cutoff cumulative",
            min=0,
            max=1,
            value=self.finder.config.cutoff_cumulative,
            style={"description_width": "initial"},
        )
        self._input["cutoff_number"] = BoundedIntText(
            description="Policy cutoff number",
            min=1,
            max=1000,
            value=self.finder.config.cutoff_number,
            style={"description_width": "initial"},
        )
        self._input["exclude_target_from_stock"] = widgets.Checkbox(
            value=self.finder.config.exclude_target_from_stock,
            description="Exclude target from stock",
        )
        box_advanced = VBox(
            [
                self._input["C"],
                self._input["max_transforms"],
                self._input["cutoff_cumulative"],
                self._input["cutoff_number"],
                self._input["exclude_target_from_stock"],
            ]
        )

        children = [box_options, box_advanced]
        tab = widgets.Tab()
        tab.children = children
        tab.set_title(0, "Options")
        tab.set_title(1, "Advanced")
        display(tab)

    def _create_route_widgets(self):
        self._buttons["show_routes"] = Button(description="Show Reactions")
        self._buttons["show_routes"].on_click(self._on_display_button_clicked)
        self._input["route"] = Dropdown(options=[], description="Routes: ",)
        self._input["route"].observe(self._on_change_route_option)
        display(HBox([self._buttons["show_routes"], self._input["route"]]))

        self._output["routes"] = widgets.Output(
            layout={"border": "1px solid silver", "width": "99%"}
        )
        display(self._output["routes"])

    def _create_search_widgets(self):
        self._buttons["execute"] = Button(description="Run Search")
        self._buttons["execute"].on_click(self._on_exec_button_clicked)

        self._buttons["extend"] = widgets.Button(description="Extend Search")
        self._buttons["extend"].on_click(self._on_extend_button_clicked)
        display(HBox([self._buttons["execute"], self._buttons["extend"]]))

        self._output["tree_search"] = widgets.Output(
            layout={
                "border": "1px solid silver",
                "width": "99%",
                "height": "300px",
                "overflow": "auto",
            }
        )
        display(self._output["tree_search"])

    def _make_slider_input(self, label, description, min_val, max_val):
        label_widget = Label(description)
        slider = widgets.IntSlider(
            continuous_update=True, min=min_val, max=max_val, readout=False
        )
        self._input[label] = IntText(continuous_update=True, layout={"width": "80px"})
        widgets.link((self._input[label], "value"), (slider, "value"))
        return HBox([label_widget, slider, self._input[label]])

    def _on_change_route_option(self, change):
        if change["name"] != "index":
            return
        self._show_route(self._input["route"].index)

    def _on_exec_button_clicked(self, _):
        self._toggle_button(False)
        self._prepare_search()
        self._tree_search()
        self._toggle_button(True)

    def _on_extend_button_clicked(self, _):
        self._toggle_button(False)
        self._tree_search()
        self._toggle_button(True)

    def _on_display_button_clicked(self, _):
        self._toggle_button(False)
        self.finder.build_routes()
        self.finder.routes.make_images()
        self._input["route"].options = [
            f"Option {i}" for i, _ in enumerate(self.finder.routes, 1)
        ]
        self._show_route(0)
        self._toggle_button(True)

    def _prepare_search(self):
        self._output["tree_search"].clear_output()
        with self._output["tree_search"]:
            selected_stocks = [
                cb.description for cb in self._input["stocks"] if cb.value
            ]
            self.finder.stock.select_stocks(selected_stocks)
            self.finder.policy.select_policy(self._input["policy"].value)
            self.finder.config.update(
                **{
                    "C": self._input["C"].value,
                    "max_transforms": self._input["max_transforms"].value,
                    "cutoff_cumulative": self._input["cutoff_cumulative"].value,
                    "cutoff_number": int(self._input["cutoff_number"].value),
                    "return_first": self._input["return_first"].value,
                    "time_limit": self._input["time_limit"].value * 60,
                    "iteration_limit": self._input["iteration_limit"].value,
                    "exclude_target_from_stock": self._input[
                        "exclude_target_from_stock"
                    ].value,
                }
            )

            smiles = self._input["smiles"].value
            print("Setting target molecule with smiles: %s" % smiles)
            self.finder.target_smiles = smiles
            self.finder.prepare_tree()

    def _show_mol(self, change):
        self._output["smiles"].clear_output()
        with self._output["smiles"]:
            mol = Chem.MolFromSmiles(change["new"])
            display(mol)

    def _show_route(self, index):
        if index is None or index >= len(self.finder.routes):
            return

        node = self.finder.routes[index]["node"]
        state = node.state
        status = "Solved" if state.is_solved else "Not Solved"

        self._output["routes"].clear_output()
        with self._output["routes"]:
            display(HTML("<H2>%s" % status))
            display("Route Score: %0.3F" % state.score)
            display(HTML("<H2>Compounds to Procure"))
            display(state.to_image())
            display(HTML("<H2>Steps"))
            display(self.finder.routes[index]["image"])

    def _toggle_button(self, on):
        for button in self._buttons.values():
            button.disabled = not on

    def _tree_search(self):
        with self._output["tree_search"]:
            self.finder.tree_search(show_progress=True)
            print("Tree search completed.")
class AiZynthApp:
    """
    Interface class to be used in a Jupyter Notebook.
    Provides a basic GUI to setup and analyze the tree search.

    Should be instantiated with the path of a yaml file with configuration:

    .. code-block::

        from aizynthfinder.interfaces import AiZynthApp
        configfile = "/path/to/configfile.yaml"
        app = AiZynthApp(configfile)

    :ivar finder: the finder instance

    :param configfile: the path to yaml file with configuration
    :param setup: if True will create and display the GUI on instantiation, defaults to True
    """

    def __init__(self, configfile: str, setup: bool = True) -> None:
        setup_logger(logging.INFO)
        self.finder = AiZynthFinder(configfile=configfile)
        self._input: StrDict = dict()
        self._output: StrDict = dict()
        self._buttons: StrDict = dict()
        if setup:
            self.setup()

    def setup(self) -> None:
        """
        Create the widgets and display the GUI.
        This is typically done on instantiation, but this method
        if for more advanced uses.
        """
        self._create_input_widgets()
        self._create_search_widgets()
        self._create_route_widgets()

    def _create_input_widgets(self) -> None:
        self._input["smiles"] = Text(description="SMILES", continuous_update=False)
        self._input["smiles"].observe(self._show_mol, names="value")
        display(self._input["smiles"])
        self._output["smiles"] = Output(
            layout={"border": "1px solid silver", "width": "50%", "height": "180px"}
        )
        display(self._output["smiles"])

        self._input["stocks"] = [
            Checkbox(
                value=True,
                description=key,
                style={"description_width": "initial"},
                layout={"justify": "left"},
            )
            for key in self.finder.stock.items
        ]

        list_ = [Label("Limit atom occurrences")]
        self._input["stocks_atom_count_on"] = []
        self._input["stocks_atom_count"] = []
        current_criteria = self.finder.stock.stop_criteria.get("counts", {})
        for atom in ["C", "O", "N"]:
            chk_box = Checkbox(
                value=atom in current_criteria,
                description=atom,
                layout={"justify": "left", "width": "80px"},
                style={"description_width": "initial"},
            )
            self._input["stocks_atom_count_on"].append(chk_box)
            inpt = BoundedIntText(
                value=current_criteria.get(atom, 0),
                min=0,
                layout={"width": "80px"},
            )
            self._input["stocks_atom_count"].append(inpt)
            list_.append(HBox([chk_box, inpt]))
        box_stocks = VBox(
            [Label("Stocks")] + self._input["stocks"] + list_,
            layout={"border": "1px solid silver"},
        )

        self._input["policy"] = widgets.Dropdown(
            options=self.finder.expansion_policy.items,
            description="Expansion Policy:",
            style={"description_width": "initial"},
        )

        self._input["filter"] = widgets.Dropdown(
            options=["None"] + self.finder.filter_policy.items,
            description="Filter Policy:",
            style={"description_width": "initial"},
        )

        max_time_box = self._make_slider_input("time_limit", "Time (min)", 1, 120)
        self._input["time_limit"].value = self.finder.config.time_limit / 60
        max_iter_box = self._make_slider_input(
            "iteration_limit", "Max Iterations", 100, 2000
        )
        self._input["iteration_limit"].value = self.finder.config.iteration_limit
        self._input["return_first"] = widgets.Checkbox(
            value=self.finder.config.return_first,
            description="Return first solved route",
        )
        vbox = VBox(
            [
                self._input["policy"],
                self._input["filter"],
                max_time_box,
                max_iter_box,
                self._input["return_first"],
            ]
        )
        box_options = HBox([box_stocks, vbox])

        self._input["C"] = FloatText(description="C", value=self.finder.config.C)
        self._input["max_transforms"] = BoundedIntText(
            description="Max steps for substrates",
            min=1,
            max=6,
            value=self.finder.config.max_transforms,
            style={"description_width": "initial"},
        )
        self._input["cutoff_cumulative"] = BoundedFloatText(
            description="Policy cutoff cumulative",
            min=0,
            max=1,
            value=self.finder.config.cutoff_cumulative,
            style={"description_width": "initial"},
        )
        self._input["cutoff_number"] = BoundedIntText(
            description="Policy cutoff number",
            min=1,
            max=1000,
            value=self.finder.config.cutoff_number,
            style={"description_width": "initial"},
        )
        self._input["filter_cutoff"] = BoundedFloatText(
            description="Filter cutoff",
            min=0,
            max=1,
            value=self.finder.config.filter_cutoff,
            style={"description_width": "initial"},
        )
        self._input["exclude_target_from_stock"] = widgets.Checkbox(
            value=self.finder.config.exclude_target_from_stock,
            description="Exclude target from stock",
        )
        box_advanced = VBox(
            [
                self._input["C"],
                self._input["max_transforms"],
                self._input["cutoff_cumulative"],
                self._input["cutoff_number"],
                self._input["filter_cutoff"],
                self._input["exclude_target_from_stock"],
            ]
        )

        children = [box_options, box_advanced]
        tab = widgets.Tab()
        tab.children = children
        tab.set_title(0, "Options")
        tab.set_title(1, "Advanced")
        display(tab)

    def _create_route_widgets(self) -> None:
        self._input["scorer"] = widgets.Dropdown(
            options=self.finder.scorers.names(),
            description="Reorder by:",
            style={"description_width": "initial"},
        )
        self._input["scorer"].observe(self._on_change_scorer)
        self._buttons["show_routes"] = Button(description="Show Reactions")
        self._buttons["show_routes"].on_click(self._on_display_button_clicked)
        self._input["route"] = Dropdown(
            options=[],
            description="Routes: ",
        )
        self._input["route"].observe(self._on_change_route_option)
        display(
            HBox(
                [
                    self._buttons["show_routes"],
                    self._input["route"],
                    self._input["scorer"],
                ]
            )
        )

        self._output["routes"] = widgets.Output(
            layout={"border": "1px solid silver", "width": "99%"}
        )
        display(self._output["routes"])

    def _create_search_widgets(self) -> None:
        self._buttons["execute"] = Button(description="Run Search")
        self._buttons["execute"].on_click(self._on_exec_button_clicked)

        self._buttons["extend"] = widgets.Button(description="Extend Search")
        self._buttons["extend"].on_click(self._on_extend_button_clicked)
        display(HBox([self._buttons["execute"], self._buttons["extend"]]))

        self._output["tree_search"] = widgets.Output(
            layout={
                "border": "1px solid silver",
                "width": "99%",
                "height": "300px",
                "overflow": "auto",
            }
        )
        display(self._output["tree_search"])

    def _make_slider_input(self, label, description, min_val, max_val) -> HBox:
        label_widget = Label(description)
        slider = widgets.IntSlider(
            continuous_update=True, min=min_val, max=max_val, readout=False
        )
        self._input[label] = IntText(continuous_update=True, layout={"width": "80px"})
        widgets.link((self._input[label], "value"), (slider, "value"))
        return HBox([label_widget, slider, self._input[label]])

    def _on_change_route_option(self, change) -> None:
        if change["name"] != "index":
            return
        self._show_route(self._input["route"].index)

    def _on_change_scorer(self, change) -> None:
        if self.finder.routes is None or change["name"] != "index":
            return
        scorer = self.finder.scorers[self._input["scorer"].value]
        self.finder.routes.rescore(scorer)
        self._show_route(self._input["route"].index)

    def _on_exec_button_clicked(self, _) -> None:
        self._toggle_button(False)
        self._prepare_search()
        self._tree_search()
        self._toggle_button(True)

    def _on_extend_button_clicked(self, _) -> None:
        self._toggle_button(False)
        self._tree_search()
        self._toggle_button(True)

    def _on_display_button_clicked(self, _) -> None:
        self._toggle_button(False)
        self.finder.build_routes()
        self.finder.routes.make_images()
        self.finder.routes.compute_scores(*self.finder.scorers.objects())
        self._input["route"].options = [
            f"Option {i}" for i, _ in enumerate(self.finder.routes, 1)  # type: ignore
        ]
        self._show_route(0)
        self._toggle_button(True)

    def _prepare_search(self) -> None:
        self._output["tree_search"].clear_output()
        with self._output["tree_search"]:
            selected_stocks = [
                cb.description for cb in self._input["stocks"] if cb.value
            ]
            self.finder.stock.select(selected_stocks)
            atom_count_limits = {}
            for cb_input, value_input in zip(
                self._input["stocks_atom_count_on"], self._input["stocks_atom_count"]
            ):
                if cb_input.value:
                    atom_count_limits[cb_input.description] = value_input.value
            self.finder.stock.set_stop_criteria({"counts": atom_count_limits})
            self.finder.expansion_policy.select(self._input["policy"].value)
            if self._input["filter"].value == "None":
                self.finder.filter_policy.deselect()
            else:
                self.finder.filter_policy.select(self._input["policy"].value)
            self.finder.config.update(
                **{
                    "C": self._input["C"].value,
                    "max_transforms": self._input["max_transforms"].value,
                    "cutoff_cumulative": self._input["cutoff_cumulative"].value,
                    "cutoff_number": int(self._input["cutoff_number"].value),
                    "return_first": self._input["return_first"].value,
                    "time_limit": self._input["time_limit"].value * 60,
                    "iteration_limit": self._input["iteration_limit"].value,
                    "filter_cutoff": self._input["filter_cutoff"].value,
                    "exclude_target_from_stock": self._input[
                        "exclude_target_from_stock"
                    ].value,
                }
            )

            smiles = self._input["smiles"].value
            print("Setting target molecule with smiles: %s" % smiles)
            self.finder.target_smiles = smiles
            self.finder.prepare_tree()

    def _show_mol(self, change) -> None:
        self._output["smiles"].clear_output()
        with self._output["smiles"]:
            mol = Chem.MolFromSmiles(change["new"])
            display(mol)

    def _show_route(self, index) -> None:
        if (
            index is None
            or self.finder.routes is None
            or index >= len(self.finder.routes)
        ):
            return

        route = self.finder.routes[index]
        state = route["node"].state
        status = "Solved" if state.is_solved else "Not Solved"

        self._output["routes"].clear_output()
        with self._output["routes"]:
            display(HTML("<H2>%s" % status))
            table_content = "".join(
                f"<tr><td>{name}</td><td>{score:.4f}</td></tr>"
                for name, score in route["all_score"].items()
            )
            display(HTML(f"<table>{table_content}</table>"))
            display(HTML("<H2>Compounds to Procure"))
            display(state.to_image())
            display(HTML("<H2>Steps"))
            display(self.finder.routes[index]["image"])

    def _toggle_button(self, on) -> None:
        for button in self._buttons.values():
            button.disabled = not on

    def _tree_search(self) -> None:
        with self._output["tree_search"]:
            self.finder.tree_search(show_progress=True)
            print("Tree search completed.")