Exemple #1
0
    def test_graph_data(self):
        num_nodes, num_node_features = 5, 32
        num_edges, num_edge_features = 6, 32
        node_features = np.random.random_sample((num_nodes, num_node_features))
        edge_features = np.random.random_sample((num_edges, num_edge_features))
        edge_index = np.array([
            [0, 1, 2, 2, 3, 4],
            [1, 2, 0, 3, 4, 0],
        ])
        node_pos_features = None

        graph = GraphData(node_features=node_features,
                          edge_index=edge_index,
                          edge_features=edge_features,
                          node_pos_features=node_pos_features)

        assert graph.num_nodes == num_nodes
        assert graph.num_node_features == num_node_features
        assert graph.num_edges == num_edges
        assert graph.num_edge_features == num_edge_features

        # check convert function
        pyg_graph = graph.to_pyg_graph()
        from torch_geometric.data import Data
        assert isinstance(pyg_graph, Data)

        dgl_graph = graph.to_dgl_graph()
        from dgl import DGLGraph
        assert isinstance(dgl_graph, DGLGraph)
    def test_graph_data(self):
        num_nodes, num_node_features = 5, 32
        num_edges, num_edge_features = 6, 32
        node_features = np.random.random_sample((num_nodes, num_node_features))
        edge_features = np.random.random_sample((num_edges, num_edge_features))
        edge_index = np.array([
            [0, 1, 2, 2, 3, 4],
            [1, 2, 0, 3, 4, 0],
        ])
        node_pos_features = None
        # z is kwargs
        z = np.random.random(5)

        graph = GraphData(node_features=node_features,
                          edge_index=edge_index,
                          edge_features=edge_features,
                          node_pos_features=node_pos_features,
                          z=z)

        assert graph.num_nodes == num_nodes
        assert graph.num_node_features == num_node_features
        assert graph.num_edges == num_edges
        assert graph.num_edge_features == num_edge_features
        assert graph.z.shape == z.shape
        assert str(
            graph
        ) == 'GraphData(node_features=[5, 32], edge_index=[2, 6], edge_features=[6, 32], z=[5])'

        # check convert function
        pyg_graph = graph.to_pyg_graph()
        from torch_geometric.data import Data
        assert isinstance(pyg_graph, Data)
        assert tuple(pyg_graph.z.shape) == z.shape

        dgl_graph = graph.to_dgl_graph()
        from dgl import DGLGraph
        assert isinstance(dgl_graph, DGLGraph)