コード例 #1
0
 def wrapper(root_smiles, config):
     mol = TreeMolecule(parent=None, transform=0, smiles=root_smiles)
     state = State(mols=[mol], config=config)
     mocked_create_root.return_value = Node(state=state,
                                            owner=None,
                                            config=config)
     return mol
コード例 #2
0
ファイル: mcts.py プロジェクト: lkjoutlook/aizynthfinder-1
 def __init__(self, config, root_smiles=None):
     if root_smiles:
         self.root = Node.create_root(smiles=root_smiles, tree=self, config=config)
     else:
         self.root = None
     self._config = config
     self._graph = None
コード例 #3
0
def test_deserialize_node(generate_root, simple_actions, mock_policy, default_config):
    serializer = MoleculeSerializer()
    root = generate_root("CCCCOc1ccc(CC(=O)N(C)O)cc1")
    action_list, prior_list = mock_policy(root.state.mols[0])
    root.expand()
    child = root.promising_child()
    node_serialized = root.serialize(serializer)
    deserializer = MoleculeDeserializer(serializer.store)

    root_new = Node.from_dict(node_serialized, None, default_config, deserializer)
    assert len(root_new.children()) == 1

    new_child = root_new.children()[0]
    assert root_new.children_view()["values"] == root.children_view()["values"]
    assert root_new.children_view()["priors"] == root.children_view()["priors"]
    assert (
        root_new.children_view()["visitations"] == root.children_view()["visitations"]
    )
    assert root_new.is_expanded
    assert new_child.children_view()["values"] == child.children_view()["values"]
    assert new_child.children_view()["priors"] == child.children_view()["priors"]
    assert (
        new_child.children_view()["visitations"] == child.children_view()["visitations"]
    )
    assert not new_child.is_expanded
    assert str(root_new.state) == str(root.state)
    assert str(new_child.state) == str(child.state)
コード例 #4
0
ファイル: mcts.py プロジェクト: naisuu/aizynthfinder
 def __init__(self, config: Configuration, root_smiles: str = None) -> None:
     if root_smiles:
         self.root: Optional[Node] = Node.create_root(smiles=root_smiles,
                                                      tree=self,
                                                      config=config)
     else:
         self.root = None
     self.config = config
     self._graph: Optional[nx.DiGraph] = None
コード例 #5
0
ファイル: mcts.py プロジェクト: naisuu/aizynthfinder
    def from_json(cls, filename: str, config: Configuration) -> "SearchTree":
        """
        Create a new search tree by deserialization from a JSON file

        :param filename: the path to the JSON node
        :param config: the configuration of the search
        :return: a deserialized tree
        """
        tree = SearchTree(config)
        with open(filename, "r") as fileobj:
            dict_ = json.load(fileobj)
        mol_deser = MoleculeDeserializer(dict_["molecules"])
        tree.root = Node.from_dict(dict_["tree"], tree, config, mol_deser)
        return tree
コード例 #6
0
 def wrapper(smiles):
     return Node.create_root(smiles, tree=None, config=default_config)