def test_two_expansions_cyclic(mock_get_actions, mock_create_root, mock_stock): """ Test the building of this tree: root | child 1 | child 2 But making child 2 should be rejected because child 2 == root """ finder = AiZynthFinder() root_smi = "COc1cc2cc(-c3ccc(OC(C)=O)c(OC(C)=O)c3)[n+](C)c(C)c2cc1OC" root_mol = mock_create_root(root_smi, finder.config) child1_smi = ["COc1cc2cc(-c3ccc(O)c(OC(C)=O)c3)[n+](C)c(C)c2cc1OC"] child2_smi = ["COc1cc2cc(-c3ccc(OC(C)=O)c(OC(C)=O)c3)[n+](C)c(C)c2cc1OC"] child1_mol, *_ = mock_get_actions(root_mol, tuple([root_smi]), [child1_smi], [0.3]) child2_mol, *_ = mock_get_actions( child1_mol[0], tuple(child1_smi), [child2_smi], [0.3], ) mock_stock(finder.config) finder.target_mol = root_mol finder.config.iteration_limit = 1 finder.tree_search() nodes = list(finder.tree.graph()) assert len(nodes) == 2 assert nodes[0].state.mols == [root_mol] assert nodes[1].state.mols == child1_mol assert finder.search_stats["iterations"] == 1
def test_three_expansions_no_reactants_second_level( mock_get_actions, mock_create_root, mock_stock ): """ Test the following scenario: root / \ child 1 child 2 | | grandchild 1 (+) grandchild 2 (*) - child 1 will be selected first for expansion (iteration 1) - grandchild 1 will be selected next, - it has no children that can be expanded (marked by x) -- end of iteration 1 - child 2 will be selected for expansion (iteration 2) - grandchild 2 will be selected next and it is in stock (marked by *) -- a solution is found and the tree search is terminated * nodes in tree will be root, child1, grandchild 1, child2, grandchild 2 """ finder = AiZynthFinder() root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" root_mol = mock_create_root(root_smi, finder.config) child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] child2_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F"] grandchild1_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"] grandchild2_smi = ["N#Cc1cccc(N)c1", "O=C(Cl)c1ccc(F)c(F)c1"] child1_mol, child2_mol = mock_get_actions( root_mol, tuple([root_smi]), [child1_smi, child2_smi], [0.3, 0.1], ) grandchild1_mol, *_ = mock_get_actions( child1_mol[1], tuple(child1_smi), [grandchild1_smi], [0.3], ) smiles_state1 = [child1_smi[0], child1_smi[2]] + grandchild1_smi mock_get_actions( grandchild1_mol[1], tuple(smiles_state1), [None], [0.3], ) # Will try to expand grandchild 1 grandchild2_mol, *_ = mock_get_actions( child2_mol[1], tuple(child2_smi), [grandchild2_smi], [0.3], ) mock_stock( finder.config, child1_mol[0], child1_mol[2], grandchild1_mol[0], *grandchild2_mol ) finder.target_mol = root_mol finder.config.return_first = True finder.tree_search() nodes = list(finder.tree.graph()) assert len(nodes) == 5 assert nodes[0].state.mols == [root_mol] assert nodes[1].state.mols == child1_mol assert nodes[2].state.mols == [child1_mol[0], child1_mol[2]] + grandchild1_mol assert nodes[3].state.mols == child2_mol assert nodes[4].state.mols == [child2_mol[0]] + grandchild2_mol assert finder.search_stats["iterations"] == 2
def test_two_expansions(mock_get_actions, mock_create_root, mock_stock): """ Test the building of this tree: root | child 1 | child 2 """ finder = AiZynthFinder() root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" root_mol = mock_create_root(root_smi, finder.config) child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] child2_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"] child1_mol, *_ = mock_get_actions(root_mol, tuple([root_smi]), [child1_smi], [0.3]) child2_mol, *_ = mock_get_actions( child1_mol[1], tuple(child1_smi), [child2_smi], [0.3], ) mock_stock(finder.config, child1_mol[0], child1_mol[2], *child2_mol) finder.target_mol = root_mol finder.config.return_first = True finder.tree_search() nodes = list(finder.tree.graph()) assert len(nodes) == 3 assert nodes[0].state.mols == [root_mol] assert nodes[1].state.mols == child1_mol assert nodes[2].state.mols == [child1_mol[0], child1_mol[2]] + child2_mol assert finder.search_stats["iterations"] == 1
def test_one_expansion(mock_get_actions, mock_create_root, mock_stock): """ Test the building of this tree: root | child 1 """ finder = AiZynthFinder() root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" root_mol = mock_create_root(root_smi, finder.config) child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] child1_mol, *_ = mock_get_actions(root_mol, tuple([root_smi]), [child1_smi], [0.3]) mock_stock(finder.config, *child1_mol) finder.target_mol = root_mol # Test first with return_first finder.config.return_first = True finder.tree_search() nodes = list(finder.tree.graph()) assert len(nodes) == 2 assert nodes[0].state.mols == [root_mol] assert nodes[1].state.mols == child1_mol assert finder.search_stats["iterations"] == 1 assert finder.search_stats["returned_first"] # then test with iteration limit finder.config.return_first = False finder.config.iteration_limit = 45 finder.prepare_tree() finder.tree_search() assert len(finder.tree.graph()) == 2 assert finder.search_stats["iterations"] == 45 assert not finder.search_stats["returned_first"]
def _process_single_smiles( smiles: str, finder: AiZynthFinder, output_name: str, do_clustering: bool, route_distance_model: str = None, ) -> None: output_name = output_name or "trees.json" finder.target_smiles = smiles finder.prepare_tree() finder.tree_search(show_progress=True) finder.build_routes() with open(output_name, "w") as fileobj: json.dump(finder.routes.dicts, fileobj, indent=2) logger().info(f"Trees saved to {output_name}") scores = ", ".join("%.4f" % score for score in finder.routes.scores) logger().info(f"Scores for best routes: {scores}") stats = finder.extract_statistics() if do_clustering: _do_clustering(finder, stats, detailed_results=False, model_path=route_distance_model) stats_str = "\n".join(f"{key.replace('_', ' ')}: {value}" for key, value in stats.items()) logger().info(stats_str)
def test_two_expansions_no_reactants_second_child(mock_get_actions, mock_create_root, mock_stock): """ Test the following scenario: root / \ child 1 child 2 (+) | grandchild 1 (*) - child 1 will be selected first for expansion (iteration 1) - grandchild 1 will be selected next and it is in stock (marked by *) -- end of iteration 1 - child 2 will be selected for expansion (iteration 2) - it has no children that can be expanded (marked with +) -- will continue to iterate until reached number of iteration (set 10 in the test) * nodes in tree will be root, child1, grandchild 1, child2 """ finder = AiZynthFinder() root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" root_mol = mock_create_root(root_smi, finder.config) child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] child2_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F"] grandchild1_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"] child1_mol, child2_mol = mock_get_actions( root_mol, tuple([root_smi]), [child1_smi, child2_smi], [0.3, 0.1], ) grandchild1_mol, *_ = mock_get_actions( child1_mol[1], tuple(child1_smi), [grandchild1_smi], [0.3], ) mock_get_actions( child2_mol[1], tuple(child2_smi), [None], [0.3], ) # Will try to expand child2 mock_stock(finder.config, child1_mol[0], child1_mol[2], *grandchild1_mol) finder.target_mol = root_mol finder.config.iteration_limit = 10 finder.tree_search() nodes = list(finder.tree.graph()) assert len(nodes) == 4 assert nodes[0].state.mols == [root_mol] assert nodes[1].state.mols == child1_mol assert nodes[2].state.mols == [child1_mol[0], child1_mol[2] ] + grandchild1_mol assert nodes[3].state.mols == child2_mol assert finder.search_stats["iterations"] == 10
def process_smiles_list(finder: AiZynthFinder, target_smiles: List[str]) -> List[Dict[str, str]]: stats_list: List[Dict[str, str]] = [] for smiles in target_smiles: finder.target_smiles = smiles finder.tree_search() finder.build_routes() stats_list.append(finder.extract_statistics()) return stats_list
def test_three_expansions_not_solved(default_config, mock_get_actions, mock_create_root, mock_stock): """ Test the building of this tree: root | child 1 | child 2 | child 3 - child 3 state is not solved (not in stock) """ finder = AiZynthFinder() root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" root_mol = mock_create_root(root_smi, finder.config) child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] child2_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"] child3_smi = ["O=C(Cl)c1ccccc1"] child1_mol, *_ = mock_get_actions(root_mol, tuple([root_smi]), [child1_smi], [0.3]) child2_mol, *_ = mock_get_actions( child1_mol[1], tuple(child1_smi), [child2_smi], [0.3], ) smiles_state2 = [child1_smi[0], child1_smi[2]] + child2_smi child3_mol, *_ = mock_get_actions( child2_mol[1], tuple(smiles_state2), [child3_smi], [0.3], ) mock_stock(finder.config, child1_mol[0], child1_mol[2], child2_mol[0]) finder.target_mol = root_mol finder.config.return_first = True finder.config.max_transforms = 2 finder.config.iteration_limit = 15 finder.tree_search() nodes = list(finder.tree.graph()) assert len(nodes) == 4 assert nodes[0].state.mols == [root_mol] assert nodes[1].state.mols == child1_mol assert nodes[2].state.mols == [child1_mol[0], child1_mol[2]] + child2_mol expected_list = [child1_mol[0], child1_mol[2], child2_mol[0]] + child3_mol assert nodes[3].state.mols == expected_list assert not nodes[3].state.is_solved assert finder.search_stats["iterations"] == 15
def test_two_expansions_two_children(mock_get_actions, mock_create_root, mock_stock): """ Test the building of this tree: root / \ child 1 child 2 | | grandchild 1 grandchild 2 """ finder = AiZynthFinder() root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" root_mol = mock_create_root(root_smi, finder.config) child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] child2_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F"] grandchild_smi = ["N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1"] child1_mol, child2_mol = mock_get_actions( root_mol, tuple([root_smi]), [child1_smi, child2_smi], [0.3, 0.1], ) grandchild1_mol = mock_get_actions( child1_mol[1], tuple(child1_smi), [grandchild_smi], [0.3], ) grandchild2_mol = mock_get_actions( child2_mol[1], tuple(child2_smi), [grandchild_smi], [0.3], ) mock_stock(finder.config, child1_mol[0], child1_mol[2], *grandchild1_mol[0]) finder.target_mol = root_mol finder.tree_search() nodes = list(finder.tree.graph()) assert len(nodes) == 5 assert nodes[0].state.mols == [root_mol] assert nodes[1].state.mols == child1_mol assert nodes[2].state.mols == [child1_mol[0], child1_mol[2] ] + grandchild1_mol[0] assert nodes[3].state.mols == child2_mol assert nodes[4].state.mols == [child2_mol[0]] + grandchild2_mol[0] assert finder.search_stats["iterations"] == 100
def _process_multi_smiles(filename: str, finder: AiZynthFinder, output_name: str, do_clustering: bool) -> None: output_name = output_name or "output.hdf5" with open(filename, "r") as fileobj: smiles = [line.strip() for line in fileobj.readlines()] results = defaultdict(list) for smi in smiles: finder.target_smiles = smi finder.prepare_tree() search_time = finder.tree_search() finder.build_routes() stats = finder.extract_statistics() logger().info(f"Done with {smi} in {search_time:.3} s") if do_clustering: _do_clustering(finder, stats, detailed_results=True) for key, value in stats.items(): results[key].append(value) results["top_scores"].append(", ".join( "%.4f" % score for score in finder.routes.scores)) results["trees"].append(finder.routes.dicts) data = pd.DataFrame.from_dict(results) with warnings.catch_warnings(): # This wil suppress a PerformanceWarning warnings.simplefilter("ignore") data.to_hdf(output_name, key="table", mode="w") logger().info(f"Output saved to {output_name}")
def generate_images(smiles, out_dir, config): finder = AiZynthFinder(configfile=config) finder.stock.select("zinc") finder.expansion_policy.select("uspto") finder.filter_policy.select("uspto") if not os.path.exists(out_dir): os.mkdir(out_dir) finder.target_smiles = smiles finder.tree_search() finder.build_routes() if finder.routes.images: for n, image in enumerate(finder.routes.images): image.save(f"{out_dir}/route{n:03d}.png") # Combine 4 images into a single image by ImageMagick images = glob.glob(f'{out_dir}/route*.png') for i, imgs in enumerate(chunked(sorted(images), 3)): concat_images(imgs, f"{out_dir}/result{i}.png")
def test_two_expansions_prune_cyclic(mock_get_actions, mock_create_root, mock_stock): """ Test the building of this tree: root | child 1 | child 2 Child 2 will not be rejected, but the tree search will not end, so catch an exception and assert on what we got. """ finder = AiZynthFinder() root_smi = "COc1cc2cc(-c3ccc(OC(C)=O)c(OC(C)=O)c3)[n+](C)c(C)c2cc1OC" root_mol = mock_create_root(root_smi, finder.config) child1_smi = ["COc1cc2cc(-c3ccc(O)c(OC(C)=O)c3)[n+](C)c(C)c2cc1OC"] child2_smi = ["COc1cc2cc(-c3ccc(OC(C)=O)c(OC(C)=O)c3)[n+](C)c(C)c2cc1OC"] child1_mol, *_ = mock_get_actions(root_mol, tuple([root_smi]), [child1_smi], [0.3]) child2_mol, *_ = mock_get_actions( child1_mol[0], tuple(child1_smi), [child2_smi], [0.3], ) mock_stock(finder.config) finder.target_mol = root_mol finder.config.iteration_limit = 1 finder.config.prune_cycles_in_search = False try: finder.tree_search() except KeyError: pass nodes = list(finder.tree.graph()) assert len(nodes) == 4 assert nodes[0].state.mols == [root_mol] assert nodes[1].state.mols == child1_mol assert nodes[2].state.mols == child2_mol assert finder.search_stats["iterations"] == 1
def test_two_expansions_no_expandable_root(mock_get_actions, mock_create_root, mock_stock): """ Test the following scenario: root | child 1 (+) - child 1 will be selected first for expansion (iteration 1) - it has no children that can be expanded (marked by +) -- end of iteration 1 - iteration 2 starts but selecting a leaf will raise an exception -- will continue to iterate until reached number of iteration (set 10 in the test) * nodes in tree will be root, child 1 """ finder = AiZynthFinder() root_smi = "CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1" root_mol = mock_create_root(root_smi, finder.config) child1_smi = ["CN1CCC(Cl)CC1", "N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F", "O"] child1_mol, *_ = mock_get_actions(root_mol, tuple([root_smi]), [child1_smi], [0.3]) mock_get_actions( child1_mol[1], tuple(child1_smi), [None], [0.3], ) # Will try to expand child1 mock_stock(finder.config, child1_mol[0], child1_mol[2]) finder.target_mol = root_mol finder.config.return_first = True finder.config.iteration_limit = 10 finder.tree_search() nodes = list(finder.tree.graph()) assert len(nodes) == 2 assert nodes[0].state.mols == [root_mol] assert nodes[1].state.mols == child1_mol assert finder.search_stats["iterations"] == 10
class AiZynthApp: """ Interface class to be used in a Jupyter Notebook. Provides a basic GUI to setup and analyze the tree search. Should be instantiated with the path of a yaml file with configuration: .. code-block:: from aizynthfinder.interfaces import AiZynthApp configfile = "/path/to/configfile.yaml" app = AiZynthApp(configfile) :ivar finder: the finder instance :vartype finder: AiZynthFinder :param configfile: the path to yaml file with configuration :type configfile: str :param setup: if True will create and display the GUI on instatiation, defaults to True :type setup: bool, optional """ def __init__(self, configfile, setup=True): setup_logger(logging.INFO) self.finder = AiZynthFinder(configfile=configfile) self._input = dict() self._output = dict() self._buttons = dict() if setup: self.setup() def setup(self): """ Create the widgets and display the GUI. This is typically done on instatiation, but this method if for more advanced uses. """ self._create_input_widgets() self._create_search_widgets() self._create_route_widgets() def _create_input_widgets(self): self._input["smiles"] = Text(description="SMILES", continuous_update=False) self._input["smiles"].observe(self._show_mol, names="value") display(self._input["smiles"]) self._output["smiles"] = Output( layout={"border": "1px solid silver", "width": "50%", "height": "180px"} ) display(self._output["smiles"]) self._input["stocks"] = [ Checkbox(value=True, description=key, layout={"justify": "left"}) for key in self.finder.stock.available_stocks() ] box_stocks = VBox( [Label("Stocks")] + self._input["stocks"], layout={"border": "1px solid silver"}, ) self._input["policy"] = widgets.Dropdown( options=self.finder.policy.available_policies(), description="Neural Policy:", style={"description_width": "initial"}, ) max_time_box = self._make_slider_input("time_limit", "Time (min)", 1, 120) self._input["time_limit"].value = self.finder.config.time_limit / 60 max_iter_box = self._make_slider_input( "iteration_limit", "Max Iterations", 100, 2000 ) self._input["iteration_limit"].value = self.finder.config.iteration_limit self._input["return_first"] = widgets.Checkbox( value=self.finder.config.return_first, description="Return first solved route", ) vbox = VBox( [ self._input["policy"], max_time_box, max_iter_box, self._input["return_first"], ] ) box_options = HBox([box_stocks, vbox]) self._input["C"] = FloatText(description="C", value=self.finder.config.C) self._input["max_transforms"] = BoundedIntText( description="Max steps for substrates", min=1, max=6, value=self.finder.config.max_transforms, style={"description_width": "initial"}, ) self._input["cutoff_cumulative"] = BoundedFloatText( description="Policy cutoff cumulative", min=0, max=1, value=self.finder.config.cutoff_cumulative, style={"description_width": "initial"}, ) self._input["cutoff_number"] = BoundedIntText( description="Policy cutoff number", min=1, max=1000, value=self.finder.config.cutoff_number, style={"description_width": "initial"}, ) self._input["exclude_target_from_stock"] = widgets.Checkbox( value=self.finder.config.exclude_target_from_stock, description="Exclude target from stock", ) box_advanced = VBox( [ self._input["C"], self._input["max_transforms"], self._input["cutoff_cumulative"], self._input["cutoff_number"], self._input["exclude_target_from_stock"], ] ) children = [box_options, box_advanced] tab = widgets.Tab() tab.children = children tab.set_title(0, "Options") tab.set_title(1, "Advanced") display(tab) def _create_route_widgets(self): self._buttons["show_routes"] = Button(description="Show Reactions") self._buttons["show_routes"].on_click(self._on_display_button_clicked) self._input["route"] = Dropdown(options=[], description="Routes: ",) self._input["route"].observe(self._on_change_route_option) display(HBox([self._buttons["show_routes"], self._input["route"]])) self._output["routes"] = widgets.Output( layout={"border": "1px solid silver", "width": "99%"} ) display(self._output["routes"]) def _create_search_widgets(self): self._buttons["execute"] = Button(description="Run Search") self._buttons["execute"].on_click(self._on_exec_button_clicked) self._buttons["extend"] = widgets.Button(description="Extend Search") self._buttons["extend"].on_click(self._on_extend_button_clicked) display(HBox([self._buttons["execute"], self._buttons["extend"]])) self._output["tree_search"] = widgets.Output( layout={ "border": "1px solid silver", "width": "99%", "height": "300px", "overflow": "auto", } ) display(self._output["tree_search"]) def _make_slider_input(self, label, description, min_val, max_val): label_widget = Label(description) slider = widgets.IntSlider( continuous_update=True, min=min_val, max=max_val, readout=False ) self._input[label] = IntText(continuous_update=True, layout={"width": "80px"}) widgets.link((self._input[label], "value"), (slider, "value")) return HBox([label_widget, slider, self._input[label]]) def _on_change_route_option(self, change): if change["name"] != "index": return self._show_route(self._input["route"].index) def _on_exec_button_clicked(self, _): self._toggle_button(False) self._prepare_search() self._tree_search() self._toggle_button(True) def _on_extend_button_clicked(self, _): self._toggle_button(False) self._tree_search() self._toggle_button(True) def _on_display_button_clicked(self, _): self._toggle_button(False) self.finder.build_routes() self.finder.routes.make_images() self._input["route"].options = [ f"Option {i}" for i, _ in enumerate(self.finder.routes, 1) ] self._show_route(0) self._toggle_button(True) def _prepare_search(self): self._output["tree_search"].clear_output() with self._output["tree_search"]: selected_stocks = [ cb.description for cb in self._input["stocks"] if cb.value ] self.finder.stock.select_stocks(selected_stocks) self.finder.policy.select_policy(self._input["policy"].value) self.finder.config.update( **{ "C": self._input["C"].value, "max_transforms": self._input["max_transforms"].value, "cutoff_cumulative": self._input["cutoff_cumulative"].value, "cutoff_number": int(self._input["cutoff_number"].value), "return_first": self._input["return_first"].value, "time_limit": self._input["time_limit"].value * 60, "iteration_limit": self._input["iteration_limit"].value, "exclude_target_from_stock": self._input[ "exclude_target_from_stock" ].value, } ) smiles = self._input["smiles"].value print("Setting target molecule with smiles: %s" % smiles) self.finder.target_smiles = smiles self.finder.prepare_tree() def _show_mol(self, change): self._output["smiles"].clear_output() with self._output["smiles"]: mol = Chem.MolFromSmiles(change["new"]) display(mol) def _show_route(self, index): if index is None or index >= len(self.finder.routes): return node = self.finder.routes[index]["node"] state = node.state status = "Solved" if state.is_solved else "Not Solved" self._output["routes"].clear_output() with self._output["routes"]: display(HTML("<H2>%s" % status)) display("Route Score: %0.3F" % state.score) display(HTML("<H2>Compounds to Procure")) display(state.to_image()) display(HTML("<H2>Steps")) display(self.finder.routes[index]["image"]) def _toggle_button(self, on): for button in self._buttons.values(): button.disabled = not on def _tree_search(self): with self._output["tree_search"]: self.finder.tree_search(show_progress=True) print("Tree search completed.")
class AiZynthApp: """ Interface class to be used in a Jupyter Notebook. Provides a basic GUI to setup and analyze the tree search. Should be instantiated with the path of a yaml file with configuration: .. code-block:: from aizynthfinder.interfaces import AiZynthApp configfile = "/path/to/configfile.yaml" app = AiZynthApp(configfile) :ivar finder: the finder instance :param configfile: the path to yaml file with configuration :param setup: if True will create and display the GUI on instantiation, defaults to True """ def __init__(self, configfile: str, setup: bool = True) -> None: setup_logger(logging.INFO) self.finder = AiZynthFinder(configfile=configfile) self._input: StrDict = dict() self._output: StrDict = dict() self._buttons: StrDict = dict() if setup: self.setup() def setup(self) -> None: """ Create the widgets and display the GUI. This is typically done on instantiation, but this method if for more advanced uses. """ self._create_input_widgets() self._create_search_widgets() self._create_route_widgets() def _create_input_widgets(self) -> None: self._input["smiles"] = Text(description="SMILES", continuous_update=False) self._input["smiles"].observe(self._show_mol, names="value") display(self._input["smiles"]) self._output["smiles"] = Output( layout={"border": "1px solid silver", "width": "50%", "height": "180px"} ) display(self._output["smiles"]) self._input["stocks"] = [ Checkbox( value=True, description=key, style={"description_width": "initial"}, layout={"justify": "left"}, ) for key in self.finder.stock.items ] list_ = [Label("Limit atom occurrences")] self._input["stocks_atom_count_on"] = [] self._input["stocks_atom_count"] = [] current_criteria = self.finder.stock.stop_criteria.get("counts", {}) for atom in ["C", "O", "N"]: chk_box = Checkbox( value=atom in current_criteria, description=atom, layout={"justify": "left", "width": "80px"}, style={"description_width": "initial"}, ) self._input["stocks_atom_count_on"].append(chk_box) inpt = BoundedIntText( value=current_criteria.get(atom, 0), min=0, layout={"width": "80px"}, ) self._input["stocks_atom_count"].append(inpt) list_.append(HBox([chk_box, inpt])) box_stocks = VBox( [Label("Stocks")] + self._input["stocks"] + list_, layout={"border": "1px solid silver"}, ) self._input["policy"] = widgets.Dropdown( options=self.finder.expansion_policy.items, description="Expansion Policy:", style={"description_width": "initial"}, ) self._input["filter"] = widgets.Dropdown( options=["None"] + self.finder.filter_policy.items, description="Filter Policy:", style={"description_width": "initial"}, ) max_time_box = self._make_slider_input("time_limit", "Time (min)", 1, 120) self._input["time_limit"].value = self.finder.config.time_limit / 60 max_iter_box = self._make_slider_input( "iteration_limit", "Max Iterations", 100, 2000 ) self._input["iteration_limit"].value = self.finder.config.iteration_limit self._input["return_first"] = widgets.Checkbox( value=self.finder.config.return_first, description="Return first solved route", ) vbox = VBox( [ self._input["policy"], self._input["filter"], max_time_box, max_iter_box, self._input["return_first"], ] ) box_options = HBox([box_stocks, vbox]) self._input["C"] = FloatText(description="C", value=self.finder.config.C) self._input["max_transforms"] = BoundedIntText( description="Max steps for substrates", min=1, max=6, value=self.finder.config.max_transforms, style={"description_width": "initial"}, ) self._input["cutoff_cumulative"] = BoundedFloatText( description="Policy cutoff cumulative", min=0, max=1, value=self.finder.config.cutoff_cumulative, style={"description_width": "initial"}, ) self._input["cutoff_number"] = BoundedIntText( description="Policy cutoff number", min=1, max=1000, value=self.finder.config.cutoff_number, style={"description_width": "initial"}, ) self._input["filter_cutoff"] = BoundedFloatText( description="Filter cutoff", min=0, max=1, value=self.finder.config.filter_cutoff, style={"description_width": "initial"}, ) self._input["exclude_target_from_stock"] = widgets.Checkbox( value=self.finder.config.exclude_target_from_stock, description="Exclude target from stock", ) box_advanced = VBox( [ self._input["C"], self._input["max_transforms"], self._input["cutoff_cumulative"], self._input["cutoff_number"], self._input["filter_cutoff"], self._input["exclude_target_from_stock"], ] ) children = [box_options, box_advanced] tab = widgets.Tab() tab.children = children tab.set_title(0, "Options") tab.set_title(1, "Advanced") display(tab) def _create_route_widgets(self) -> None: self._input["scorer"] = widgets.Dropdown( options=self.finder.scorers.names(), description="Reorder by:", style={"description_width": "initial"}, ) self._input["scorer"].observe(self._on_change_scorer) self._buttons["show_routes"] = Button(description="Show Reactions") self._buttons["show_routes"].on_click(self._on_display_button_clicked) self._input["route"] = Dropdown( options=[], description="Routes: ", ) self._input["route"].observe(self._on_change_route_option) display( HBox( [ self._buttons["show_routes"], self._input["route"], self._input["scorer"], ] ) ) self._output["routes"] = widgets.Output( layout={"border": "1px solid silver", "width": "99%"} ) display(self._output["routes"]) def _create_search_widgets(self) -> None: self._buttons["execute"] = Button(description="Run Search") self._buttons["execute"].on_click(self._on_exec_button_clicked) self._buttons["extend"] = widgets.Button(description="Extend Search") self._buttons["extend"].on_click(self._on_extend_button_clicked) display(HBox([self._buttons["execute"], self._buttons["extend"]])) self._output["tree_search"] = widgets.Output( layout={ "border": "1px solid silver", "width": "99%", "height": "300px", "overflow": "auto", } ) display(self._output["tree_search"]) def _make_slider_input(self, label, description, min_val, max_val) -> HBox: label_widget = Label(description) slider = widgets.IntSlider( continuous_update=True, min=min_val, max=max_val, readout=False ) self._input[label] = IntText(continuous_update=True, layout={"width": "80px"}) widgets.link((self._input[label], "value"), (slider, "value")) return HBox([label_widget, slider, self._input[label]]) def _on_change_route_option(self, change) -> None: if change["name"] != "index": return self._show_route(self._input["route"].index) def _on_change_scorer(self, change) -> None: if self.finder.routes is None or change["name"] != "index": return scorer = self.finder.scorers[self._input["scorer"].value] self.finder.routes.rescore(scorer) self._show_route(self._input["route"].index) def _on_exec_button_clicked(self, _) -> None: self._toggle_button(False) self._prepare_search() self._tree_search() self._toggle_button(True) def _on_extend_button_clicked(self, _) -> None: self._toggle_button(False) self._tree_search() self._toggle_button(True) def _on_display_button_clicked(self, _) -> None: self._toggle_button(False) self.finder.build_routes() self.finder.routes.make_images() self.finder.routes.compute_scores(*self.finder.scorers.objects()) self._input["route"].options = [ f"Option {i}" for i, _ in enumerate(self.finder.routes, 1) # type: ignore ] self._show_route(0) self._toggle_button(True) def _prepare_search(self) -> None: self._output["tree_search"].clear_output() with self._output["tree_search"]: selected_stocks = [ cb.description for cb in self._input["stocks"] if cb.value ] self.finder.stock.select(selected_stocks) atom_count_limits = {} for cb_input, value_input in zip( self._input["stocks_atom_count_on"], self._input["stocks_atom_count"] ): if cb_input.value: atom_count_limits[cb_input.description] = value_input.value self.finder.stock.set_stop_criteria({"counts": atom_count_limits}) self.finder.expansion_policy.select(self._input["policy"].value) if self._input["filter"].value == "None": self.finder.filter_policy.deselect() else: self.finder.filter_policy.select(self._input["policy"].value) self.finder.config.update( **{ "C": self._input["C"].value, "max_transforms": self._input["max_transforms"].value, "cutoff_cumulative": self._input["cutoff_cumulative"].value, "cutoff_number": int(self._input["cutoff_number"].value), "return_first": self._input["return_first"].value, "time_limit": self._input["time_limit"].value * 60, "iteration_limit": self._input["iteration_limit"].value, "filter_cutoff": self._input["filter_cutoff"].value, "exclude_target_from_stock": self._input[ "exclude_target_from_stock" ].value, } ) smiles = self._input["smiles"].value print("Setting target molecule with smiles: %s" % smiles) self.finder.target_smiles = smiles self.finder.prepare_tree() def _show_mol(self, change) -> None: self._output["smiles"].clear_output() with self._output["smiles"]: mol = Chem.MolFromSmiles(change["new"]) display(mol) def _show_route(self, index) -> None: if ( index is None or self.finder.routes is None or index >= len(self.finder.routes) ): return route = self.finder.routes[index] state = route["node"].state status = "Solved" if state.is_solved else "Not Solved" self._output["routes"].clear_output() with self._output["routes"]: display(HTML("<H2>%s" % status)) table_content = "".join( f"<tr><td>{name}</td><td>{score:.4f}</td></tr>" for name, score in route["all_score"].items() ) display(HTML(f"<table>{table_content}</table>")) display(HTML("<H2>Compounds to Procure")) display(state.to_image()) display(HTML("<H2>Steps")) display(self.finder.routes[index]["image"]) def _toggle_button(self, on) -> None: for button in self._buttons.values(): button.disabled = not on def _tree_search(self) -> None: with self._output["tree_search"]: self.finder.tree_search(show_progress=True) print("Tree search completed.")