Ejemplo n.º 1
0
def test_find_repetetive_patterns_created_tree_no_patterns(
        default_config, mock_stock, shared_datadir):
    mock_stock(default_config, Molecule(smiles="CC"), Molecule(smiles="CCCO"))

    # Try with a short tree (3 nodes, 1 reaction)
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_without_repetition.json", default_config)
    analysis = TreeAnalysis(search_tree)

    rt = ReactionTree.from_analysis(analysis)

    assert not rt.has_repeating_patterns
    hidden_nodes = [
        node for node in rt.graph if rt.graph.nodes[node].get("hide", False)
    ]
    assert len(hidden_nodes) == 0

    # Try with something longer
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_without_repetition_longer.json", default_config)
    analysis = TreeAnalysis(search_tree)

    rt = ReactionTree.from_analysis(analysis)

    assert not rt.has_repeating_patterns
Ejemplo n.º 2
0
def test_find_repetetive_patterns_created_tree(default_config, mock_stock,
                                               shared_datadir):
    mock_stock(default_config, Molecule(smiles="CC"), Molecule(smiles="C"))

    # Try one with 2 repetetive units
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_with_repetition.json", default_config)
    analysis = TreeAnalysis(search_tree)

    rt = ReactionTree.from_analysis(analysis)

    assert rt.has_repeating_patterns
    hidden_nodes = [
        node for node in rt.graph if rt.graph.nodes[node].get("hide", False)
    ]
    assert len(hidden_nodes) == 5

    # Try one with 3 repetetive units
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_with_3_repetitions.json", default_config)
    analysis = TreeAnalysis(search_tree)

    rt = ReactionTree.from_analysis(analysis)

    assert rt.has_repeating_patterns
    hidden_nodes = [
        node for node in rt.graph if rt.graph.nodes[node].get("hide", False)
    ]
    assert len(hidden_nodes) == 10
Ejemplo n.º 3
0
def test_template_occurence_scorer_no_metadata(shared_datadir, default_config):
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_without_repetition.json", default_config)
    nodes = list(search_tree.graph())
    scorer = AverageTemplateOccurenceScorer()

    assert scorer(nodes[1]) == 0
Ejemplo n.º 4
0
def test_number_of_reaction_scorer_node(shared_datadir, default_config):
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_without_repetition.json", default_config)
    nodes = list(search_tree.graph())
    scorer = NumberOfReactionsScorer()

    assert scorer(nodes[1]) == 1
Ejemplo n.º 5
0
def test_scoring_branch_mcts_tree_in_stock(shared_datadir, default_config,
                                           mock_stock):
    mock_stock(
        default_config,
        "CC(C)(C)CO",
        "CC(C)(C)OC(=O)N(CCCl)CCCl",
        "N#CCc1cccc(O)c1F",
        "O=[N+]([O-])c1ccccc1F",
        "O=C1CCC(=O)N1Br",
        "O=C=Nc1csc(C(F)(F)F)n1",
        "CCC[Sn](Cl)(CCC)CCC",
        "COc1ccc2ncsc2c1",
    )
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_with_branching.json", default_config)
    nodes = list(search_tree.graph())

    assert pytest.approx(StateScorer(default_config)(nodes[-1]),
                         abs=1e-3) == 0.950
    assert NumberOfReactionsScorer()(nodes[-1]) == 14
    assert NumberOfPrecursorsScorer(default_config)(nodes[-1]) == 8
    assert NumberOfPrecursorsInStockScorer(default_config)(nodes[-1]) == 8
    assert PriceSumScorer(default_config)(nodes[-1]) == 8
    cost_score = RouteCostScorer(default_config)(nodes[-1])
    assert pytest.approx(cost_score, abs=1e-3) == 77.4797
Ejemplo n.º 6
0
def test_create_combine_tree_dict_from_tree(mock_stock, default_config,
                                            load_reaction_tree,
                                            shared_datadir):
    mock_stock(
        default_config,
        "Nc1ccc(NC(=S)Nc2ccccc2)cc1",
        "Cc1ccc2nc3ccccc3c(Cl)c2c1",
        "Nc1ccc(N)cc1",
        "S=C=Nc1ccccc1",
        "Cc1ccc2nc3ccccc3c(N)c2c1",
        "Nc1ccc(Br)cc1",
    )
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_for_clustering.json", default_config)
    analysis = TreeAnalysis(search_tree)
    collection = RouteCollection.from_analysis(analysis, 3)
    expected = load_reaction_tree("combined_example_tree.json")

    combined_dict = collection.combined_reaction_trees().to_dict()

    assert len(combined_dict["children"]) == 2
    assert combined_dict["children"][0]["is_reaction"]
    assert len(combined_dict["children"][0]["children"]) == 2
    assert len(combined_dict["children"][1]["children"]) == 2
    assert len(combined_dict["children"][1]["children"][0]["children"]) == 2
    assert combined_dict["children"][1]["children"][0]["children"][0][
        "is_reaction"]
    assert combined_dict == expected
Ejemplo n.º 7
0
def test_template_occurence_scorer(shared_datadir, default_config):
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_without_repetition.json", default_config)
    nodes = list(search_tree.graph())
    nodes[0][nodes[1]]["action"].metadata["library_occurence"] = 5
    scorer = AverageTemplateOccurenceScorer()

    assert scorer(nodes[0]) == 0
    assert scorer(nodes[1]) == 5
Ejemplo n.º 8
0
def test_scoring_branched_mcts_tree(shared_datadir, default_config):
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_with_branching.json", default_config)
    nodes = list(search_tree.graph())

    assert pytest.approx(StateScorer()(nodes[-1]), abs=1e-6) == 0.00012363
    assert NumberOfReactionsScorer()(nodes[-1]) == 14
    assert NumberOfPrecursorsScorer(default_config)(nodes[-1]) == 8
    assert NumberOfPrecursorsInStockScorer(default_config)(nodes[-1]) == 0
def test_serialize_deserialize_tree(
    fresh_tree,
    generate_root,
    simple_actions,
    mock_expansion_policy,
    default_config,
    mocker,
    tmpdir,
):
    serializer = MoleculeSerializer()
    root = generate_root("CCCCOc1ccc(CC(=O)N(C)O)cc1")
    fresh_tree.root = root
    action_list, prior_list = mock_expansion_policy(root.state.mols[0])
    root.expand()
    child = root.promising_child()
    mocked_json_dump = mocker.patch("aizynthfinder.mcts.mcts.json.dump")
    serializer = MoleculeSerializer()
    filename = str(tmpdir / "dummy.json")

    # Test serialization

    fresh_tree.serialize(filename)

    expected_dict = {
        "tree": root.serialize(serializer),
        "molecules": serializer.store
    }
    mocked_json_dump.assert_called_once_with(expected_dict,
                                             mocker.ANY,
                                             indent=mocker.ANY)

    # Test deserialization

    mocker.patch("aizynthfinder.mcts.mcts.json.load",
                 return_value=expected_dict)

    new_tree = SearchTree.from_json(filename, default_config)
    root_new = new_tree.root
    assert len(root_new.children()) == 1

    new_child = root_new.children()[0]
    assert root_new.children_view()["values"] == root.children_view()["values"]
    assert root_new.children_view()["priors"] == root.children_view()["priors"]
    assert (root_new.children_view()["visitations"] == root.children_view()
            ["visitations"])
    assert root_new.is_expanded
    assert new_child.children_view()["values"] == child.children_view(
    )["values"]
    assert new_child.children_view()["priors"] == child.children_view(
    )["priors"]
    assert (new_child.children_view()["visitations"] == child.children_view()
            ["visitations"])
    assert not new_child.is_expanded
    assert str(root_new.state) == str(root.state)
    assert str(new_child.state) == str(child.state)
Ejemplo n.º 10
0
def test_sort(shared_datadir, default_config, mock_stock):
    mock_stock(default_config, "CCCO", "CC")
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_without_repetition.json", default_config)
    nodes = list(search_tree.graph())
    scorer = StateScorer(default_config)

    sorted_nodes, scores, _ = scorer.sort(nodes)

    assert [np.round(score, 4) for score in scores] == [0.9976, 0.0491]
    assert sorted_nodes == [nodes[1], nodes[0]]
Ejemplo n.º 11
0
def setup_analysis(default_config, shared_datadir, tmpdir, mock_stock):
    mock_stock(
        default_config, "N#Cc1cccc(N)c1F", "O=C(Cl)c1ccc(F)cc1", "CN1CCC(Cl)CC1", "O"
    )

    with gzip.open(shared_datadir / "full_search_tree.json.gz", "rb") as gzip_obj:
        with open(tmpdir / "full_search_tree.json", "wb") as fileobj:
            fileobj.write(gzip_obj.read())
    tree = SearchTree.from_json(tmpdir / "full_search_tree.json", default_config)
    nodes = list(tree.graph())

    def wrapper(scorer=None):
        return TreeAnalysis(tree, scorer=scorer), nodes

    return wrapper
Ejemplo n.º 12
0
def test_route_node_depth_from_analysis(default_config, mock_stock,
                                        shared_datadir):
    mock_stock(default_config, Molecule(smiles="CC"), Molecule(smiles="CCCO"))
    search_tree = SearchTree.from_json(
        shared_datadir / "tree_without_repetition.json", default_config)
    analysis = TreeAnalysis(search_tree)
    rt = ReactionTree.from_analysis(analysis)

    mols = list(rt.molecules())

    assert rt.depth(mols[0]) == 0
    assert rt.depth(mols[1]) == 2
    assert rt.depth(mols[2]) == 2

    rxns = list(rt.reactions())
    assert rt.depth(rxns[0]) == 1

    for mol in rt.molecules():
        assert rt.depth(mol) == 2 * rt.graph.nodes[mol]["transform"]