def test_graph_data(self): num_nodes, num_node_features = 5, 32 num_edges, num_edge_features = 6, 32 node_features = np.random.random_sample((num_nodes, num_node_features)) edge_features = np.random.random_sample((num_edges, num_edge_features)) edge_index = np.array([ [0, 1, 2, 2, 3, 4], [1, 2, 0, 3, 4, 0], ]) node_pos_features = None graph = GraphData(node_features=node_features, edge_index=edge_index, edge_features=edge_features, node_pos_features=node_pos_features) assert graph.num_nodes == num_nodes assert graph.num_node_features == num_node_features assert graph.num_edges == num_edges assert graph.num_edge_features == num_edge_features # check convert function pyg_graph = graph.to_pyg_graph() from torch_geometric.data import Data assert isinstance(pyg_graph, Data) dgl_graph = graph.to_dgl_graph() from dgl import DGLGraph assert isinstance(dgl_graph, DGLGraph)
def test_batch_graph_data(self): num_nodes_list, num_edge_list = [3, 4, 5], [2, 4, 5] num_node_features, num_edge_features = 32, 32 edge_index_list = [ np.array([[0, 1], [1, 2]]), np.array([[0, 1, 2, 3], [1, 2, 0, 2]]), np.array([[0, 1, 2, 3, 4], [1, 2, 3, 4, 0]]), ] graph_list = [ GraphData(node_features=np.random.random_sample( (num_nodes_list[i], num_node_features)), edge_index=edge_index_list[i], edge_features=np.random.random_sample( (num_edge_list[i], num_edge_features)), node_pos_features=None) for i in range(len(num_edge_list)) ] batch = BatchGraphData(graph_list) assert batch.num_nodes == sum(num_nodes_list) assert batch.num_node_features == num_node_features assert batch.num_edges == sum(num_edge_list) assert batch.num_edge_features == num_edge_features assert batch.graph_index.shape == (sum(num_nodes_list), )
def _featurize(self, datapoint: PymatgenStructure, **kwargs) -> GraphData: """ Calculate crystal graph features from pymatgen structure. Parameters ---------- datapoint: pymatgen.core.Structure A periodic crystal composed of a lattice and a sequence of atomic sites with 3D coordinates and elements. Returns ------- graph: GraphData A crystal graph with CGCNN style features. """ if 'struct' in kwargs and datapoint is None: datapoint = kwargs.get("struct") raise DeprecationWarning( 'Struct is being phased out as a parameter, please pass "datapoint" instead.' ) node_features = self._get_node_features(datapoint) edge_index, edge_features = self._get_edge_features_and_index( datapoint) graph = GraphData(node_features, edge_index, edge_features) return graph
def _featurize(self, datapoint: RDKitMol, **kwargs) -> GraphData: """Calculate molecule graph features from RDKit mol object. Parameters ---------- datapoint: rdkit.Chem.rdchem.Mol RDKit mol object. Returns ------- graph: GraphData A molecule graph with some features. """ if 'mol' in kwargs: datapoint = kwargs.get("mol") raise DeprecationWarning( 'Mol is being phased out as a parameter, please pass "datapoint" instead.' ) node_features = np.asarray( [self._pagtn_atom_featurizer(atom) for atom in datapoint.GetAtoms()], dtype=np.float) edge_index, edge_features = self._pagtn_edge_featurizer(datapoint) graph = GraphData(node_features, edge_index, edge_features) return graph
def _featurize(self, datapoint: PymatgenStructure, **kwargs) -> GraphData: """ Parameters ---------- datapoint: : PymatgenStructure Pymatgen Structure object of the surface configuration. It also requires site_properties attribute with "Sitetypes"(Active or spectator site) and "oss"(Species of Active site from the list of self.aos and "-1" for spectator sites). Returns ------- graph: GraphData Node features, All edges for each node in diffrent permutations """ if 'structure' in kwargs and datapoint is None: datapoint = kwargs.get("structure") raise DeprecationWarning( 'Structure is being phased out as a parameter, please pass "datapoint" instead.' ) xSites, xNSs = self.setup_env.read_datum(datapoint) config_size = xNSs.shape v = np.arange(0, len(xSites)).repeat(config_size[2] * config_size[3]) u = xNSs.flatten() graph = GraphData(node_features=xSites, edge_index=np.array([u, v])) return graph
def _featurize(self, mol: RDKitMol) -> GraphData: """Calculate molecule graph features from RDKit mol object. Parameters ---------- mol: rdkit.Chem.rdchem.Mol RDKit mol object. Returns ------- graph: GraphData A molecule graph with some features. """ if self.use_partial_charge: try: mol.GetAtomWithIdx(0).GetProp('_GasteigerCharge') except: # If partial charges were not computed try: from rdkit.Chem import AllChem AllChem.ComputeGasteigerCharges(mol) except ModuleNotFoundError: raise ImportError( "This class requires RDKit to be installed.") # construct atom (node) feature h_bond_infos = construct_hydrogen_bonding_info(mol) atom_features = np.asarray( [ _construct_atom_feature(atom, h_bond_infos, self.use_chirality, self.use_partial_charge) for atom in mol.GetAtoms() ], dtype=float, ) # construct edge (bond) index src, dest = [], [] for bond in mol.GetBonds(): # add edge list considering a directed graph start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() src += [start, end] dest += [end, start] # construct edge (bond) feature bond_features = None # deafult None if self.use_edges: features = [] for bond in mol.GetBonds(): features += 2 * [_construct_bond_feature(bond)] bond_features = np.asarray(features, dtype=float) return GraphData(node_features=atom_features, edge_index=np.asarray([src, dest], dtype=int), edge_features=bond_features)
def _featurize(self, mol: RDKitMol) -> GraphData: """Calculate molecule graph features from RDKit mol object. Parameters ---------- mol: rdkit.Chem.rdchem.Mol RDKit mol object. Returns ------- graph: GraphData A molecule graph with some features. """ from rdkit import Chem from rdkit.Chem import AllChem # construct atom and bond features try: mol.GetAtomWithIdx(0).GetProp('_GasteigerCharge') except: # If partial charges were not computed AllChem.ComputeGasteigerCharges(mol) h_bond_infos = construct_hydrogen_bonding_info(mol) sssr = Chem.GetSymmSSSR(mol) # construct atom (node) feature atom_features = np.array( [ _construct_atom_feature(atom, h_bond_infos, sssr) for atom in mol.GetAtoms() ], dtype=np.float, ) # construct edge (bond) information src, dest, bond_features = [], [], [] for bond in mol.GetBonds(): # add edge list considering a directed graph start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() src += [start, end] dest += [end, start] bond_features += 2 * [_construct_bond_feature(bond)] if self.add_self_edges: num_atoms = mol.GetNumAtoms() src += [i for i in range(num_atoms)] dest += [i for i in range(num_atoms)] # add dummy edge features bond_fea_length = len(bond_features[0]) bond_features += num_atoms * [[0 for _ in range(bond_fea_length)]] return GraphData(node_features=atom_features, edge_index=np.array([src, dest], dtype=np.int), edge_features=np.array(bond_features, dtype=np.float))
def test_invalid_graph_data(self): with self.assertRaises(ValueError): invalid_node_features_type = list(np.random.random_sample((5, 32))) edge_index = np.array([ [0, 1, 2, 2, 3, 4], [1, 2, 0, 3, 4, 0], ]) _ = GraphData( node_features=invalid_node_features_type, edge_index=edge_index, ) with self.assertRaises(ValueError): node_features = np.random.random_sample((5, 32)) invalid_edge_index_shape = np.array([ [0, 1, 2, 2, 3, 4], [1, 2, 0, 3, 4, 5], ]) _ = GraphData( node_features=node_features, edge_index=invalid_edge_index_shape, ) with self.assertRaises(ValueError): node_features = np.random.random_sample((5, 5)) invalid_edge_index_shape = np.array([ [0, 1, 2, 2, 3, 4], [1, 2, 0, 3, 4, 0], [2, 2, 1, 4, 0, 3], ], ) _ = GraphData( node_features=node_features, edge_index=invalid_edge_index_shape, ) with self.assertRaises(TypeError): node_features = np.random.random_sample((5, 32)) _ = GraphData(node_features=node_features)
def test_graph_data(self): num_nodes, num_node_features = 5, 32 num_edges, num_edge_features = 6, 32 node_features = np.random.random_sample((num_nodes, num_node_features)) edge_features = np.random.random_sample((num_edges, num_edge_features)) edge_index = np.array([ [0, 1, 2, 2, 3, 4], [1, 2, 0, 3, 4, 0], ]) node_pos_features = None # z is kwargs z = np.random.random(5) graph = GraphData(node_features=node_features, edge_index=edge_index, edge_features=edge_features, node_pos_features=node_pos_features, z=z) assert graph.num_nodes == num_nodes assert graph.num_node_features == num_node_features assert graph.num_edges == num_edges assert graph.num_edge_features == num_edge_features assert graph.z.shape == z.shape assert str( graph ) == 'GraphData(node_features=[5, 32], edge_index=[2, 6], edge_features=[6, 32], z=[5])' # check convert function pyg_graph = graph.to_pyg_graph() from torch_geometric.data import Data assert isinstance(pyg_graph, Data) assert tuple(pyg_graph.z.shape) == z.shape dgl_graph = graph.to_dgl_graph() from dgl import DGLGraph assert isinstance(dgl_graph, DGLGraph)
def _featurize(self, mol: RDKitMol) -> GraphData: """Calculate molecule graph features from RDKit mol object. Parameters ---------- mol: rdkit.Chem.rdchem.Mol RDKit mol object. Returns ------- graph: GraphData A molecule graph with some features. """ node_features = np.asarray( [self._pagtn_atom_featurizer(atom) for atom in mol.GetAtoms()], dtype=np.float) edge_index, edge_features = self._pagtn_edge_featurizer(mol) graph = GraphData(node_features, edge_index, edge_features) return graph
def _featurize(self, struct: PymatgenStructure) -> GraphData: """ Calculate crystal graph features from pymatgen structure. Parameters ---------- struct: pymatgen.Structure A periodic crystal composed of a lattice and a sequence of atomic sites with 3D coordinates and elements. Returns ------- graph: GraphData A crystal graph with CGCNN style features. """ node_features = self._get_node_features(struct) edge_index, edge_features = self._get_edge_features_and_index(struct) graph = GraphData(node_features, edge_index, edge_features) return graph
def _featurize(self, structure: PymatgenStructure) -> GraphData: """ Parameters ---------- structure: : PymatgenStructure Pymatgen Structure object of the surface configuration. It also requires site_properties attribute with "Sitetypes"(Active or spectator site) and "oss"(Species of Active site from the list of self.aos and "-1" for spectator sites). Returns ------- graph: GraphData Node features, All edges for each node in diffrent permutations """ xSites, xNSs = self.setup_env.read_datum(structure) config_size = xNSs.shape v = np.arange(0, len(xSites)).repeat(config_size[2] * config_size[3]) u = xNSs.flatten() graph = GraphData(node_features=xSites, edge_index=np.array([u, v])) return graph
def _featurize(self, datapoint: RDKitMol, **kwargs) -> GraphData: """Calculate molecule graph features from RDKit mol object. Parameters ---------- datapoint: rdkit.Chem.rdchem.Mol RDKit mol object. Returns ------- graph: GraphData A molecule graph with some features. """ assert datapoint.GetNumAtoms( ) > 1, "More than one atom should be present in the molecule for this featurizer to work." if 'mol' in kwargs: datapoint = kwargs.get("mol") raise DeprecationWarning( 'Mol is being phased out as a parameter, please pass "datapoint" instead.' ) if self.use_partial_charge: try: datapoint.GetAtomWithIdx(0).GetProp('_GasteigerCharge') except: # If partial charges were not computed try: from rdkit.Chem import AllChem AllChem.ComputeGasteigerCharges(datapoint) except ModuleNotFoundError: raise ImportError( "This class requires RDKit to be installed.") # construct atom (node) feature h_bond_infos = construct_hydrogen_bonding_info(datapoint) atom_features = np.asarray( [ _construct_atom_feature(atom, h_bond_infos, self.use_chirality, self.use_partial_charge) for atom in datapoint.GetAtoms() ], dtype=float, ) # construct edge (bond) index src, dest = [], [] for bond in datapoint.GetBonds(): # add edge list considering a directed graph start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() src += [start, end] dest += [end, start] # construct edge (bond) feature bond_features = None # deafult None if self.use_edges: features = [] for bond in datapoint.GetBonds(): features += 2 * [_construct_bond_feature(bond)] bond_features = np.asarray(features, dtype=float) return GraphData(node_features=atom_features, edge_index=np.asarray([src, dest], dtype=int), edge_features=bond_features)