Example #1
0
def test_reaction_hash(load_reaction_tree):
    rt = ReactionTree.from_dict(load_reaction_tree("branched_route.json"))
    reactions = list(rt.reactions())[:4]

    hash_ = hash_reactions(reactions)

    assert hash_ == "359045e74d757c7895304337c855817748b9eefe0e1e680258d4574e"

    hash_ = hash_reactions(reactions, sort=False)

    assert hash_ == "d0cf86e9a5e3a8539964ae62dab51952f64db8c84d750a3cc5b381a6"
Example #2
0
def test_reaction_hash(setup_linear_reaction_tree):
    rt = setup_linear_reaction_tree()
    reactions = list(rt.reactions())[:4]

    hash_ = hash_reactions(reactions)

    assert hash_ == "4e4ca9d7d2fc47ed3fa43a1dfb9abcd14c58f56ff73942dd1d0b8176"

    hash_ = hash_reactions(reactions, sort=False)

    assert hash_ == "567c23da4673b8b2519aeafda9b26ae949ad3e24f570968ee5f80878"
    def find(reaction_tree):
        """
        Find the repeating patterns and mark the nodes

        :param reaction_tree: the reaction tree to process
        :type reaction_tree: ReactionTree
        """
        for node in reaction_tree.reactions():
            # We are only interesting of starting at the very first reaction
            if any(reaction_tree.graph[mol] for mol in node.reactants[0]):
                continue
            actions = _RepeatingPatternIdentifier._list_reactions(
                reaction_tree, node)
            if len(actions) < 5:
                continue

            hashes = [
                hash_reactions([rxn1, rxn2], sort=False)
                for rxn1, rxn2 in zip(actions[:-1:2], actions[1::2])
            ]
            for idx, (hash1, hash2) in enumerate(zip(hashes[:-1], hashes[1:])):
                if hash1 == hash2:
                    _RepeatingPatternIdentifier._hide_reaction(
                        reaction_tree, actions[idx * 2])
                    _RepeatingPatternIdentifier._hide_reaction(
                        reaction_tree, actions[idx * 2 + 1])
                    reaction_tree.has_repeating_patterns = True
                # The else-clause prevents removing repeating patterns in the middle of a route
                else:
                    break
Example #4
0
    def _collect_top_items(
        items: Union[Sequence[MctsNode], Sequence[ReactionTree]],
        scores: Sequence[float],
        reactions: Sequence[Union[Iterable[RetroReaction],
                                  Iterable[FixedRetroReaction]]],
        min_return: int,
        max_return: int = None,
    ) -> Tuple[Union[Sequence[MctsNode], Sequence[ReactionTree]],
               Sequence[float]]:
        if len(items) <= min_return:
            return items, scores

        seen_hashes = set()
        best_items: List[Any] = []
        best_scores = []
        last_score = 1e16
        for score, item, actions in zip(scores, items, reactions):
            if len(best_items) >= min_return and score < last_score:
                break
            route_hash = hash_reactions(actions)

            if route_hash in seen_hashes:
                continue
            seen_hashes.add(route_hash)
            best_items.append(item)
            best_scores.append(score)
            last_score = score

            if max_return and len(best_items) == max_return:
                break

        return best_items, best_scores
Example #5
0
    def sort_nodes(self, min_return=5, max_return=25):
        """
        Sort and select the nodes, so that the best scoring routes are returned.
        The algorithm filter away identical routes and returns at minimum the number specified.
        If multiple alternative routes have the same score as the n'th route, they will be included and returned.

        :param min_return: the minium number of routes to return, defaults to 5
        :type min_return: int, optional
        :param max_return: the maximum number of routes to return
        :type max_return: int, optional
        :return: the nodes
        :rtype: list of Node
        :return: the score
        :rtype: list of float
        """
        nodes = self._all_nodes()
        sorted_nodes, sorted_scores = self.scorer.sort(nodes)

        if len(nodes) <= min_return:
            return sorted_nodes, sorted_scores

        seen_hashes = set()
        best_nodes = []
        best_scores = []
        last_score = 1e16
        for score, node in zip(sorted_scores, sorted_nodes):
            if len(best_nodes) >= min_return and score < last_score:
                break
            route_actions, _ = self.search_tree.route_to_node(node)
            route_hash = hash_reactions(route_actions)

            if route_hash in seen_hashes:
                continue
            seen_hashes.add(route_hash)
            best_nodes.append(node)
            best_scores.append(score)
            last_score = score

            if max_return and len(best_nodes) == max_return:
                break

        return best_nodes, best_scores
    def _collect_top_items(
        items: Union[Sequence[MctsNode], Sequence[ReactionTree]],
        scores: Sequence[float],
        reactions: Sequence[Union[Iterable[RetroReaction],
                                  Iterable[FixedRetroReaction]]],
        selection,
    ) -> Tuple[Union[Sequence[MctsNode], Sequence[ReactionTree]],
               Sequence[float]]:
        if len(items) <= selection.nmin:
            return items, scores

        max_return, min_return = selection.nmax, selection.nmin
        if selection.return_all:
            nsolved = sum(int(item.is_solved) for item in items)
            if nsolved:
                max_return = nsolved
                min_return = nsolved

        seen_hashes = set()
        best_items: List[Any] = []
        best_scores = []
        last_score = 1e16
        for score, item, actions in zip(scores, items, reactions):
            if len(best_items) >= min_return and score < last_score:
                break
            route_hash = hash_reactions(actions)

            if route_hash in seen_hashes:
                continue
            seen_hashes.add(route_hash)
            best_items.append(item)
            best_scores.append(score)
            last_score = score

            if max_return and len(best_items) == max_return:
                break

        return best_items, best_scores