Exemple #1
0
    def apply(self) -> Tuple[Tuple[TreeMolecule, ...], ...]:
        """
        Apply a reactions smarts to a molecule and return the products (reactants for retro templates)

        Will try to sanitize the reactants, and if that fails it will not return that molecule

        :return: the products of the reaction
        """
        reaction = rdc.rdchiralReaction(self.smarts)
        rct = rdc.rdchiralReactants(self.mol.smiles)
        try:
            reactants = rdc.rdchiralRun(reaction, rct)
        except RuntimeError as err:
            logger().debug(
                f"Runtime error in RDChiral with template {self.smarts} on {self.mol.smiles}\n{err}"
            )
            reactants = []

        # Turning rdchiral outcome into rdkit tuple of tuples to maintain compatibility
        outcomes = []
        for reactant_str in reactants:
            smiles_list = reactant_str.split(".")
            try:
                rct = tuple(
                    TreeMolecule(parent=self.mol, smiles=smi, sanitize=True)
                    for smi in smiles_list
                )
            except MoleculeException:
                pass
            else:
                outcomes.append(rct)
        self._reactants = tuple(outcomes)

        return self._reactants
Exemple #2
0
def get_mongo_client(
    host: str = "localhost",
    port: int = 27017,
    user: str = None,
    password: str = None,
    tls_certs_path: str = "",
) -> MongoClient:
    """
    A helper function to create and reuse MongoClient

    The client is only setup once. Therefore if this function is called a second
    time with different parameters, it would still return the first client.

    :param host: the host
    :param port: the host port
    :param user: username, defaults to None
    :param password: password, defaults to None
    :param tls_certs_path: the path to TLS certificates if to be used, defaults to ""
    :raises ValueError: if host and port is not given first time
    :return: the MongoDB client
    """
    global _CLIENT
    if _CLIENT is None:
        params = {}
        if tls_certs_path:
            params.update({"ssl": "true", "ssl_ca_certs": tls_certs_path})
        cred_str = f"{user}:{password}@" if password else ""
        uri = f"mongodb://{cred_str}{host}:{port}/?{urlencode(params)}"
        logger().debug(f"Connecting to MongoDB on {host}:{port}")
        _CLIENT = MongoClient(uri)
    return _CLIENT
Exemple #3
0
def _process_multi_smiles(filename: str, finder: AiZynthFinder,
                          output_name: str, do_clustering: bool) -> None:
    output_name = output_name or "output.hdf5"
    with open(filename, "r") as fileobj:
        smiles = [line.strip() for line in fileobj.readlines()]

    results = defaultdict(list)
    for smi in smiles:
        finder.target_smiles = smi
        finder.prepare_tree()
        search_time = finder.tree_search()
        finder.build_routes()
        stats = finder.extract_statistics()

        logger().info(f"Done with {smi} in {search_time:.3} s")
        if do_clustering:
            _do_clustering(finder, stats, detailed_results=True)
        for key, value in stats.items():
            results[key].append(value)
        results["top_scores"].append(", ".join(
            "%.4f" % score for score in finder.routes.scores))
        results["trees"].append(finder.routes.dicts)

    data = pd.DataFrame.from_dict(results)
    with warnings.catch_warnings():  # This wil suppress a PerformanceWarning
        warnings.simplefilter("ignore")
        data.to_hdf(output_name, key="table", mode="w")
    logger().info(f"Output saved to {output_name}")
Exemple #4
0
    def __init__(
        self,
        state: MctsState,
        owner: MctsSearchTree,
        config: Configuration,
        parent: MctsNode = None,
    ):
        self._state = state
        self._config = config
        self._expansion_policy = config.expansion_policy
        self._filter_policy = config.filter_policy
        self.tree = owner
        self.is_expanded: bool = False
        self.is_expandable: bool = not self.state.is_terminal
        self._parent = parent

        self._children_values: List[float] = []
        self._children_priors: List[float] = []
        self._children_visitations: List[int] = []
        self._children_actions: List[RetroReaction] = []
        self._children: List[Optional[MctsNode]] = []

        self.blacklist = set(mol.inchi_key for mol in state.expandable_mols)
        if parent:
            self.blacklist = self.blacklist.union(parent.blacklist)

        self._logger = logger()
Exemple #5
0
 def __init__(self, config):
     self._config = config
     self._logger = logger()
     self._policies = {}
     self._selected_policy = None
     self._policy_model = None
     self._stock = config.stock
     self._templates = None
 def __init__(self, key: str, config: Configuration, **kwargs: Any) -> None:
     if any(name not in kwargs for name in self._required_kwargs):
         raise PolicyException(
             f"A {self.__class__.__name__} class needs to be initiated "
             f"with keyword arguments: {', '.join(self._required_kwargs)}")
     self._config = config
     self._logger = logger()
     self.key = key
Exemple #7
0
 def __post_init__(self) -> None:
     self._properties: StrDict = {}
     self.stock = Stock()
     self.expansion_policy = ExpansionPolicy(self)
     self.filter_policy = FilterPolicy(self)
     self.scorers = ScorerCollection(self)
     self.molecule_cost = MoleculeCost()
     self._logger = logger()
Exemple #8
0
def start_processes(inputs, log_prefix, cmd_callback, poll_freq=5):
    """
    Start a number of background processes and wait for them
    to complete.

    The standard output and standard error is saved to a log file.

    The command to start for each process is given by the ``cmd_callback``
    function that takes the index of the process and an item of the input
    as arguments.

    :param inputs: an iterable of input to the processes
    :type inputs: iterable
    :param log_prefix: the prefix to the log file of each processes
    :type log_prefix: str
    :param cmd_callback: function that creates the process commands
    :type cmd_callback: function
    :param poll_freq: the polling frequency for checking if processes are completed
    :type poll_freq: int, optional
    """
    processes = []
    output_fileobjs = []
    for index, iinput in enumerate(inputs, 1):
        output_fileobjs.append(open(f"{log_prefix}{index}.log", "w"))
        cmd = cmd_callback(index, iinput)
        processes.append(
            subprocess.Popen(cmd,
                             stdout=output_fileobjs[-1],
                             stderr=subprocess.STDOUT))
        logger().info(f"Started background task with pid={processes[-1].pid}")

    logger().info("Waiting for background tasks to complete...")
    not_finished = True
    while not_finished:
        time.sleep(poll_freq)
        not_finished = False
        for process, fileobj in zip(processes, output_fileobjs):
            fileobj.flush()
            if process.poll() is None:
                not_finished = True

    for fileobj in output_fileobjs:
        fileobj.close()
Exemple #9
0
    def __init__(self):
        self._properties = {}
        filename = os.path.join(data_path(), "config.yml")
        with open(filename, "r") as fileobj:
            _config = yaml.load(fileobj.read(), Loader=yaml.SafeLoader)
        self._update_from_config(_config)

        self.stock = Stock()
        self.policy = Policy(self)
        self._logger = logger()
Exemple #10
0
    def __post_init__(self) -> None:
        self._properties: StrDict = {}
        filename = os.path.join(data_path(), "config.yml")
        with open(filename, "r") as fileobj:
            _config = yaml.load(fileobj.read(), Loader=yaml.SafeLoader)
        self._update_from_config(_config)

        self.stock = Stock()
        self.expansion_policy = ExpansionPolicy(self)
        self.filter_policy = FilterPolicy(self)
        self.scorers = ScorerCollection(self)
        self._logger = logger()
Exemple #11
0
def _process_single_smiles(
    smiles: str,
    finder: AiZynthFinder,
    output_name: str,
    do_clustering: bool,
    route_distance_model: str = None,
) -> None:
    output_name = output_name or "trees.json"
    finder.target_smiles = smiles
    finder.prepare_tree()
    finder.tree_search(show_progress=True)
    finder.build_routes()

    with open(output_name, "w") as fileobj:
        json.dump(finder.routes.dicts, fileobj, indent=2)
    logger().info(f"Trees saved to {output_name}")

    scores = ", ".join("%.4f" % score for score in finder.routes.scores)
    logger().info(f"Scores for best routes: {scores}")

    stats = finder.extract_statistics()
    if do_clustering:
        _do_clustering(finder,
                       stats,
                       detailed_results=False,
                       model_path=route_distance_model)
    stats_str = "\n".join(f"{key.replace('_', ' ')}: {value}"
                          for key, value in stats.items())
    logger().info(stats_str)
    def __init__(self,
                 configfile: str = None,
                 configdict: StrDict = None) -> None:
        self._logger = logger()

        if configfile:
            self.config = Configuration.from_file(configfile)
        elif configdict:
            self.config = Configuration.from_dict(configdict)
        else:
            self.config = Configuration()

        self.expansion_policy = self.config.expansion_policy
        self.filter_policy = self.config.filter_policy
    def __init__(self, configfile=None):
        self._logger = logger()

        if configfile:
            self.config = Configuration.from_file(configfile)
        else:
            self.config = Configuration()

        self.policy = self.config.policy
        self.stock = self.config.stock
        self.tree = None
        self._target_mol = None
        self.search_stats = {}
        self.routes = None
        self.analysis = None
Exemple #14
0
    def __init__(self, state, owner, config, parent=None):
        self.state = state
        self._config = config
        self._policy = config.policy
        self._tree = owner
        self.is_expanded = False
        self.is_expandable = not self.state.is_terminal
        self.parent = parent

        self._children_values = []
        self._children_priors = []
        self._children_visitations = []
        self._children_actions = []
        self._children = []

        self._logger = logger()
Exemple #15
0
    def __init__(self, configfile: str = None, configdict: StrDict = None) -> None:
        self._logger = logger()

        if configfile:
            self.config = Configuration.from_file(configfile)
        elif configdict:
            self.config = Configuration.from_dict(configdict)
        else:
            self.config = Configuration()

        self.expansion_policy = self.config.expansion_policy
        self.filter_policy = self.config.filter_policy
        self.stock = self.config.stock
        self.scorers = self.config.scorers
        self.tree: Optional[SearchTree] = None
        self._target_mol: Optional[Molecule] = None
        self.search_stats: StrDict = dict()
        self.routes = RouteCollection([])
        self.analysis: Optional[TreeAnalysis] = None
    def __init__(self, configfile=None, configdict=None):
        self._logger = logger()

        if configfile:
            self.config = Configuration.from_file(configfile)
        elif configdict:
            self.config = Configuration.from_dict(configdict)
        else:
            self.config = Configuration()

        self.expansion_policy = self.config.expansion_policy
        self.filter_policy = self.config.filter_policy
        self.stock = self.config.stock
        self.scorers = self.config.scorers
        self.tree = None
        self._target_mol = None
        self.search_stats = {}
        self.routes = None
        self.analysis = None
Exemple #17
0
    def __init__(
        self, state: State, owner: SearchTree, config: Configuration, parent=None
    ):
        self.state = state
        self._config = config
        self._expansion_policy = config.expansion_policy
        self._filter_policy = config.filter_policy
        self.tree = owner
        self.is_expanded: bool = False
        self.is_expandable: bool = not self.state.is_terminal
        self.parent = parent

        self._children_values: List[float] = []
        self._children_priors: List[float] = []
        self._children_visitations: List[int] = []
        self._children_actions: List[RetroReaction] = []
        self._children: List[Optional[Node]] = []

        self._logger = logger()
    def __init__(self,
                 reaction_tree,
                 content=TreeContent.MOLECULES,
                 exhaustive_limit=20):
        self._logger = logger()
        # Will convert string input automatically
        self._content = TreeContent(content)
        self._graph = reaction_tree.graph
        self._root = self._make_root(reaction_tree)

        self._trees = []
        self._tree_count, self._node_index_list = self._inspect_tree()
        self._enumeration = self._tree_count <= exhaustive_limit

        if not self._root:
            return

        if self._enumeration:
            self._create_all_trees()
        else:
            self._trees.append(self._create_tree_recursively(self._root))
Exemple #19
0
def _process_single_smiles(smiles, finder, output_name):
    output_name = output_name or "trees.json"
    finder.target_smiles = smiles
    finder.prepare_tree()
    finder.tree_search(show_progress=True)
    finder.build_routes()

    with open(output_name, "w") as fileobj:
        json.dump(finder.routes.dicts, fileobj, indent=2)
    logger().info(f"Trees saved to {output_name}")

    scores = ", ".join("%.4f" % score for score in finder.routes.scores)
    logger().info(f"Scores for best routes: {scores}")

    stats = finder.extract_statistics()
    stats_str = "\n".join(f"{key.replace('_', ' ')}: {value}"
                          for key, value in stats.items())
    logger().info(stats_str)
 def __init__(self):
     self._items: StrDict = {}
     self._selection: List[str] = []
     self._logger = logger()
Exemple #21
0
 def __init__(self):
     self._items = {}
     self._selection = []
     self._logger = logger()
Exemple #22
0
    prediction_service_pb2_grpc,
)
from tensorflow.keras.metrics import top_k_categorical_accuracy
from tensorflow.keras.models import load_model as load_keras_model

from aizynthfinder.utils.logging import logger

top10_acc = functools.partial(top_k_categorical_accuracy, k=10)
top10_acc.__name__ = "top10_acc"

top50_acc = functools.partial(top_k_categorical_accuracy, k=50)
top50_acc.__name__ = "top50_acc"

CUSTOM_OBJECTS = {"top10_acc": top10_acc, "top50_acc": top50_acc, "tf": tf}

_logger = logger()

TF_SERVING_HOST = os.environ.get("TF_SERVING_HOST")
TF_SERVING_REST_PORT = os.environ.get("TF_SERVING_REST_PORT")
TF_SERVING_GRPC_PORT = os.environ.get("TF_SERVING_GRPC_PORT")


def load_model(source, key, use_remote_models):
    """
    Load model from a configuration specification.

    If `use_remote_models` is True, tries to load:
      1. A Tensorflow server through gRPC
      2. A Tensorflow server through REST API
      3. A local model
    otherwise it just loads the local model