def __init__(self, filename, config): super().__init__(config) self._mol_nodes = [] with open(filename, "r") as fileobj: dict_ = json.load(fileobj) mol_deser = MoleculeDeserializer(dict_["molecules"]) self.root = AndOrNode(dict_["tree"], config, mol_deser, self)
def test_deserialize_node(generate_root, simple_actions, mock_policy, default_config): serializer = MoleculeSerializer() root = generate_root("CCCCOc1ccc(CC(=O)N(C)O)cc1") action_list, prior_list = mock_policy(root.state.mols[0]) root.expand() child = root.promising_child() node_serialized = root.serialize(serializer) deserializer = MoleculeDeserializer(serializer.store) root_new = Node.from_dict(node_serialized, None, default_config, deserializer) assert len(root_new.children()) == 1 new_child = root_new.children()[0] assert root_new.children_view()["values"] == root.children_view()["values"] assert root_new.children_view()["priors"] == root.children_view()["priors"] assert ( root_new.children_view()["visitations"] == root.children_view()["visitations"] ) assert root_new.is_expanded assert new_child.children_view()["values"] == child.children_view()["values"] assert new_child.children_view()["priors"] == child.children_view()["priors"] assert ( new_child.children_view()["visitations"] == child.children_view()["visitations"] ) assert not new_child.is_expanded assert str(root_new.state) == str(root.state) assert str(new_child.state) == str(child.state)
def test_chaining(): serializer = MoleculeSerializer() mol1 = TreeMolecule(parent=None, smiles="CCC", transform=1) mol2 = TreeMolecule(smiles="CCO", parent=mol1) id_ = serializer[mol2] deserializer = MoleculeDeserializer(serializer.store) assert deserializer[id_].smiles == mol2.smiles assert deserializer[id(mol1)].smiles == mol1.smiles assert id(deserializer[id_]) != id_ assert id(deserializer[id(mol1)]) != id(mol1)
def from_json(cls, filename: str, config: Configuration) -> "SearchTree": """ Create a new search tree by deserialization from a JSON file :param filename: the path to the JSON node :param config: the configuration of the search :return: a deserialized tree """ tree = SearchTree(config) with open(filename, "r") as fileobj: dict_ = json.load(fileobj) mol_deser = MoleculeDeserializer(dict_["molecules"]) tree.root = Node.from_dict(dict_["tree"], tree, config, mol_deser) return tree
def from_dict(cls, dict_: StrDict, config: Configuration, molecules: MoleculeDeserializer) -> "State": """ Create a new state from a dictionary, i.e. deserialization :param dict_: the serialized state :type dict_: dict :param config: settings of the tree search algorithm :type config: Configuration :param molecules: the deserialized molecules :type molecules: MoleculeDeserializer :return: a deserialized state :rtype: State """ mols = molecules.get_tree_molecules(dict_["mols"]) return State(mols, config)
def test_serialize_deserialize_state(default_config): mol = TreeMolecule(parent=None, smiles="CCC", transform=1) state0 = State([mol], default_config) serializer = MoleculeSerializer() state_serialized = state0.serialize(serializer) assert len(state_serialized["mols"]) == 1 assert state_serialized["mols"][0] == id(mol) deserializer = MoleculeDeserializer(serializer.store) state1 = State.from_dict(state_serialized, default_config, deserializer) assert len(state1.mols) == 1 assert state1.mols[0] == state0.mols[0] assert state1.in_stock_list == state0.in_stock_list assert state1.score == state0.score
def from_dict( cls, dict_: StrDict, tree: SearchTree, config: Configuration, molecules: MoleculeDeserializer, parent: "Node" = None, ) -> "Node": """ Create a new node from a dictionary, i.e. deserialization :param dict_: the serialized node :param tree: the search tree :param config: settings of the tree search algorithm :param molecules: the deserialized molecules :param parent: the parent node :return: a deserialized node """ state = State.from_dict(dict_["state"], config, molecules) node = Node(state=state, owner=tree, config=config, parent=parent) node.is_expanded = dict_["is_expanded"] node.is_expandable = dict_["is_expandable"] node._children_values = dict_["children_values"] node._children_priors = dict_["children_priors"] node._children_visitations = dict_["children_visitations"] node._children_actions = [] for action in dict_["children_actions"]: mol = molecules.get_tree_molecules([action["mol"]])[0] node._children_actions.append( RetroReaction( mol, action["smarts"], action["index"], action.get("metadata", {}), ) ) node._children = [ Node.from_dict(child, tree, config, molecules, parent=node) if child else None for child in dict_["children"] ] return node
def test_deserialize_tree_mols(): store = { 123: { "smiles": "CCC", "class": "TreeMolecule", "parent": None, "transform": 1, }, 234: {"smiles": "CCO", "class": "TreeMolecule", "parent": 123, "transform": 2}, } deserializer = MoleculeDeserializer(store) assert deserializer[123].smiles == "CCC" assert deserializer[234].smiles == "CCO" assert deserializer[123].parent is None assert deserializer[234].parent is deserializer[123] assert deserializer[123].transform == 1 assert deserializer[234].transform == 2
def test_deserialize_single_mol(): store = {123: {"smiles": "CCC", "class": "Molecule"}} deserializer = MoleculeDeserializer(store) assert deserializer[123].smiles == "CCC"