예제 #1
0
    def test_graph_property_edge_case(self):
        G_1, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph())
        Graph.add_node_attr(G_1, "node_feature", x)
        dg_1 = Graph(G_1)
        self.assertEqual(dg_1.num_nodes, G_1.number_of_nodes())
        self.assertEqual(dg_1.num_edges, G_1.number_of_edges())
        self.assertEqual(dg_1.num_node_features, 2)
        self.assertEqual(dg_1.num_edge_features, 0)
        self.assertEqual(dg_1.num_graph_features, 0)
        self.assertEqual(dg_1.num_node_labels, 0)
        self.assertEqual(dg_1.num_edge_labels, 0)
        self.assertEqual(dg_1.num_graph_labels, 0)

        G_2, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph())
        Graph.add_edge_attr(G_2, "edge_label", edge_y.type(torch.FloatTensor))
        Graph.add_node_attr(G_2, "node_label", y.type(torch.FloatTensor))
        Graph.add_graph_attr(G_2, "graph_label",
                             graph_y.type(torch.FloatTensor))

        dg_2 = Graph(G_2)
        self.assertEqual(dg_2.num_node_labels, 1)
        self.assertEqual(dg_2.num_edge_labels, 1)
        self.assertEqual(dg_2.num_graph_labels, 1)
예제 #2
0
    def test_graph_basics(self):
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph()
        )

        dg = Graph(
            node_feature=x, node_label=y, edge_index=edge_index,
            edge_feature=edge_x, edge_label=edge_y,
            graph_feature=graph_x, graph_label=graph_y, directed=True
        )

        for item in [
                "directed",
                "node_feature",
                "node_label",
                "edge_feature",
                "edge_label",
                "graph_feature",
                "graph_label",
                "edge_index",
                "edge_label_index",
                "node_label_index"
                # "is_train"
        ]:
            self.assertEqual(item in dg, True)
        # self.assertEqual(len([key for key in dg]), 11)
        self.assertEqual(len([key for key in dg]), 10)
예제 #3
0
    def test_graph_property_general(self):
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph())

        dg = Graph(node_feature=x,
                   node_label=y,
                   edge_index=edge_index,
                   edge_feature=edge_x,
                   edge_label=edge_y,
                   graph_feature=graph_x,
                   graph_label=graph_y,
                   directed=True)
        self.assertEqual(sorted(dg.keys), [
            "directed", "edge_feature", "edge_index", "edge_label",
            "edge_label_index", "graph_feature", "graph_label", "is_train",
            "node_feature", "node_label", "node_label_index"
        ])
        self.assertEqual(dg.num_nodes, G.number_of_nodes())
        self.assertEqual(dg.num_edges, G.number_of_edges())
        self.assertEqual(dg.num_node_features, 2)
        self.assertEqual(dg.num_edge_features, 2)
        self.assertEqual(dg.num_graph_features, 2)
        self.assertEqual(dg.num_node_labels, np.max(y.data.numpy()) + 1)
        self.assertEqual(dg.num_edge_labels, np.max(edge_y.data.numpy()) + 1)
        self.assertEqual(dg.num_graph_labels, np.max(graph_y.data.numpy()) + 1)
예제 #4
0
    def test_add_feature_nx(self):
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph())
        Graph.add_edge_attr(G, "edge_feature", edge_x)
        Graph.add_edge_attr(G, "edge_label", edge_y)
        Graph.add_node_attr(G, "node_feature", x)
        Graph.add_node_attr(G, "node_label", y)
        Graph.add_graph_attr(G, "graph_feature", graph_x)
        Graph.add_graph_attr(G, "graph_label", graph_y)

        self.assertEqual(len(G.edges.data()), edge_index.shape[1])
        for item in G.edges.data():
            self.assertEqual("edge_feature" in item[2], True)
            self.assertEqual("edge_label" in item[2], True)
            self.assertEqual(len(item[2]["edge_feature"]), 2)
            self.assertEqual(type(item[2]["edge_label"].item()), int)

        for item in G.nodes.data():
            self.assertEqual("node_feature" in item[1], True)
            self.assertEqual("node_label" in item[1], True)
            self.assertEqual(len(item[1]["node_feature"]), 2)
            self.assertEqual(type(item[1]["node_label"].item()), int)

        self.assertEqual(
            G.graph.get("graph_feature").eq(graph_x).sum().item(),
            2,
        )
        self.assertEqual(
            G.graph.get("graph_label").eq(graph_y).sum().item(),
            1,
        )
예제 #5
0
    def test_split_edge_case(self):
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph())
        Graph.add_node_attr(G, "node_label", y)
        Graph.add_edge_attr(G, "edge_label", edge_y)
        dg = Graph(G)

        dg_node = dg.split()
        dg_num_nodes = dg.num_nodes
        node_0 = int(dg_num_nodes * 0.8)
        node_1 = int(dg_num_nodes * 0.1)
        node_2 = dg_num_nodes - node_0 - node_1
        self.assertEqual(dg_node[0].node_label_index.shape[0], node_0)
        self.assertEqual(dg_node[1].node_label_index.shape[0], node_1)
        self.assertEqual(dg_node[2].node_label_index.shape[0], node_2)

        dg_edge = dg.split(task="edge")
        dg_num_edges = dg.num_edges
        edge_0 = int(dg_num_edges * 0.8)
        edge_1 = int(dg_num_edges * 0.1)
        edge_2 = dg_num_edges - edge_0 - edge_1
        self.assertEqual(dg_edge[0].edge_label_index.shape[1], edge_0)
        self.assertEqual(dg_edge[1].edge_label_index.shape[1], edge_1)
        self.assertEqual(dg_edge[2].edge_label_index.shape[1], edge_2)

        dg_link = dg.split(task="link_pred")
        edge_0 = int(dg_num_edges * 0.8)
        edge_1 = int(dg_num_edges * 0.1)
        edge_2 = dg.num_edges - edge_0 - edge_1
        self.assertEqual(dg_link[0].edge_label_index.shape[1], edge_0)
        self.assertEqual(dg_link[1].edge_label_index.shape[1], edge_1)
        self.assertEqual(dg_link[2].edge_label_index.shape[1], edge_2)
예제 #6
0
 def test_graph_basics(self):
     G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
         simple_networkx_graph()
     )
     Graph.add_edge_attr(G, "edge_feature", edge_x)
     Graph.add_edge_attr(G, "edge_label", edge_y)
     Graph.add_node_attr(G, "node_feature", x)
     Graph.add_node_attr(G, "node_label", y)
     Graph.add_graph_attr(G, "graph_feature", graph_x)
     Graph.add_graph_attr(G, "graph_label", graph_y)
     dg = Graph(G)
     self.assertTrue(dg.is_directed())
     self.assertEqual(dg.is_undirected(), False)
     self.assertEqual(len(dg), 10)
     for item in [
             'G',
             'node_feature',
             'node_label',
             'edge_feature',
             'edge_label',
             'graph_feature',
             'graph_label',
             'edge_index',
             'edge_label_index',
             'node_label_index',
     ]:
         self.assertEqual(item in dg, True)
     self.assertEqual(len([key for key in dg]), 10)
예제 #7
0
    def test_dataset_property(self):
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph()
        )
        Graph.add_edge_attr(G, "edge_feature", edge_x)
        Graph.add_edge_attr(G, "edge_label", edge_y)
        Graph.add_node_attr(G, "node_feature", x)
        Graph.add_node_attr(G, "node_label", y)
        Graph.add_graph_attr(G, "graph_feature", graph_x)
        Graph.add_graph_attr(G, "graph_label", graph_y)
        H = G.copy()
        Graph.add_graph_attr(H, "graph_label", torch.tensor([1]))

        graphs = GraphDataset.list_to_graphs([G, H])
        dataset = GraphDataset(graphs)
        self.assertEqual(dataset.num_node_labels, 5)
        self.assertEqual(dataset.num_node_features, 2)
        self.assertEqual(dataset.num_edge_labels, 4)
        self.assertEqual(dataset.num_edge_features, 2)
        self.assertEqual(dataset.num_graph_labels, 2)
        self.assertEqual(dataset.num_graph_features, 2)
        self.assertEqual(dataset.num_labels, 5)  # node task
        dataset = GraphDataset(graphs, task="edge")
        self.assertEqual(dataset.num_labels, 4)
        dataset = GraphDataset(graphs, task="link_pred")
        self.assertEqual(dataset.num_labels, 4)
        dataset = GraphDataset(graphs, task="graph")
        self.assertEqual(dataset.num_labels, 2)
    def test_dataset_property(self):
        _, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph())
        G = Graph(node_feature=x,
                  node_label=y,
                  edge_index=edge_index,
                  edge_feature=edge_x,
                  edge_label=edge_y,
                  graph_feature=graph_x,
                  graph_label=graph_y,
                  directed=True)

        H = deepcopy(G)

        H.graph_label = torch.tensor([1])

        graphs = [G, H]
        dataset = GraphDataset(graphs)
        self.assertEqual(dataset.num_node_labels, 5)
        self.assertEqual(dataset.num_node_features, 2)
        self.assertEqual(dataset.num_edge_labels, 4)
        self.assertEqual(dataset.num_edge_features, 2)
        self.assertEqual(dataset.num_graph_labels, 1)
        self.assertEqual(dataset.num_graph_features, 2)
        self.assertEqual(dataset.num_labels, 5)  # node task
        dataset = GraphDataset(graphs, task="edge")
        self.assertEqual(dataset.num_labels, 4)
        dataset = GraphDataset(graphs, task="link_pred")
        self.assertEqual(dataset.num_labels, 5)
        dataset = GraphDataset(graphs, task="graph")
        self.assertEqual(dataset.num_labels, 1)
예제 #9
0
    def test_graph_property_general(self):
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph()
        )
        Graph.add_edge_attr(G, "edge_feature", edge_x)
        Graph.add_edge_attr(G, "edge_label", edge_y)
        Graph.add_node_attr(G, "node_feature", x)
        Graph.add_node_attr(G, "node_label", y)
        Graph.add_graph_attr(G, "graph_feature", graph_x)
        Graph.add_graph_attr(G, "graph_label", graph_y)

        dg = Graph(G)
        self.assertEqual(
            dg.keys,
            [
                'G',
                'node_feature',
                'node_label',
                'edge_feature',
                'edge_label',
                'graph_feature',
                'graph_label',
                'edge_index',
                'edge_label_index',
                'node_label_index',
            ]
        )
        self.assertEqual(dg.num_nodes, G.number_of_nodes())
        self.assertEqual(dg.num_edges, G.number_of_edges())
        self.assertEqual(dg.num_node_features, 2)
        self.assertEqual(dg.num_edge_features, 2)
        self.assertEqual(dg.num_graph_features, 2)
        self.assertEqual(dg.num_node_labels, np.max(y.data.numpy()) + 1)
        self.assertEqual(dg.num_edge_labels, np.max(edge_y.data.numpy()) + 1)
        self.assertEqual(dg.num_graph_labels, np.max(graph_y.data.numpy()) + 1)
예제 #10
0
    def test_transform(self):
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph()
        )
        Graph.add_edge_attr(G, "edge_feature", edge_x)
        Graph.add_edge_attr(G, "edge_label", edge_y)
        Graph.add_node_attr(G, "node_feature", x)
        Graph.add_node_attr(G, "node_label", y)
        Graph.add_graph_attr(G, "graph_feature", graph_x)
        Graph.add_graph_attr(G, "graph_label", graph_y)

        dg = Graph(G)

        dg_edge_feature = dg.edge_feature.clone()
        dg_node_feature = dg.node_feature.clone()
        dg_graph_feature = dg.graph_feature.clone()

        dg.apply_tensor(
            lambda x: x, "edge_feature", "node_feature", "graph_feature"
        )
        self.assertTrue(torch.all(dg_edge_feature.eq(dg.edge_feature)))
        self.assertTrue(torch.all(dg_node_feature.eq(dg.node_feature)))
        self.assertTrue(torch.all(dg_graph_feature.eq(dg.graph_feature)))

        dg.apply_tensor(
            lambda x: x + 10, "edge_feature", "node_feature", "graph_feature"
        )
        self.assertFalse(torch.all(dg_edge_feature.eq(dg.edge_feature)))
        self.assertFalse(torch.all(dg_node_feature.eq(dg.node_feature)))
        self.assertFalse(torch.all(dg_graph_feature.eq(dg.graph_feature)))

        dg.apply_tensor(
            lambda x: x + 100, "edge_feature", "node_feature", "graph_feature"
        )
        self.assertTrue(
            torch.all(dg.edge_feature.eq(dg_edge_feature + 10 + 100))
        )
        self.assertTrue(
            torch.all(dg.node_feature.eq(dg_node_feature + 10 + 100))
        )
        self.assertTrue(
            torch.all(dg.graph_feature.eq(dg_graph_feature + 10 + 100))
        )

        dg.apply_tensor(
            lambda x: x * 2, "edge_feature", "node_feature", "graph_feature"
        )
        self.assertTrue(
            torch.all(dg.edge_feature.eq((dg_edge_feature + 10 + 100) * 2))
        )
        self.assertTrue(
            torch.all(dg.node_feature.eq((dg_node_feature + 10 + 100) * 2))
        )
        self.assertTrue(
            torch.all(dg.graph_feature.eq((dg_graph_feature + 10 + 100) * 2))
        )
예제 #11
0
    def test_split_edge_case(self):
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph()
        )

        dg = Graph(
            node_label=y,
            edge_label=edge_y,
            edge_index=edge_index,
            directed=True
        )
        dg_node = dg.split()
        dg_num_nodes = dg.num_nodes
        self.assertEqual(
            dg_node[0].node_label_index.shape[0],
            int(dg_num_nodes * 0.8),
        )
        self.assertEqual(
            dg_node[1].node_label_index.shape[0],
            int(dg_num_nodes * 0.1),
        )
        self.assertEqual(
            dg_node[2].node_label_index.shape[0],
            dg.num_nodes
            - int(dg_num_nodes * 0.8)
            - int(dg_num_nodes * 0.1)
        )

        dg_edge = dg.split(task="edge")
        dg_num_edges = dg.num_edges
        edge_0 = int(dg_num_edges * 0.8)
        edge_1 = int(dg_num_edges * 0.1)
        edge_2 = dg.num_edges - edge_0 - edge_1
        self.assertEqual(
            dg_edge[0].edge_label_index.shape[1],
            edge_0
        )
        self.assertEqual(
            dg_edge[1].edge_label_index.shape[1],
            edge_1
        )
        self.assertEqual(
            dg_edge[2].edge_label_index.shape[1],
            edge_2
        )

        dg_link = dg.split(task="link_pred")
        dg_num_edges = dg.num_edges
        edge_0 = int(dg_num_edges * 0.8)
        edge_1 = int(dg_num_edges * 0.1)
        edge_2 = dg.num_edges - edge_0 - edge_1
        self.assertEqual(dg_link[0].edge_label_index.shape[1], edge_0)
        self.assertEqual(dg_link[1].edge_label_index.shape[1], edge_1)
        self.assertEqual(dg_link[2].edge_label_index.shape[1], edge_2)
예제 #12
0
    def test_transform(self):
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph()
        )
        dg = Graph(
            node_feature=x, node_label=y, edge_index=edge_index,
            edge_feature=edge_x, edge_label=edge_y,
            graph_feature=graph_x, graph_label=graph_y, directed=True
        )

        dg_edge_feature = dg.edge_feature.clone()
        dg_node_feature = dg.node_feature.clone()
        dg_graph_feature = dg.graph_feature.clone()

        dg.apply_tensor(
            lambda x: x, "edge_feature", "node_feature", "graph_feature"
        )
        self.assertTrue(torch.all(dg_edge_feature.eq(dg.edge_feature)))
        self.assertTrue(torch.all(dg_node_feature.eq(dg.node_feature)))
        self.assertTrue(torch.all(dg_graph_feature.eq(dg.graph_feature)))

        dg.apply_tensor(
            lambda x: x + 10, "edge_feature", "node_feature", "graph_feature"
        )
        self.assertFalse(torch.all(dg_edge_feature.eq(dg.edge_feature)))
        self.assertFalse(torch.all(dg_node_feature.eq(dg.node_feature)))
        self.assertFalse(torch.all(dg_graph_feature.eq(dg.graph_feature)))

        dg.apply_tensor(
            lambda x: x + 100, "edge_feature", "node_feature", "graph_feature"
        )
        self.assertTrue(
            torch.all(dg.edge_feature.eq(dg_edge_feature + 10 + 100))
        )
        self.assertTrue(
            torch.all(dg.node_feature.eq(dg_node_feature + 10 + 100))
        )
        self.assertTrue(
            torch.all(dg.graph_feature.eq(dg_graph_feature + 10 + 100))
        )

        dg.apply_tensor(
            lambda x: x * 2, "edge_feature", "node_feature", "graph_feature"
        )
        self.assertTrue(
            torch.all(dg.edge_feature.eq((dg_edge_feature + 10 + 100) * 2))
        )
        self.assertTrue(
            torch.all(dg.node_feature.eq((dg_node_feature + 10 + 100) * 2))
        )
        self.assertTrue(
            torch.all(dg.graph_feature.eq((dg_graph_feature + 10 + 100) * 2))
        )
예제 #13
0
 def test_dataset_basic(self):
     G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
         simple_networkx_graph())
     Graph.add_edge_attr(G, "edge_feature", edge_x)
     Graph.add_edge_attr(G, "edge_label", edge_y)
     Graph.add_node_attr(G, "node_feature", x)
     Graph.add_node_attr(G, "node_label", y)
     Graph.add_graph_attr(G, "graph_feature", graph_x)
     Graph.add_graph_attr(G, "graph_label", graph_y)
     H = deepcopy(G)
     graphs = GraphDataset.list_to_graphs([G, H])
     dataset = GraphDataset(graphs)
     self.assertEqual(len(dataset), 2)
예제 #14
0
    def test_graph_property_edge_case(self):
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph()
        )

        dg = Graph(
            node_feature=x, node_label=y.type(torch.FloatTensor),
            edge_index=edge_index, edge_label=edge_y.type(torch.FloatTensor),
            graph_label=graph_y.type(torch.FloatTensor), directed=True
        )
        self.assertEqual(dg.num_node_labels, 1)
        self.assertEqual(dg.num_edge_labels, 1)
        self.assertEqual(dg.num_graph_labels, 1)
예제 #15
0
 def test_repr(self):
     G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
         simple_networkx_graph())
     Graph.add_edge_attr(G, "edge_feature", edge_x)
     Graph.add_edge_attr(G, "edge_label", edge_y)
     Graph.add_node_attr(G, "node_feature", x)
     Graph.add_node_attr(G, "node_label", y)
     Graph.add_graph_attr(G, "graph_feature", graph_x)
     Graph.add_graph_attr(G, "graph_label", graph_y)
     dg = Graph(G)
     self.assertEqual(
         repr(dg), "Graph(G=[], edge_feature=[17, 2], "
         "edge_index=[2, 17], edge_label=[17], edge_label_index=[2, 17], "
         "graph_feature=[1, 2], graph_label=[1], "
         "node_feature=[10, 2], node_label=[10], node_label_index=[10])")
예제 #16
0
 def test_repr(self):
     G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
         simple_networkx_graph())
     dg = Graph(node_feature=x,
                node_label=y,
                edge_index=edge_index,
                edge_feature=edge_x,
                edge_label=edge_y,
                graph_feature=graph_x,
                graph_label=graph_y,
                directed=True)
     self.assertEqual(
         repr(dg), "Graph(directed=[1], edge_feature=[17, 2], "
         "edge_index=[2, 17], edge_label=[17], edge_label_index=[2, 17], "
         "graph_feature=[1, 2], graph_label=[1], is_train=[1], "
         "node_feature=[10, 2], node_label=[10], node_label_index=[10])")
예제 #17
0
 def test_batch_basic(self):
     G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
         simple_networkx_graph())
     Graph.add_edge_attr(G, "edge_feature", edge_x)
     Graph.add_edge_attr(G, "edge_label", edge_y)
     Graph.add_node_attr(G, "node_feature", x)
     Graph.add_node_attr(G, "node_label", y)
     Graph.add_graph_attr(G, "graph_feature", graph_x)
     Graph.add_graph_attr(G, "graph_label", graph_y)
     H = deepcopy(G)
     graphs = [Graph(G), Graph(H)]
     batch = Batch.from_data_list(graphs)
     self.assertEqual(batch.num_graphs, 2)
     self.assertEqual(
         len(batch.node_feature),
         2 * len(graphs[0].node_feature),
     )
    def test_dataset_basic(self):
        _, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph())

        G = Graph(node_feature=x,
                  node_label=y,
                  edge_index=edge_index,
                  edge_feature=edge_x,
                  edge_label=edge_y,
                  graph_feature=graph_x,
                  graph_label=graph_y,
                  directed=True)

        H = deepcopy(G)

        dataset = GraphDataset([G, H])
        self.assertEqual(len(dataset), 2)
예제 #19
0
    def test_clone(self):
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph()
        )

        dg = Graph(
            node_feature=x, node_label=y, edge_index=edge_index,
            edge_feature=edge_x, edge_label=edge_y,
            graph_feature=graph_x, graph_label=graph_y, directed=True
        )
        dg1 = dg.clone()
        self.assertEqual(dg.num_nodes, dg1.num_nodes)
        self.assertEqual(dg.num_edges, dg1.num_edges)
        self.assertEqual(dg.num_node_features, dg1.num_node_features)
        self.assertEqual(dg.num_edge_features, dg1.num_edge_features)
        self.assertEqual(dg.num_node_labels, dg1.num_node_labels)
        self.assertEqual(dg.num_edge_labels, dg1.num_edge_labels)
        self.assertTrue(not id(dg.edge_index) == id(dg1.edge_index))
        self.assertTrue(tuple(dg.keys) == tuple(dg1.keys))
예제 #20
0
    def test_split_edge_case(self):
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph()
        )
        dg = Graph(G)

        dg_node = dg.split()
        dg_num_nodes_reduced = dg.num_nodes - 3
        self.assertEqual(
            dg_node[0].node_label_index.shape[0],
            1 + int(dg_num_nodes_reduced * 0.8),
        )
        self.assertEqual(
            dg_node[1].node_label_index.shape[0],
            1 + int(dg_num_nodes_reduced * 0.1),
        )
        self.assertEqual(
            dg_node[2].node_label_index.shape[0],
            dg.num_nodes
            - 2
            - int(dg_num_nodes_reduced * 0.8)
            - int(dg_num_nodes_reduced * 0.1)
        )

        dg_edge = dg.split(task="edge")
        dg_num_edges_reduced = dg.num_edges - 3
        edge_0 = 1 + int(dg_num_edges_reduced * 0.8)
        edge_1 = 1 + int(dg_num_edges_reduced * 0.1)
        edge_2 = dg.num_edges - edge_0 - edge_1
        self.assertEqual(dg_edge[0].edge_label_index.shape[1], edge_0)
        self.assertEqual(dg_edge[1].edge_label_index.shape[1], edge_1)
        self.assertEqual(dg_edge[2].edge_label_index.shape[1], edge_2)

        dg_link = dg.split(task="link_pred")
        dg_num_edges_reduced = dg.num_edges - 3
        edge_0 = 1 + int(dg_num_edges_reduced * 0.8)
        edge_1 = 1 + int(dg_num_edges_reduced * 0.1)
        edge_2 = dg.num_edges - edge_0 - edge_1
        self.assertEqual(dg_link[0].edge_label_index.shape[1], edge_0)
        self.assertEqual(dg_link[1].edge_label_index.shape[1], edge_1)
        self.assertEqual(dg_link[2].edge_label_index.shape[1], edge_2)
예제 #21
0
    def test_clone(self):
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph())
        Graph.add_edge_attr(G, "edge_feature", edge_x)
        Graph.add_edge_attr(G, "edge_label", edge_y)
        Graph.add_node_attr(G, "node_feature", x)
        Graph.add_node_attr(G, "node_label", y)
        Graph.add_graph_attr(G, "graph_feature", graph_x)
        Graph.add_graph_attr(G, "graph_label", graph_y)

        dg = Graph(G)
        dg1 = dg.clone()
        self.assertEqual(dg.num_nodes, dg1.num_nodes)
        self.assertEqual(dg.num_edges, dg1.num_edges)
        self.assertEqual(dg.num_node_features, dg1.num_node_features)
        self.assertEqual(dg.num_edge_features, dg1.num_edge_features)
        self.assertEqual(dg.num_node_labels, dg1.num_node_labels)
        self.assertEqual(dg.num_edge_labels, dg1.num_edge_labels)
        self.assertTrue(not id(dg.G) == id(dg1.G))
        self.assertTrue(not id(dg.edge_index) == id(dg1.edge_index))
        self.assertTrue(tuple(dg.keys) == tuple(dg1.keys))
    def test_dataset_split_custom(self):
        # transductive split with node task (self defined dataset)
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph())
        Graph.add_edge_attr(G, "edge_feature", edge_x)
        Graph.add_edge_attr(G, "edge_label", edge_y)
        Graph.add_node_attr(G, "node_feature", x)
        Graph.add_node_attr(G, "node_label", y)
        Graph.add_graph_attr(G, "graph_feature", graph_x)
        Graph.add_graph_attr(G, "graph_label", graph_y)

        num_nodes = len(list(G.nodes))
        nodes_train = torch.tensor(list(G.nodes)[:int(0.3 * num_nodes)])
        nodes_val = torch.tensor(
            list(G.nodes)[int(0.3 * num_nodes):int(0.6 * num_nodes)])
        nodes_test = torch.tensor(list(G.nodes)[int(0.6 * num_nodes):])

        graph_train = Graph(node_feature=x,
                            node_label=y,
                            edge_index=edge_index,
                            node_label_index=nodes_train,
                            directed=True)
        graph_val = Graph(node_feature=x,
                          node_label=y,
                          edge_index=edge_index,
                          node_label_index=nodes_val,
                          directed=True)
        graph_test = Graph(node_feature=x,
                           node_label=y,
                           edge_index=edge_index,
                           node_label_index=nodes_test,
                           directed=True)

        graphs_train = [graph_train]
        graphs_val = [graph_val]
        graphs_test = [graph_test]

        dataset_train, dataset_val, dataset_test = (GraphDataset(graphs_train,
                                                                 task='node'),
                                                    GraphDataset(graphs_val,
                                                                 task='node'),
                                                    GraphDataset(graphs_test,
                                                                 task='node'))

        self.assertEqual(dataset_train[0].node_label_index.tolist(),
                         list(range(int(0.3 * num_nodes))))
        self.assertEqual(
            dataset_val[0].node_label_index.tolist(),
            list(range(int(0.3 * num_nodes), int(0.6 * num_nodes))))
        self.assertEqual(dataset_test[0].node_label_index.tolist(),
                         list(range(int(0.6 * num_nodes), num_nodes)))

        # transductive split with link_pred task (train/val split)
        edges = list(G.edges)
        num_edges = len(edges)
        edges_train = edges[:int(0.7 * num_edges)]
        edges_val = edges[int(0.7 * num_edges):]
        link_size_list = [len(edges_train), len(edges_val)]

        # generate pseudo pos and neg edges, they may overlap here
        train_pos = torch.LongTensor(edges_train).permute(1, 0)
        val_pos = torch.LongTensor(edges_val).permute(1, 0)
        val_neg = torch.randint(high=10, size=val_pos.shape, dtype=torch.int64)
        val_neg_double = torch.cat((val_neg, val_neg), dim=1)

        num_train = len(edges_train)
        num_val = len(edges_val)

        graph_train = Graph(node_feature=x,
                            edge_index=edge_index,
                            edge_feature=edge_x,
                            directed=True,
                            edge_label_index=train_pos)

        graph_val = Graph(node_feature=x,
                          edge_index=edge_index,
                          edge_feature=edge_x,
                          directed=True,
                          edge_label_index=val_pos,
                          negative_edge=val_neg_double)

        graphs_train = [graph_train]
        graphs_val = [graph_val]

        dataset_train, dataset_val = (GraphDataset(graphs_train,
                                                   task='link_pred',
                                                   resample_negatives=True),
                                      GraphDataset(
                                          graphs_val,
                                          task='link_pred',
                                          edge_negative_sampling_ratio=2))

        self.assertEqual(dataset_train[0].edge_label_index.shape[1],
                         2 * link_size_list[0])
        self.assertEqual(dataset_train[0].edge_label.shape[0],
                         2 * link_size_list[0])
        self.assertEqual(dataset_val[0].edge_label_index.shape[1],
                         val_pos.shape[1] + val_neg_double.shape[1])
        self.assertEqual(dataset_val[0].edge_label.shape[0],
                         val_pos.shape[1] + val_neg_double.shape[1])
        self.assertTrue(
            torch.equal(dataset_train[0].edge_label_index[:, :num_train],
                        train_pos))
        self.assertTrue(
            torch.equal(dataset_val[0].edge_label_index[:, :num_val], val_pos))
        self.assertTrue(
            torch.equal(dataset_val[0].edge_label_index[:, num_val:],
                        val_neg_double))

        dataset_train.resample_negatives = False
        self.assertTrue(
            torch.equal(dataset_train[0].edge_label_index,
                        dataset_train[0].edge_label_index))

        # transductive split with link_pred task with edge label
        edge_label_train = torch.LongTensor([1, 2, 3, 2, 1, 1, 2, 3, 2, 0, 0])
        edge_label_val = torch.LongTensor([1, 2, 3, 2, 1, 0])

        graph_train = Graph(node_feature=x,
                            edge_index=edge_index,
                            directed=True,
                            edge_label_index=train_pos,
                            edge_label=edge_label_train)

        graph_val = Graph(node_feature=x,
                          edge_index=edge_index,
                          directed=True,
                          edge_label_index=val_pos,
                          negative_edge=val_neg,
                          edge_label=edge_label_val)

        graphs_train = [graph_train]
        graphs_val = [graph_val]

        dataset_train, dataset_val = (GraphDataset(graphs_train,
                                                   task='link_pred'),
                                      GraphDataset(graphs_val,
                                                   task='link_pred'))

        self.assertTrue(
            torch.equal(dataset_train[0].edge_label_index,
                        dataset_train[0].edge_label_index))

        self.assertTrue(
            torch.equal(dataset_train[0].edge_label[:num_train],
                        edge_label_train))

        self.assertTrue(
            torch.equal(dataset_val[0].edge_label[:num_val], edge_label_val))

        # Multiple graph tensor backend link prediction (inductive)
        pyg_dataset = Planetoid('./cora', 'Cora')
        x = pyg_dataset[0].x
        y = pyg_dataset[0].y
        edge_index = pyg_dataset[0].edge_index
        row, col = edge_index
        mask = row < col
        row, col = row[mask], col[mask]
        edge_index = torch.stack([row, col], dim=0)
        edge_index = torch.cat(
            [edge_index, torch.flip(edge_index, [0])], dim=1)

        graphs = [
            Graph(node_feature=x,
                  node_label=y,
                  edge_index=edge_index,
                  directed=False)
        ]
        graphs = [copy.deepcopy(graphs[0]) for _ in range(10)]

        edge_label_index = graphs[0].edge_label_index
        dataset = GraphDataset(graphs,
                               task='link_pred',
                               edge_message_ratio=0.6,
                               edge_train_mode="all")
        datasets = {}
        datasets['train'], datasets['val'], datasets['test'] = dataset.split(
            transductive=False, split_ratio=[0.85, 0.05, 0.1])
        edge_label_index_split = (
            datasets['train'][0].edge_label_index[:,
                                                  0:edge_label_index.shape[1]])

        self.assertTrue(torch.equal(edge_label_index, edge_label_index_split))

        # transductive split with node task (pytorch geometric dataset)
        pyg_dataset = Planetoid("./cora", "Cora")
        ds = pyg_to_dicts(pyg_dataset, task="cora")
        graphs = [Graph(**item) for item in ds]
        split_ratio = [0.3, 0.3, 0.4]
        node_size_list = [0 for i in range(len(split_ratio))]
        for graph in graphs:
            custom_splits = [[] for i in range(len(split_ratio))]
            split_offset = 0
            num_nodes = graph.num_nodes
            shuffled_node_indices = torch.randperm(graph.num_nodes)
            for i, split_ratio_i in enumerate(split_ratio):
                if i != len(split_ratio) - 1:
                    num_split_i = int(split_ratio_i * num_nodes)
                    nodes_split_i = (
                        shuffled_node_indices[split_offset:split_offset +
                                              num_split_i])
                    split_offset += num_split_i
                else:
                    nodes_split_i = shuffled_node_indices[split_offset:]

                custom_splits[i] = nodes_split_i
                node_size_list[i] += len(nodes_split_i)
            graph.custom = {"general_splits": custom_splits}

        node_feature = graphs[0].node_feature
        edge_index = graphs[0].edge_index
        directed = graphs[0].directed

        graph_train = Graph(
            node_feature=node_feature,
            edge_index=edge_index,
            directed=directed,
            node_label_index=graphs[0].custom["general_splits"][0])

        graph_val = Graph(
            node_feature=node_feature,
            edge_index=edge_index,
            directed=directed,
            node_label_index=graphs[0].custom["general_splits"][1])

        graph_test = Graph(
            node_feature=node_feature,
            edge_index=edge_index,
            directed=directed,
            node_label_index=graphs[0].custom["general_splits"][2])

        train_dataset = GraphDataset([graph_train], task="node")
        val_dataset = GraphDataset([graph_val], task="node")
        test_dataset = GraphDataset([graph_test], task="node")

        self.assertEqual(len(train_dataset[0].node_label_index),
                         node_size_list[0])
        self.assertEqual(len(val_dataset[0].node_label_index),
                         node_size_list[1])
        self.assertEqual(len(test_dataset[0].node_label_index),
                         node_size_list[2])

        # transductive split with edge task
        pyg_dataset = Planetoid("./cora", "Cora")
        graphs_g = GraphDataset.pyg_to_graphs(pyg_dataset)
        ds = pyg_to_dicts(pyg_dataset, task="cora")
        graphs = [Graph(**item) for item in ds]
        split_ratio = [0.3, 0.3, 0.4]
        edge_size_list = [0 for i in range(len(split_ratio))]
        for i, graph in enumerate(graphs):
            custom_splits = [[] for i in range(len(split_ratio))]
            split_offset = 0
            edges = list(graphs_g[i].G.edges)
            num_edges = graph.num_edges
            random.shuffle(edges)
            for i, split_ratio_i in enumerate(split_ratio):
                if i != len(split_ratio) - 1:
                    num_split_i = int(split_ratio_i * num_edges)
                    edges_split_i = (edges[split_offset:split_offset +
                                           num_split_i])
                    split_offset += num_split_i
                else:
                    edges_split_i = edges[split_offset:]

                custom_splits[i] = edges_split_i
                edge_size_list[i] += len(edges_split_i)
            graph.custom = {"general_splits": custom_splits}

        node_feature = graphs[0].node_feature
        edge_index = graphs[0].edge_index
        directed = graphs[0].directed

        train_index = torch.tensor(
            graphs[0].custom["general_splits"][0]).permute(1, 0)
        train_index = torch.cat((train_index, train_index), dim=1)
        val_index = torch.tensor(
            graphs[0].custom["general_splits"][1]).permute(1, 0)
        val_index = torch.cat((val_index, val_index), dim=1)
        test_index = torch.tensor(
            graphs[0].custom["general_splits"][2]).permute(1, 0)
        test_index = torch.cat((test_index, test_index), dim=1)

        graph_train = Graph(node_feature=node_feature,
                            edge_index=edge_index,
                            directed=directed,
                            edge_label_index=train_index)

        graph_val = Graph(node_feature=node_feature,
                          edge_index=edge_index,
                          directed=directed,
                          edge_label_index=val_index)

        graph_test = Graph(node_feature=node_feature,
                           edge_index=edge_index,
                           directed=directed,
                           edge_label_index=test_index)

        train_dataset = GraphDataset([graph_train], task="edge")
        val_dataset = GraphDataset([graph_val], task="edge")
        test_dataset = GraphDataset([graph_test], task="edge")

        self.assertEqual(train_dataset[0].edge_label_index.shape[1],
                         2 * edge_size_list[0])
        self.assertEqual(val_dataset[0].edge_label_index.shape[1],
                         2 * edge_size_list[1])
        self.assertEqual(test_dataset[0].edge_label_index.shape[1],
                         2 * edge_size_list[2])

        # inductive split with graph task
        pyg_dataset = TUDataset("./enzymes", "ENZYMES")
        ds = pyg_to_dicts(pyg_dataset)
        graphs = [Graph(**item) for item in ds]
        num_graphs = len(graphs)
        split_ratio = [0.3, 0.3, 0.4]
        graph_size_list = []
        split_offset = 0
        custom_split_graphs = []
        for i, split_ratio_i in enumerate(split_ratio):
            if i != len(split_ratio) - 1:
                num_split_i = int(split_ratio_i * num_graphs)
                custom_split_graphs.append(graphs[split_offset:split_offset +
                                                  num_split_i])
                split_offset += num_split_i
                graph_size_list.append(num_split_i)
            else:
                custom_split_graphs.append(graphs[split_offset:])
                graph_size_list.append(len(graphs[split_offset:]))
        dataset = GraphDataset(graphs,
                               task="graph",
                               custom_split_graphs=custom_split_graphs)
        split_res = dataset.split(transductive=False)
        self.assertEqual(graph_size_list[0], len(split_res[0]))
        self.assertEqual(graph_size_list[1], len(split_res[1]))
        self.assertEqual(graph_size_list[2], len(split_res[2]))