def _process_multi_smiles(filename: str, finder: AiZynthFinder, output_name: str, do_clustering: bool) -> None: output_name = output_name or "output.hdf5" with open(filename, "r") as fileobj: smiles = [line.strip() for line in fileobj.readlines()] results = defaultdict(list) for smi in smiles: finder.target_smiles = smi finder.prepare_tree() search_time = finder.tree_search() finder.build_routes() stats = finder.extract_statistics() logger().info(f"Done with {smi} in {search_time:.3} s") if do_clustering: _do_clustering(finder, stats, detailed_results=True) for key, value in stats.items(): results[key].append(value) results["top_scores"].append(", ".join( "%.4f" % score for score in finder.routes.scores)) results["trees"].append(finder.routes.dicts) data = pd.DataFrame.from_dict(results) with warnings.catch_warnings(): # This wil suppress a PerformanceWarning warnings.simplefilter("ignore") data.to_hdf(output_name, key="table", mode="w") logger().info(f"Output saved to {output_name}")
def _process_single_smiles( smiles: str, finder: AiZynthFinder, output_name: str, do_clustering: bool, route_distance_model: str = None, ) -> None: output_name = output_name or "trees.json" finder.target_smiles = smiles finder.prepare_tree() finder.tree_search(show_progress=True) finder.build_routes() with open(output_name, "w") as fileobj: json.dump(finder.routes.dicts, fileobj, indent=2) logger().info(f"Trees saved to {output_name}") scores = ", ".join("%.4f" % score for score in finder.routes.scores) logger().info(f"Scores for best routes: {scores}") stats = finder.extract_statistics() if do_clustering: _do_clustering(finder, stats, detailed_results=False, model_path=route_distance_model) stats_str = "\n".join(f"{key.replace('_', ' ')}: {value}" for key, value in stats.items()) logger().info(stats_str)
def process_smiles_list(finder: AiZynthFinder, target_smiles: List[str]) -> List[Dict[str, str]]: stats_list: List[Dict[str, str]] = [] for smiles in target_smiles: finder.target_smiles = smiles finder.tree_search() finder.build_routes() stats_list.append(finder.extract_statistics()) return stats_list
def generate_images(smiles, out_dir, config): finder = AiZynthFinder(configfile=config) finder.stock.select("zinc") finder.expansion_policy.select("uspto") finder.filter_policy.select("uspto") if not os.path.exists(out_dir): os.mkdir(out_dir) finder.target_smiles = smiles finder.tree_search() finder.build_routes() if finder.routes.images: for n, image in enumerate(finder.routes.images): image.save(f"{out_dir}/route{n:03d}.png") # Combine 4 images into a single image by ImageMagick images = glob.glob(f'{out_dir}/route*.png') for i, imgs in enumerate(chunked(sorted(images), 3)): concat_images(imgs, f"{out_dir}/result{i}.png")
class AiZynthApp: """ Interface class to be used in a Jupyter Notebook. Provides a basic GUI to setup and analyze the tree search. Should be instantiated with the path of a yaml file with configuration: .. code-block:: from aizynthfinder.interfaces import AiZynthApp configfile = "/path/to/configfile.yaml" app = AiZynthApp(configfile) :ivar finder: the finder instance :vartype finder: AiZynthFinder :param configfile: the path to yaml file with configuration :type configfile: str :param setup: if True will create and display the GUI on instatiation, defaults to True :type setup: bool, optional """ def __init__(self, configfile, setup=True): setup_logger(logging.INFO) self.finder = AiZynthFinder(configfile=configfile) self._input = dict() self._output = dict() self._buttons = dict() if setup: self.setup() def setup(self): """ Create the widgets and display the GUI. This is typically done on instatiation, but this method if for more advanced uses. """ self._create_input_widgets() self._create_search_widgets() self._create_route_widgets() def _create_input_widgets(self): self._input["smiles"] = Text(description="SMILES", continuous_update=False) self._input["smiles"].observe(self._show_mol, names="value") display(self._input["smiles"]) self._output["smiles"] = Output( layout={"border": "1px solid silver", "width": "50%", "height": "180px"} ) display(self._output["smiles"]) self._input["stocks"] = [ Checkbox(value=True, description=key, layout={"justify": "left"}) for key in self.finder.stock.available_stocks() ] box_stocks = VBox( [Label("Stocks")] + self._input["stocks"], layout={"border": "1px solid silver"}, ) self._input["policy"] = widgets.Dropdown( options=self.finder.policy.available_policies(), description="Neural Policy:", style={"description_width": "initial"}, ) max_time_box = self._make_slider_input("time_limit", "Time (min)", 1, 120) self._input["time_limit"].value = self.finder.config.time_limit / 60 max_iter_box = self._make_slider_input( "iteration_limit", "Max Iterations", 100, 2000 ) self._input["iteration_limit"].value = self.finder.config.iteration_limit self._input["return_first"] = widgets.Checkbox( value=self.finder.config.return_first, description="Return first solved route", ) vbox = VBox( [ self._input["policy"], max_time_box, max_iter_box, self._input["return_first"], ] ) box_options = HBox([box_stocks, vbox]) self._input["C"] = FloatText(description="C", value=self.finder.config.C) self._input["max_transforms"] = BoundedIntText( description="Max steps for substrates", min=1, max=6, value=self.finder.config.max_transforms, style={"description_width": "initial"}, ) self._input["cutoff_cumulative"] = BoundedFloatText( description="Policy cutoff cumulative", min=0, max=1, value=self.finder.config.cutoff_cumulative, style={"description_width": "initial"}, ) self._input["cutoff_number"] = BoundedIntText( description="Policy cutoff number", min=1, max=1000, value=self.finder.config.cutoff_number, style={"description_width": "initial"}, ) self._input["exclude_target_from_stock"] = widgets.Checkbox( value=self.finder.config.exclude_target_from_stock, description="Exclude target from stock", ) box_advanced = VBox( [ self._input["C"], self._input["max_transforms"], self._input["cutoff_cumulative"], self._input["cutoff_number"], self._input["exclude_target_from_stock"], ] ) children = [box_options, box_advanced] tab = widgets.Tab() tab.children = children tab.set_title(0, "Options") tab.set_title(1, "Advanced") display(tab) def _create_route_widgets(self): self._buttons["show_routes"] = Button(description="Show Reactions") self._buttons["show_routes"].on_click(self._on_display_button_clicked) self._input["route"] = Dropdown(options=[], description="Routes: ",) self._input["route"].observe(self._on_change_route_option) display(HBox([self._buttons["show_routes"], self._input["route"]])) self._output["routes"] = widgets.Output( layout={"border": "1px solid silver", "width": "99%"} ) display(self._output["routes"]) def _create_search_widgets(self): self._buttons["execute"] = Button(description="Run Search") self._buttons["execute"].on_click(self._on_exec_button_clicked) self._buttons["extend"] = widgets.Button(description="Extend Search") self._buttons["extend"].on_click(self._on_extend_button_clicked) display(HBox([self._buttons["execute"], self._buttons["extend"]])) self._output["tree_search"] = widgets.Output( layout={ "border": "1px solid silver", "width": "99%", "height": "300px", "overflow": "auto", } ) display(self._output["tree_search"]) def _make_slider_input(self, label, description, min_val, max_val): label_widget = Label(description) slider = widgets.IntSlider( continuous_update=True, min=min_val, max=max_val, readout=False ) self._input[label] = IntText(continuous_update=True, layout={"width": "80px"}) widgets.link((self._input[label], "value"), (slider, "value")) return HBox([label_widget, slider, self._input[label]]) def _on_change_route_option(self, change): if change["name"] != "index": return self._show_route(self._input["route"].index) def _on_exec_button_clicked(self, _): self._toggle_button(False) self._prepare_search() self._tree_search() self._toggle_button(True) def _on_extend_button_clicked(self, _): self._toggle_button(False) self._tree_search() self._toggle_button(True) def _on_display_button_clicked(self, _): self._toggle_button(False) self.finder.build_routes() self.finder.routes.make_images() self._input["route"].options = [ f"Option {i}" for i, _ in enumerate(self.finder.routes, 1) ] self._show_route(0) self._toggle_button(True) def _prepare_search(self): self._output["tree_search"].clear_output() with self._output["tree_search"]: selected_stocks = [ cb.description for cb in self._input["stocks"] if cb.value ] self.finder.stock.select_stocks(selected_stocks) self.finder.policy.select_policy(self._input["policy"].value) self.finder.config.update( **{ "C": self._input["C"].value, "max_transforms": self._input["max_transforms"].value, "cutoff_cumulative": self._input["cutoff_cumulative"].value, "cutoff_number": int(self._input["cutoff_number"].value), "return_first": self._input["return_first"].value, "time_limit": self._input["time_limit"].value * 60, "iteration_limit": self._input["iteration_limit"].value, "exclude_target_from_stock": self._input[ "exclude_target_from_stock" ].value, } ) smiles = self._input["smiles"].value print("Setting target molecule with smiles: %s" % smiles) self.finder.target_smiles = smiles self.finder.prepare_tree() def _show_mol(self, change): self._output["smiles"].clear_output() with self._output["smiles"]: mol = Chem.MolFromSmiles(change["new"]) display(mol) def _show_route(self, index): if index is None or index >= len(self.finder.routes): return node = self.finder.routes[index]["node"] state = node.state status = "Solved" if state.is_solved else "Not Solved" self._output["routes"].clear_output() with self._output["routes"]: display(HTML("<H2>%s" % status)) display("Route Score: %0.3F" % state.score) display(HTML("<H2>Compounds to Procure")) display(state.to_image()) display(HTML("<H2>Steps")) display(self.finder.routes[index]["image"]) def _toggle_button(self, on): for button in self._buttons.values(): button.disabled = not on def _tree_search(self): with self._output["tree_search"]: self.finder.tree_search(show_progress=True) print("Tree search completed.")
class AiZynthApp: """ Interface class to be used in a Jupyter Notebook. Provides a basic GUI to setup and analyze the tree search. Should be instantiated with the path of a yaml file with configuration: .. code-block:: from aizynthfinder.interfaces import AiZynthApp configfile = "/path/to/configfile.yaml" app = AiZynthApp(configfile) :ivar finder: the finder instance :param configfile: the path to yaml file with configuration :param setup: if True will create and display the GUI on instantiation, defaults to True """ def __init__(self, configfile: str, setup: bool = True) -> None: setup_logger(logging.INFO) self.finder = AiZynthFinder(configfile=configfile) self._input: StrDict = dict() self._output: StrDict = dict() self._buttons: StrDict = dict() if setup: self.setup() def setup(self) -> None: """ Create the widgets and display the GUI. This is typically done on instantiation, but this method if for more advanced uses. """ self._create_input_widgets() self._create_search_widgets() self._create_route_widgets() def _create_input_widgets(self) -> None: self._input["smiles"] = Text(description="SMILES", continuous_update=False) self._input["smiles"].observe(self._show_mol, names="value") display(self._input["smiles"]) self._output["smiles"] = Output( layout={"border": "1px solid silver", "width": "50%", "height": "180px"} ) display(self._output["smiles"]) self._input["stocks"] = [ Checkbox( value=True, description=key, style={"description_width": "initial"}, layout={"justify": "left"}, ) for key in self.finder.stock.items ] list_ = [Label("Limit atom occurrences")] self._input["stocks_atom_count_on"] = [] self._input["stocks_atom_count"] = [] current_criteria = self.finder.stock.stop_criteria.get("counts", {}) for atom in ["C", "O", "N"]: chk_box = Checkbox( value=atom in current_criteria, description=atom, layout={"justify": "left", "width": "80px"}, style={"description_width": "initial"}, ) self._input["stocks_atom_count_on"].append(chk_box) inpt = BoundedIntText( value=current_criteria.get(atom, 0), min=0, layout={"width": "80px"}, ) self._input["stocks_atom_count"].append(inpt) list_.append(HBox([chk_box, inpt])) box_stocks = VBox( [Label("Stocks")] + self._input["stocks"] + list_, layout={"border": "1px solid silver"}, ) self._input["policy"] = widgets.Dropdown( options=self.finder.expansion_policy.items, description="Expansion Policy:", style={"description_width": "initial"}, ) self._input["filter"] = widgets.Dropdown( options=["None"] + self.finder.filter_policy.items, description="Filter Policy:", style={"description_width": "initial"}, ) max_time_box = self._make_slider_input("time_limit", "Time (min)", 1, 120) self._input["time_limit"].value = self.finder.config.time_limit / 60 max_iter_box = self._make_slider_input( "iteration_limit", "Max Iterations", 100, 2000 ) self._input["iteration_limit"].value = self.finder.config.iteration_limit self._input["return_first"] = widgets.Checkbox( value=self.finder.config.return_first, description="Return first solved route", ) vbox = VBox( [ self._input["policy"], self._input["filter"], max_time_box, max_iter_box, self._input["return_first"], ] ) box_options = HBox([box_stocks, vbox]) self._input["C"] = FloatText(description="C", value=self.finder.config.C) self._input["max_transforms"] = BoundedIntText( description="Max steps for substrates", min=1, max=6, value=self.finder.config.max_transforms, style={"description_width": "initial"}, ) self._input["cutoff_cumulative"] = BoundedFloatText( description="Policy cutoff cumulative", min=0, max=1, value=self.finder.config.cutoff_cumulative, style={"description_width": "initial"}, ) self._input["cutoff_number"] = BoundedIntText( description="Policy cutoff number", min=1, max=1000, value=self.finder.config.cutoff_number, style={"description_width": "initial"}, ) self._input["filter_cutoff"] = BoundedFloatText( description="Filter cutoff", min=0, max=1, value=self.finder.config.filter_cutoff, style={"description_width": "initial"}, ) self._input["exclude_target_from_stock"] = widgets.Checkbox( value=self.finder.config.exclude_target_from_stock, description="Exclude target from stock", ) box_advanced = VBox( [ self._input["C"], self._input["max_transforms"], self._input["cutoff_cumulative"], self._input["cutoff_number"], self._input["filter_cutoff"], self._input["exclude_target_from_stock"], ] ) children = [box_options, box_advanced] tab = widgets.Tab() tab.children = children tab.set_title(0, "Options") tab.set_title(1, "Advanced") display(tab) def _create_route_widgets(self) -> None: self._input["scorer"] = widgets.Dropdown( options=self.finder.scorers.names(), description="Reorder by:", style={"description_width": "initial"}, ) self._input["scorer"].observe(self._on_change_scorer) self._buttons["show_routes"] = Button(description="Show Reactions") self._buttons["show_routes"].on_click(self._on_display_button_clicked) self._input["route"] = Dropdown( options=[], description="Routes: ", ) self._input["route"].observe(self._on_change_route_option) display( HBox( [ self._buttons["show_routes"], self._input["route"], self._input["scorer"], ] ) ) self._output["routes"] = widgets.Output( layout={"border": "1px solid silver", "width": "99%"} ) display(self._output["routes"]) def _create_search_widgets(self) -> None: self._buttons["execute"] = Button(description="Run Search") self._buttons["execute"].on_click(self._on_exec_button_clicked) self._buttons["extend"] = widgets.Button(description="Extend Search") self._buttons["extend"].on_click(self._on_extend_button_clicked) display(HBox([self._buttons["execute"], self._buttons["extend"]])) self._output["tree_search"] = widgets.Output( layout={ "border": "1px solid silver", "width": "99%", "height": "300px", "overflow": "auto", } ) display(self._output["tree_search"]) def _make_slider_input(self, label, description, min_val, max_val) -> HBox: label_widget = Label(description) slider = widgets.IntSlider( continuous_update=True, min=min_val, max=max_val, readout=False ) self._input[label] = IntText(continuous_update=True, layout={"width": "80px"}) widgets.link((self._input[label], "value"), (slider, "value")) return HBox([label_widget, slider, self._input[label]]) def _on_change_route_option(self, change) -> None: if change["name"] != "index": return self._show_route(self._input["route"].index) def _on_change_scorer(self, change) -> None: if self.finder.routes is None or change["name"] != "index": return scorer = self.finder.scorers[self._input["scorer"].value] self.finder.routes.rescore(scorer) self._show_route(self._input["route"].index) def _on_exec_button_clicked(self, _) -> None: self._toggle_button(False) self._prepare_search() self._tree_search() self._toggle_button(True) def _on_extend_button_clicked(self, _) -> None: self._toggle_button(False) self._tree_search() self._toggle_button(True) def _on_display_button_clicked(self, _) -> None: self._toggle_button(False) self.finder.build_routes() self.finder.routes.make_images() self.finder.routes.compute_scores(*self.finder.scorers.objects()) self._input["route"].options = [ f"Option {i}" for i, _ in enumerate(self.finder.routes, 1) # type: ignore ] self._show_route(0) self._toggle_button(True) def _prepare_search(self) -> None: self._output["tree_search"].clear_output() with self._output["tree_search"]: selected_stocks = [ cb.description for cb in self._input["stocks"] if cb.value ] self.finder.stock.select(selected_stocks) atom_count_limits = {} for cb_input, value_input in zip( self._input["stocks_atom_count_on"], self._input["stocks_atom_count"] ): if cb_input.value: atom_count_limits[cb_input.description] = value_input.value self.finder.stock.set_stop_criteria({"counts": atom_count_limits}) self.finder.expansion_policy.select(self._input["policy"].value) if self._input["filter"].value == "None": self.finder.filter_policy.deselect() else: self.finder.filter_policy.select(self._input["policy"].value) self.finder.config.update( **{ "C": self._input["C"].value, "max_transforms": self._input["max_transforms"].value, "cutoff_cumulative": self._input["cutoff_cumulative"].value, "cutoff_number": int(self._input["cutoff_number"].value), "return_first": self._input["return_first"].value, "time_limit": self._input["time_limit"].value * 60, "iteration_limit": self._input["iteration_limit"].value, "filter_cutoff": self._input["filter_cutoff"].value, "exclude_target_from_stock": self._input[ "exclude_target_from_stock" ].value, } ) smiles = self._input["smiles"].value print("Setting target molecule with smiles: %s" % smiles) self.finder.target_smiles = smiles self.finder.prepare_tree() def _show_mol(self, change) -> None: self._output["smiles"].clear_output() with self._output["smiles"]: mol = Chem.MolFromSmiles(change["new"]) display(mol) def _show_route(self, index) -> None: if ( index is None or self.finder.routes is None or index >= len(self.finder.routes) ): return route = self.finder.routes[index] state = route["node"].state status = "Solved" if state.is_solved else "Not Solved" self._output["routes"].clear_output() with self._output["routes"]: display(HTML("<H2>%s" % status)) table_content = "".join( f"<tr><td>{name}</td><td>{score:.4f}</td></tr>" for name, score in route["all_score"].items() ) display(HTML(f"<table>{table_content}</table>")) display(HTML("<H2>Compounds to Procure")) display(state.to_image()) display(HTML("<H2>Steps")) display(self.finder.routes[index]["image"]) def _toggle_button(self, on) -> None: for button in self._buttons.values(): button.disabled = not on def _tree_search(self) -> None: with self._output["tree_search"]: self.finder.tree_search(show_progress=True) print("Tree search completed.")