Пример #1
0
    def test_graph_data(self):
        num_nodes, num_node_features = 5, 32
        num_edges, num_edge_features = 6, 32
        node_features = np.random.random_sample((num_nodes, num_node_features))
        edge_features = np.random.random_sample((num_edges, num_edge_features))
        edge_index = np.array([
            [0, 1, 2, 2, 3, 4],
            [1, 2, 0, 3, 4, 0],
        ])
        node_pos_features = None

        graph = GraphData(node_features=node_features,
                          edge_index=edge_index,
                          edge_features=edge_features,
                          node_pos_features=node_pos_features)

        assert graph.num_nodes == num_nodes
        assert graph.num_node_features == num_node_features
        assert graph.num_edges == num_edges
        assert graph.num_edge_features == num_edge_features

        # check convert function
        pyg_graph = graph.to_pyg_graph()
        from torch_geometric.data import Data
        assert isinstance(pyg_graph, Data)

        dgl_graph = graph.to_dgl_graph()
        from dgl import DGLGraph
        assert isinstance(dgl_graph, DGLGraph)
Пример #2
0
    def test_batch_graph_data(self):
        num_nodes_list, num_edge_list = [3, 4, 5], [2, 4, 5]
        num_node_features, num_edge_features = 32, 32
        edge_index_list = [
            np.array([[0, 1], [1, 2]]),
            np.array([[0, 1, 2, 3], [1, 2, 0, 2]]),
            np.array([[0, 1, 2, 3, 4], [1, 2, 3, 4, 0]]),
        ]

        graph_list = [
            GraphData(node_features=np.random.random_sample(
                (num_nodes_list[i], num_node_features)),
                      edge_index=edge_index_list[i],
                      edge_features=np.random.random_sample(
                          (num_edge_list[i], num_edge_features)),
                      node_pos_features=None)
            for i in range(len(num_edge_list))
        ]
        batch = BatchGraphData(graph_list)

        assert batch.num_nodes == sum(num_nodes_list)
        assert batch.num_node_features == num_node_features
        assert batch.num_edges == sum(num_edge_list)
        assert batch.num_edge_features == num_edge_features
        assert batch.graph_index.shape == (sum(num_nodes_list), )
Пример #3
0
    def _featurize(self, datapoint: PymatgenStructure, **kwargs) -> GraphData:
        """
    Calculate crystal graph features from pymatgen structure.

    Parameters
    ----------
    datapoint: pymatgen.core.Structure
      A periodic crystal composed of a lattice and a sequence of atomic
      sites with 3D coordinates and elements.

    Returns
    -------
    graph: GraphData
      A crystal graph with CGCNN style features.
    """
        if 'struct' in kwargs and datapoint is None:
            datapoint = kwargs.get("struct")
            raise DeprecationWarning(
                'Struct is being phased out as a parameter, please pass "datapoint" instead.'
            )

        node_features = self._get_node_features(datapoint)
        edge_index, edge_features = self._get_edge_features_and_index(
            datapoint)
        graph = GraphData(node_features, edge_index, edge_features)
        return graph
Пример #4
0
  def _featurize(self, datapoint: RDKitMol, **kwargs) -> GraphData:
    """Calculate molecule graph features from RDKit mol object.

    Parameters
    ----------
    datapoint: rdkit.Chem.rdchem.Mol
      RDKit mol object.

    Returns
    -------
    graph: GraphData
      A molecule graph with some features.
    """
    if 'mol' in kwargs:
      datapoint = kwargs.get("mol")
      raise DeprecationWarning(
          'Mol is being phased out as a parameter, please pass "datapoint" instead.'
      )

    node_features = np.asarray(
        [self._pagtn_atom_featurizer(atom) for atom in datapoint.GetAtoms()],
        dtype=np.float)
    edge_index, edge_features = self._pagtn_edge_featurizer(datapoint)
    graph = GraphData(node_features, edge_index, edge_features)
    return graph
Пример #5
0
    def _featurize(self, datapoint: PymatgenStructure, **kwargs) -> GraphData:
        """
    Parameters
    ----------
    datapoint: : PymatgenStructure
      Pymatgen Structure object of the surface configuration. It also requires
      site_properties attribute with "Sitetypes"(Active or spectator site) and
      "oss"(Species of Active site from the list of self.aos and "-1" for
      spectator sites).

    Returns
    -------
    graph: GraphData
      Node features, All edges for each node in diffrent permutations
    """
        if 'structure' in kwargs and datapoint is None:
            datapoint = kwargs.get("structure")
            raise DeprecationWarning(
                'Structure is being phased out as a parameter, please pass "datapoint" instead.'
            )

        xSites, xNSs = self.setup_env.read_datum(datapoint)
        config_size = xNSs.shape
        v = np.arange(0, len(xSites)).repeat(config_size[2] * config_size[3])
        u = xNSs.flatten()
        graph = GraphData(node_features=xSites, edge_index=np.array([u, v]))
        return graph
    def _featurize(self, mol: RDKitMol) -> GraphData:
        """Calculate molecule graph features from RDKit mol object.

    Parameters
    ----------
    mol: rdkit.Chem.rdchem.Mol
      RDKit mol object.

    Returns
    -------
    graph: GraphData
      A molecule graph with some features.
    """
        if self.use_partial_charge:
            try:
                mol.GetAtomWithIdx(0).GetProp('_GasteigerCharge')
            except:
                # If partial charges were not computed
                try:
                    from rdkit.Chem import AllChem
                    AllChem.ComputeGasteigerCharges(mol)
                except ModuleNotFoundError:
                    raise ImportError(
                        "This class requires RDKit to be installed.")

        # construct atom (node) feature
        h_bond_infos = construct_hydrogen_bonding_info(mol)
        atom_features = np.asarray(
            [
                _construct_atom_feature(atom, h_bond_infos, self.use_chirality,
                                        self.use_partial_charge)
                for atom in mol.GetAtoms()
            ],
            dtype=float,
        )

        # construct edge (bond) index
        src, dest = [], []
        for bond in mol.GetBonds():
            # add edge list considering a directed graph
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            src += [start, end]
            dest += [end, start]

        # construct edge (bond) feature
        bond_features = None  # deafult None
        if self.use_edges:
            features = []
            for bond in mol.GetBonds():
                features += 2 * [_construct_bond_feature(bond)]
            bond_features = np.asarray(features, dtype=float)

        return GraphData(node_features=atom_features,
                         edge_index=np.asarray([src, dest], dtype=int),
                         edge_features=bond_features)
    def _featurize(self, mol: RDKitMol) -> GraphData:
        """Calculate molecule graph features from RDKit mol object.

    Parameters
    ----------
    mol: rdkit.Chem.rdchem.Mol
      RDKit mol object.

    Returns
    -------
    graph: GraphData
      A molecule graph with some features.
    """
        from rdkit import Chem
        from rdkit.Chem import AllChem

        # construct atom and bond features
        try:
            mol.GetAtomWithIdx(0).GetProp('_GasteigerCharge')
        except:
            # If partial charges were not computed
            AllChem.ComputeGasteigerCharges(mol)

        h_bond_infos = construct_hydrogen_bonding_info(mol)
        sssr = Chem.GetSymmSSSR(mol)

        # construct atom (node) feature
        atom_features = np.array(
            [
                _construct_atom_feature(atom, h_bond_infos, sssr)
                for atom in mol.GetAtoms()
            ],
            dtype=np.float,
        )

        # construct edge (bond) information
        src, dest, bond_features = [], [], []
        for bond in mol.GetBonds():
            # add edge list considering a directed graph
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            src += [start, end]
            dest += [end, start]
            bond_features += 2 * [_construct_bond_feature(bond)]

        if self.add_self_edges:
            num_atoms = mol.GetNumAtoms()
            src += [i for i in range(num_atoms)]
            dest += [i for i in range(num_atoms)]
            # add dummy edge features
            bond_fea_length = len(bond_features[0])
            bond_features += num_atoms * [[0 for _ in range(bond_fea_length)]]

        return GraphData(node_features=atom_features,
                         edge_index=np.array([src, dest], dtype=np.int),
                         edge_features=np.array(bond_features, dtype=np.float))
Пример #8
0
    def test_invalid_graph_data(self):
        with self.assertRaises(ValueError):
            invalid_node_features_type = list(np.random.random_sample((5, 32)))
            edge_index = np.array([
                [0, 1, 2, 2, 3, 4],
                [1, 2, 0, 3, 4, 0],
            ])
            _ = GraphData(
                node_features=invalid_node_features_type,
                edge_index=edge_index,
            )

        with self.assertRaises(ValueError):
            node_features = np.random.random_sample((5, 32))
            invalid_edge_index_shape = np.array([
                [0, 1, 2, 2, 3, 4],
                [1, 2, 0, 3, 4, 5],
            ])
            _ = GraphData(
                node_features=node_features,
                edge_index=invalid_edge_index_shape,
            )

        with self.assertRaises(ValueError):
            node_features = np.random.random_sample((5, 5))
            invalid_edge_index_shape = np.array([
                [0, 1, 2, 2, 3, 4],
                [1, 2, 0, 3, 4, 0],
                [2, 2, 1, 4, 0, 3],
            ], )
            _ = GraphData(
                node_features=node_features,
                edge_index=invalid_edge_index_shape,
            )

        with self.assertRaises(TypeError):
            node_features = np.random.random_sample((5, 32))
            _ = GraphData(node_features=node_features)
Пример #9
0
    def test_graph_data(self):
        num_nodes, num_node_features = 5, 32
        num_edges, num_edge_features = 6, 32
        node_features = np.random.random_sample((num_nodes, num_node_features))
        edge_features = np.random.random_sample((num_edges, num_edge_features))
        edge_index = np.array([
            [0, 1, 2, 2, 3, 4],
            [1, 2, 0, 3, 4, 0],
        ])
        node_pos_features = None
        # z is kwargs
        z = np.random.random(5)

        graph = GraphData(node_features=node_features,
                          edge_index=edge_index,
                          edge_features=edge_features,
                          node_pos_features=node_pos_features,
                          z=z)

        assert graph.num_nodes == num_nodes
        assert graph.num_node_features == num_node_features
        assert graph.num_edges == num_edges
        assert graph.num_edge_features == num_edge_features
        assert graph.z.shape == z.shape
        assert str(
            graph
        ) == 'GraphData(node_features=[5, 32], edge_index=[2, 6], edge_features=[6, 32], z=[5])'

        # check convert function
        pyg_graph = graph.to_pyg_graph()
        from torch_geometric.data import Data
        assert isinstance(pyg_graph, Data)
        assert tuple(pyg_graph.z.shape) == z.shape

        dgl_graph = graph.to_dgl_graph()
        from dgl import DGLGraph
        assert isinstance(dgl_graph, DGLGraph)
    def _featurize(self, mol: RDKitMol) -> GraphData:
        """Calculate molecule graph features from RDKit mol object.

    Parameters
    ----------
    mol: rdkit.Chem.rdchem.Mol
      RDKit mol object.

    Returns
    -------
    graph: GraphData
      A molecule graph with some features.
    """
        node_features = np.asarray(
            [self._pagtn_atom_featurizer(atom) for atom in mol.GetAtoms()],
            dtype=np.float)
        edge_index, edge_features = self._pagtn_edge_featurizer(mol)
        graph = GraphData(node_features, edge_index, edge_features)
        return graph
Пример #11
0
    def _featurize(self, struct: PymatgenStructure) -> GraphData:
        """
    Calculate crystal graph features from pymatgen structure.

    Parameters
    ----------
    struct: pymatgen.Structure
      A periodic crystal composed of a lattice and a sequence of atomic
      sites with 3D coordinates and elements.

    Returns
    -------
    graph: GraphData
      A crystal graph with CGCNN style features.
    """

        node_features = self._get_node_features(struct)
        edge_index, edge_features = self._get_edge_features_and_index(struct)
        graph = GraphData(node_features, edge_index, edge_features)
        return graph
Пример #12
0
  def _featurize(self, structure: PymatgenStructure) -> GraphData:
    """
    Parameters
    ----------
    structure: : PymatgenStructure
      Pymatgen Structure object of the surface configuration. It also requires
      site_properties attribute with "Sitetypes"(Active or spectator site) and
      "oss"(Species of Active site from the list of self.aos and "-1" for
      spectator sites).

    Returns
    -------
    graph: GraphData
      Node features, All edges for each node in diffrent permutations
    """
    xSites, xNSs = self.setup_env.read_datum(structure)
    config_size = xNSs.shape
    v = np.arange(0, len(xSites)).repeat(config_size[2] * config_size[3])
    u = xNSs.flatten()
    graph = GraphData(node_features=xSites, edge_index=np.array([u, v]))
    return graph
Пример #13
0
    def _featurize(self, datapoint: RDKitMol, **kwargs) -> GraphData:
        """Calculate molecule graph features from RDKit mol object.

    Parameters
    ----------
    datapoint: rdkit.Chem.rdchem.Mol
      RDKit mol object.

    Returns
    -------
    graph: GraphData
      A molecule graph with some features.
    """
        assert datapoint.GetNumAtoms(
        ) > 1, "More than one atom should be present in the molecule for this featurizer to work."
        if 'mol' in kwargs:
            datapoint = kwargs.get("mol")
            raise DeprecationWarning(
                'Mol is being phased out as a parameter, please pass "datapoint" instead.'
            )

        if self.use_partial_charge:
            try:
                datapoint.GetAtomWithIdx(0).GetProp('_GasteigerCharge')
            except:
                # If partial charges were not computed
                try:
                    from rdkit.Chem import AllChem
                    AllChem.ComputeGasteigerCharges(datapoint)
                except ModuleNotFoundError:
                    raise ImportError(
                        "This class requires RDKit to be installed.")

        # construct atom (node) feature
        h_bond_infos = construct_hydrogen_bonding_info(datapoint)
        atom_features = np.asarray(
            [
                _construct_atom_feature(atom, h_bond_infos, self.use_chirality,
                                        self.use_partial_charge)
                for atom in datapoint.GetAtoms()
            ],
            dtype=float,
        )

        # construct edge (bond) index
        src, dest = [], []
        for bond in datapoint.GetBonds():
            # add edge list considering a directed graph
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            src += [start, end]
            dest += [end, start]

        # construct edge (bond) feature
        bond_features = None  # deafult None
        if self.use_edges:
            features = []
            for bond in datapoint.GetBonds():
                features += 2 * [_construct_bond_feature(bond)]
            bond_features = np.asarray(features, dtype=float)

        return GraphData(node_features=atom_features,
                         edge_index=np.asarray([src, dest], dtype=int),
                         edge_features=bond_features)