示例#1
0
 def __init__(self, filename, config):
     super().__init__(config)
     self._mol_nodes = []
     with open(filename, "r") as fileobj:
         dict_ = json.load(fileobj)
     mol_deser = MoleculeDeserializer(dict_["molecules"])
     self.root = AndOrNode(dict_["tree"], config, mol_deser, self)
def test_deserialize_node(generate_root, simple_actions, mock_policy, default_config):
    serializer = MoleculeSerializer()
    root = generate_root("CCCCOc1ccc(CC(=O)N(C)O)cc1")
    action_list, prior_list = mock_policy(root.state.mols[0])
    root.expand()
    child = root.promising_child()
    node_serialized = root.serialize(serializer)
    deserializer = MoleculeDeserializer(serializer.store)

    root_new = Node.from_dict(node_serialized, None, default_config, deserializer)
    assert len(root_new.children()) == 1

    new_child = root_new.children()[0]
    assert root_new.children_view()["values"] == root.children_view()["values"]
    assert root_new.children_view()["priors"] == root.children_view()["priors"]
    assert (
        root_new.children_view()["visitations"] == root.children_view()["visitations"]
    )
    assert root_new.is_expanded
    assert new_child.children_view()["values"] == child.children_view()["values"]
    assert new_child.children_view()["priors"] == child.children_view()["priors"]
    assert (
        new_child.children_view()["visitations"] == child.children_view()["visitations"]
    )
    assert not new_child.is_expanded
    assert str(root_new.state) == str(root.state)
    assert str(new_child.state) == str(child.state)
def test_chaining():
    serializer = MoleculeSerializer()
    mol1 = TreeMolecule(parent=None, smiles="CCC", transform=1)
    mol2 = TreeMolecule(smiles="CCO", parent=mol1)

    id_ = serializer[mol2]

    deserializer = MoleculeDeserializer(serializer.store)

    assert deserializer[id_].smiles == mol2.smiles
    assert deserializer[id(mol1)].smiles == mol1.smiles
    assert id(deserializer[id_]) != id_
    assert id(deserializer[id(mol1)]) != id(mol1)
示例#4
0
    def from_json(cls, filename: str, config: Configuration) -> "SearchTree":
        """
        Create a new search tree by deserialization from a JSON file

        :param filename: the path to the JSON node
        :param config: the configuration of the search
        :return: a deserialized tree
        """
        tree = SearchTree(config)
        with open(filename, "r") as fileobj:
            dict_ = json.load(fileobj)
        mol_deser = MoleculeDeserializer(dict_["molecules"])
        tree.root = Node.from_dict(dict_["tree"], tree, config, mol_deser)
        return tree
示例#5
0
    def from_dict(cls, dict_: StrDict, config: Configuration,
                  molecules: MoleculeDeserializer) -> "State":
        """
        Create a new state from a dictionary, i.e. deserialization

        :param dict_: the serialized state
        :type dict_: dict
        :param config: settings of the tree search algorithm
        :type config: Configuration
        :param molecules: the deserialized molecules
        :type molecules: MoleculeDeserializer
        :return: a deserialized state
        :rtype: State
        """
        mols = molecules.get_tree_molecules(dict_["mols"])
        return State(mols, config)
def test_serialize_deserialize_state(default_config):
    mol = TreeMolecule(parent=None, smiles="CCC", transform=1)
    state0 = State([mol], default_config)
    serializer = MoleculeSerializer()

    state_serialized = state0.serialize(serializer)

    assert len(state_serialized["mols"]) == 1
    assert state_serialized["mols"][0] == id(mol)

    deserializer = MoleculeDeserializer(serializer.store)
    state1 = State.from_dict(state_serialized, default_config, deserializer)

    assert len(state1.mols) == 1
    assert state1.mols[0] == state0.mols[0]
    assert state1.in_stock_list == state0.in_stock_list
    assert state1.score == state0.score
示例#7
0
    def from_dict(
        cls,
        dict_: StrDict,
        tree: SearchTree,
        config: Configuration,
        molecules: MoleculeDeserializer,
        parent: "Node" = None,
    ) -> "Node":
        """
        Create a new node from a dictionary, i.e. deserialization

        :param dict_: the serialized node
        :param tree: the search tree
        :param config: settings of the tree search algorithm
        :param molecules: the deserialized molecules
        :param parent: the parent node
        :return: a deserialized node
        """
        state = State.from_dict(dict_["state"], config, molecules)
        node = Node(state=state, owner=tree, config=config, parent=parent)
        node.is_expanded = dict_["is_expanded"]
        node.is_expandable = dict_["is_expandable"]
        node._children_values = dict_["children_values"]
        node._children_priors = dict_["children_priors"]
        node._children_visitations = dict_["children_visitations"]
        node._children_actions = []
        for action in dict_["children_actions"]:
            mol = molecules.get_tree_molecules([action["mol"]])[0]
            node._children_actions.append(
                RetroReaction(
                    mol,
                    action["smarts"],
                    action["index"],
                    action.get("metadata", {}),
                )
            )

        node._children = [
            Node.from_dict(child, tree, config, molecules, parent=node)
            if child
            else None
            for child in dict_["children"]
        ]
        return node
def test_deserialize_tree_mols():
    store = {
        123: {
            "smiles": "CCC",
            "class": "TreeMolecule",
            "parent": None,
            "transform": 1,
        },
        234: {"smiles": "CCO", "class": "TreeMolecule", "parent": 123, "transform": 2},
    }

    deserializer = MoleculeDeserializer(store)

    assert deserializer[123].smiles == "CCC"
    assert deserializer[234].smiles == "CCO"
    assert deserializer[123].parent is None
    assert deserializer[234].parent is deserializer[123]
    assert deserializer[123].transform == 1
    assert deserializer[234].transform == 2
def test_deserialize_single_mol():
    store = {123: {"smiles": "CCC", "class": "Molecule"}}
    deserializer = MoleculeDeserializer(store)

    assert deserializer[123].smiles == "CCC"