Exemple #1
0
def test_rescore_collection_for_trees(default_config, setup_linear_reaction_tree):
    rt = setup_linear_reaction_tree()
    routes = RouteCollection(reaction_trees=[rt])
    routes.compute_scores(StateScorer(default_config))

    routes.rescore(NumberOfReactionsScorer())

    assert routes.scores[0] == 2
    assert np.round(routes.all_scores[0]["state score"], 3) == 0.994
    assert routes.all_scores[0]["number of reactions"] == 2
def test_clustering_collection_timeout(load_reaction_tree):
    collection = RouteCollection(reaction_trees=[
        ReactionTree.from_dict(
            load_reaction_tree("routes_for_clustering.json", idx))
        for idx in range(3)
    ])
    cluster_labels = collection.cluster(n_clusters=1, timeout=0)

    assert len(cluster_labels) == 0
    assert collection.clusters is None
Exemple #3
0
def test_create_clustering_gui(mocker, load_reaction_tree):
    collection = RouteCollection(reaction_trees=[
        ReactionTree.from_dict(
            load_reaction_tree("routes_for_clustering.json", idx))
        for idx in range(3)
    ])
    display_patch = mocker.patch(
        "aizynthfinder.interfaces.gui.clustering.display")
    ClusteringGui(collection)

    display_patch.assert_called()
Exemple #4
0
def test_rescore_collection_for_trees(default_config, mock_stock,
                                      load_reaction_tree):
    mock_stock(default_config, "N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1",
               "CN1CCC(Cl)CC1", "O")
    rt = ReactionTree.from_dict(load_reaction_tree("sample_reaction.json"))
    routes = RouteCollection(reaction_trees=[rt])
    routes.compute_scores(StateScorer(default_config))

    routes.rescore(NumberOfReactionsScorer())

    assert routes.scores[0] == 2
    assert np.round(routes.all_scores[0]["state score"], 3) == 0.994
    assert routes.all_scores[0]["number of reactions"] == 2
Exemple #5
0
def test_compute_new_score_for_trees(default_config, setup_linear_reaction_tree):
    rt = setup_linear_reaction_tree()
    routes = RouteCollection(reaction_trees=[rt])

    assert routes.nodes[0] is None
    assert routes.scores[0] is np.nan
    assert routes.all_scores[0] == {}

    routes.compute_scores(StateScorer(default_config), NumberOfReactionsScorer())

    assert routes.scores[0] is np.nan
    assert np.round(routes.all_scores[0]["state score"], 3) == 0.994
    assert routes.all_scores[0]["number of reactions"] == 2
Exemple #6
0
def test_clustering_collection(load_reaction_tree):
    collection = RouteCollection(reaction_trees=[
        ReactionTree.from_dict(
            load_reaction_tree("routes_for_clustering.json", idx))
        for idx in range(3)
    ])
    collection.cluster(n_clusters=1)

    assert len(collection.clusters) == 2
    assert collection.clusters[0].reaction_trees == collection.reaction_trees[
        1:3]
    assert collection.clusters[1].reaction_trees == [
        collection.reaction_trees[0]
    ]
    def prepare_tree(self) -> None:
        """
        Setup the tree for searching

        :raises ValueError: if the target molecule was not set
        """
        if not self.target_mol:
            raise ValueError("No target molecule set")

        self.stock.reset_exclusion_list()
        if self.config.exclude_target_from_stock and self.target_mol in self.stock:
            self.stock.exclude(self.target_mol)
            self._logger.debug("Excluding the target compound from the stock")

        self._setup_search_tree()
        self.analysis = None
        self.routes = RouteCollection([])
Exemple #8
0
def test_create_combine_tree_to_visjs(load_reaction_tree, tmpdir):
    collection = RouteCollection(
        reaction_trees=[
            ReactionTree.from_dict(load_reaction_tree("routes_for_clustering.json", 0)),
            ReactionTree.from_dict(load_reaction_tree("routes_for_clustering.json", 1)),
            ReactionTree.from_dict(load_reaction_tree("routes_for_clustering.json", 2)),
        ]
    )
    tar_filename = str(tmpdir / "routes.tar")
    combined = collection.combined_reaction_trees()

    combined.to_visjs_page(tar_filename)

    assert os.path.exists(tar_filename)
    with TarFile(tar_filename) as tarobj:
        assert "./route.html" in tarobj.getnames()
        assert len([name for name in tarobj.getnames() if name.endswith(".png")]) == 8
Exemple #9
0
def test_distance_collection(load_reaction_tree):
    collection = RouteCollection(reaction_trees=[
        ReactionTree.from_dict(
            load_reaction_tree("routes_for_clustering.json", idx))
        for idx in range(3)
    ])

    dist_mat1 = collection.distance_matrix()
    dist_mat2 = collection.distance_matrix(recreate=True)

    assert (dist_mat1 - dist_mat2).sum() == 0

    dist_mat3 = collection.distance_matrix(content="molecules")

    assert (dist_mat1 - dist_mat3).sum() != 0
    assert len(dist_mat3) == 3
    assert pytest.approx(dist_mat3[0, 1], abs=1e-2) == 2.6522
    assert pytest.approx(dist_mat3[0, 2], abs=1e-2) == 3.0779
    assert pytest.approx(dist_mat3[2, 1], abs=1e-2) == 0.7483
Exemple #10
0
def test_create_combine_tree_dict_from_json(load_reaction_tree):
    collection = RouteCollection(
        reaction_trees=[
            ReactionTree.from_dict(load_reaction_tree("routes_for_clustering.json", 0)),
            ReactionTree.from_dict(load_reaction_tree("routes_for_clustering.json", 1)),
            ReactionTree.from_dict(load_reaction_tree("routes_for_clustering.json", 2)),
        ]
    )
    expected = load_reaction_tree("combined_example_tree.json")

    combined_dict = collection.combined_reaction_trees().to_dict()

    assert len(combined_dict["children"]) == 2
    assert combined_dict["children"][0]["is_reaction"]
    assert len(combined_dict["children"][0]["children"]) == 2
    assert len(combined_dict["children"][1]["children"]) == 2
    assert len(combined_dict["children"][1]["children"][1]["children"]) == 2
    assert combined_dict["children"][1]["children"][1]["children"][0]["is_reaction"]
    assert combined_dict == expected
Exemple #11
0
    def __init__(self, configfile: str = None, configdict: StrDict = None) -> None:
        self._logger = logger()

        if configfile:
            self.config = Configuration.from_file(configfile)
        elif configdict:
            self.config = Configuration.from_dict(configdict)
        else:
            self.config = Configuration()

        self.expansion_policy = self.config.expansion_policy
        self.filter_policy = self.config.filter_policy
        self.stock = self.config.stock
        self.scorers = self.config.scorers
        self.tree: Optional[SearchTree] = None
        self._target_mol: Optional[Molecule] = None
        self.search_stats: StrDict = dict()
        self.routes = RouteCollection([])
        self.analysis: Optional[TreeAnalysis] = None