def test_hetero_graph_basics(self):
        G = generate_simple_hete_graph()
        hete = HeteroGraph(G)
        hete = HeteroGraph(node_feature=hete.node_feature,
                           node_label=hete.node_label,
                           edge_feature=hete.edge_feature,
                           edge_label=hete.edge_label,
                           edge_index=hete.edge_index,
                           directed=True)

        self.assertEqual(hete.num_node_features('n1'), 10)
        self.assertEqual(hete.num_node_features('n2'), 12)
        self.assertEqual(hete.num_edge_features(('n1', 'e1', 'n1')), 8)
        self.assertEqual(hete.num_edge_features(('n1', 'e2', 'n2')), 12)
        self.assertEqual(hete.num_nodes('n1'), 4)
        self.assertEqual(hete.num_nodes('n2'), 5)
        self.assertEqual(len(hete.node_types), 2)
        self.assertEqual(len(hete.edge_types), 2)

        message_types = hete.message_types
        self.assertEqual(len(message_types), 7)
        self.assertEqual(hete.num_node_labels('n1'), 2)
        self.assertEqual(hete.num_node_labels('n2'), 2)
        self.assertEqual(hete.num_edge_labels(('n1', 'e1', 'n1')), 2)
        self.assertEqual(hete.num_edge_labels(('n1', 'e2', 'n2')), 2)
        self.assertEqual(hete.num_edges(message_types[0]), 3)
        self.assertEqual(len(hete.node_label_index), 2)
    def test_resample_disjoint_heterogeneous(self):
        G = generate_dense_hete_dataset()
        hete = HeteroGraph(G)
        hete = HeteroGraph(node_feature=hete.node_feature,
                           node_label=hete.node_label,
                           edge_feature=hete.edge_feature,
                           edge_label=hete.edge_label,
                           edge_index=hete.edge_index,
                           directed=True)
        graphs = [hete]
        dataset = GraphDataset(graphs,
                               task="link_pred",
                               edge_train_mode="disjoint",
                               edge_message_ratio=0.8,
                               resample_disjoint=True,
                               resample_disjoint_period=1)
        dataset_train, _, _ = dataset.split(split_ratio=[0.5, 0.2, 0.3])
        graph_train_first = dataset_train[0]
        graph_train_second = dataset_train[0]

        for message_type in graph_train_first.edge_index:
            self.assertEqual(
                graph_train_first.edge_label_index[message_type].shape[1],
                graph_train_second.edge_label_index[message_type].shape[1])
            self.assertEqual(graph_train_first.edge_label[message_type].shape,
                             graph_train_second.edge_label[message_type].shape)
    def test_hetero_graph_batch(self):
        G = generate_simple_hete_graph()
        hete = HeteroGraph(G)
        hete = HeteroGraph(
            node_feature=hete.node_feature,
            node_label=hete.node_label,
            edge_feature=hete.edge_feature,
            edge_label=hete.edge_label,
            edge_index=hete.edge_index,
            directed=True
        )

        heteGraphDataset = []
        for _ in range(30):
            heteGraphDataset.append(hete.clone())
        dataloader = DataLoader(
            heteGraphDataset,
            collate_fn=Batch.collate(),
            batch_size=3,
            shuffle=True,
        )

        self.assertEqual(len(dataloader), math.ceil(30 / 3))
        for data in dataloader:
            self.assertEqual(data.num_graphs, 3)
Exemple #4
0
    def _custom_hete_split_link_pred_disjoint(self, graph_train):
        objective_edges = graph_train.disjoint_split

        nodes_dict = {}
        for node in graph_train.G.nodes(data=True):
            nodes_dict[node[0]] = node[1]["node_type"]

        edges_dict = {}
        objective_edges_dict = {}

        for edge in graph_train.G.edges(data=True):
            edge_type = edge[-1]["edge_type"]
            head_type = nodes_dict[edge[0]]
            tail_type = nodes_dict[edge[1]]
            message_type = (head_type, edge_type, tail_type)
            if message_type not in edges_dict:
                edges_dict[message_type] = []
            edges_dict[message_type].append(edge)

        for edge in objective_edges:
            edge_type = edge[-1]["edge_type"]
            head_type = nodes_dict[edge[0]]
            tail_type = nodes_dict[edge[1]]
            message_type = (head_type, edge_type, tail_type)
            if message_type not in objective_edges_dict:
                objective_edges_dict[message_type] = []
            objective_edges_dict[message_type].append(edge)

        message_edges = []
        for edge_type in edges_dict:
            if edge_type in objective_edges_dict:
                edges_no_info = [edge[:-1] for edge in edges_dict[edge_type]]
                objective_edges_no_info = [
                    edge[:-1] for edge in objective_edges_dict[edge_type]
                ]
                message_edges_no_info = set(edges_no_info) - set(
                    objective_edges_no_info)
                message_edges += [(edge[0], edge[1],
                                   graph_train.G.edges[edge[0], edge[1]])
                                  for edge in message_edges_no_info]
            else:
                message_edges += edges_dict[edge_type]

        # update objective edges
        for edge_type in edges_dict:
            if edge_type not in objective_edges_dict:
                objective_edges += edges_dict[edge_type]

        graph_train = HeteroGraph(graph_train._edge_subgraph_with_isonodes(
            graph_train.G,
            message_edges,
        ),
                                  negative_edges=graph_train.negative_edges)

        graph_train._create_label_link_pred(
            graph_train, objective_edges, list(graph_train.G.nodes(data=True)))

        return graph_train
 def test_hetero_graph_none(self):
     G = generate_simple_hete_graph(add_edge_type=False)
     hete = HeteroGraph(G)
     hete = HeteroGraph(node_feature=hete.node_feature,
                        node_label=hete.node_label,
                        edge_feature=hete.edge_feature,
                        edge_label=hete.edge_label,
                        edge_index=hete.edge_index,
                        directed=True)
     message_types = hete.message_types
     for message_type in message_types:
         self.assertEqual(message_type[1], None)
Exemple #6
0
    def test_hetero_multigraph_split(self):
        G = generate_dense_hete_multigraph()
        hete = HeteroGraph(G)

        # node
        hete_node = hete.split(task='node')
        for node_type in hete.node_label_index:
            num_nodes = len(hete.node_label_index[node_type])
            num_nodes_reduced = num_nodes - 3
            node_0 = 1 + int(num_nodes_reduced * 0.8)
            node_1 = 1 + int(num_nodes_reduced * 0.1)
            node_2 = num_nodes - node_0 - node_1
            self.assertEqual(len(hete_node[0].node_label_index[node_type]),
                             node_0)
            self.assertEqual(len(hete_node[1].node_label_index[node_type]),
                             node_1)
            self.assertEqual(len(hete_node[2].node_label_index[node_type]),
                             node_2)

        # edge
        hete_edge = hete.split(task='edge')
        for edge_type in hete.edge_label_index:
            num_edges = int(hete.edge_label_index[edge_type].shape[1])
            num_edges_reduced = num_edges - 3
            edge_0 = 1 + int(num_edges_reduced * 0.8)
            edge_1 = 1 + int(num_edges_reduced * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(hete_edge[0].edge_label_index[edge_type].shape[1],
                             edge_0)
            self.assertEqual(hete_edge[1].edge_label_index[edge_type].shape[1],
                             edge_1)
            self.assertEqual(hete_edge[2].edge_label_index[edge_type].shape[1],
                             edge_2)

        # link prediction
        hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2])
        # calculate the expected edge num for each splitted subgraph
        hete_link_train_edge_num, hete_link_val_edge_num, hete_link_test_edge_num = 0, 0, 0
        for key, val in hete.edge_label_index.items():
            val_length = val.shape[1]
            val_length_reduced = val_length - 3
            hete_link_train_edge_num += 1 + int(0.5 * val_length_reduced)
            hete_link_val_edge_num += 1 + int(0.3 * val_length_reduced)
            hete_link_test_edge_num += \
                val_length - 2 - int(0.5 * val_length_reduced) - int(0.3 * val_length_reduced)

        self.assertEqual(len(hete_link[0].edge_label),
                         hete_link_train_edge_num)
        self.assertEqual(len(hete_link[1].edge_label), hete_link_val_edge_num)
        self.assertEqual(len(hete_link[2].edge_label), hete_link_test_edge_num)
Exemple #7
0
    def test_hetero_graph_batch(self):
        G = generate_simple_hete_graph()
        hete = HeteroGraph(G)

        heteGraphDataset = []
        for i in range(30):
            heteGraphDataset.append(hete.clone())
        dataloader = DataLoader(heteGraphDataset,
                                collate_fn=Batch.collate(),
                                batch_size=3,
                                shuffle=True)

        self.assertEqual(len(dataloader), math.ceil(30 / 3))
        for data in dataloader:
            self.assertEqual(data.num_graphs, 3)
Exemple #8
0
def main():
    args = arg_parse()

    edge_train_mode = args.mode
    print('edge train mode: {}'.format(edge_train_mode))

    G = nx.read_gpickle(args.data_path)
    print(G.number_of_edges())
    print('Each node has node ID (n_id). Example: ', G.nodes[0])
    print(
        'Each edge has edge ID (id) and categorical label (e_label). Example: ',
        G[0][5871])

    # find num edge types
    max_label = 0
    labels = []
    for u, v, edge_key in G.edges:
        l = G[u][v][edge_key]['e_label']
        if not l in labels:
            labels.append(l)
    # labels are consecutive (0-17)
    num_edge_types = len(labels)

    H = WN_transform(G, num_edge_types)
    # The nodes in the graph have the features: node_feature and node_type (just one node type "n1" here)
    for node in H.nodes(data=True):
        print(node)
        break
    # The edges in the graph have the features: edge_feature and edge_type ("0" - "17" here)
    for edge in H.edges(data=True):
        print(edge)
        break

    hete = HeteroGraph(H)

    dataset = GraphDataset([hete], task='link_pred')
    dataset_train, dataset_val, dataset_test = dataset.split(
        transductive=True, split_ratio=[0.8, 0.1, 0.1])
    train_loader = DataLoader(dataset_train,
                              collate_fn=Batch.collate(),
                              batch_size=1)
    val_loader = DataLoader(dataset_val,
                            collate_fn=Batch.collate(),
                            batch_size=1)
    test_loader = DataLoader(dataset_test,
                             collate_fn=Batch.collate(),
                             batch_size=1)
    dataloaders = {
        'train': train_loader,
        'val': val_loader,
        'test': test_loader
    }

    hidden_size = 32
    model = HeteroNet(hete, hidden_size, 0.2).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=0.001,
                                 weight_decay=5e-4)

    train(model, dataloaders, optimizer, args)
Exemple #9
0
    def _custom_hete_split_link_pred(self):
        split_num = len(self.graphs[0].general_splits)
        split_graphs = [[] for x in range(split_num)]
        for i in range(len(self.graphs)):
            graph = self.graphs[i]
            graph_train = copy.copy(graph)

            edges_train = graph_train.general_splits[0]
            edges_val = graph_train.general_splits[1]

            graph_train = HeteroGraph(
                graph_train._edge_subgraph_with_isonodes(
                    graph_train.G,
                    edges_train,
                ),
                disjoint_split=(graph_train.disjoint_split),
                negative_edges=(graph_train.negative_edges))

            graph_val = copy.copy(graph_train)
            if split_num == 3:
                graph_test = copy.copy(graph)
                edges_test = graph.general_splits[2]
                graph_test = HeteroGraph(
                    graph_test._edge_subgraph_with_isonodes(
                        graph_test.G, edges_train + edges_val),
                    negative_edges=(graph_test.negative_edges))

            graph_train._create_label_link_pred(
                graph_train, edges_train, list(graph_train.G.nodes(data=True)))
            graph_val._create_label_link_pred(
                graph_val, edges_val, list(graph_val.G.nodes(data=True)))

            if split_num == 3:
                graph_test._create_label_link_pred(
                    graph_test, edges_test,
                    list(graph_test.G.nodes(data=True)))

            split_graphs[0].append(graph_train)
            split_graphs[1].append(graph_val)
            if split_num == 3:
                split_graphs[2].append(graph_test)

        return split_graphs
Exemple #10
0
    def test_hetero_graph_basics(self):
        G = generate_simple_hete_graph()
        hete = HeteroGraph(G)

        self.assertEqual(hete.get_num_node_features('n1'), 10)
        self.assertEqual(hete.get_num_node_features('n2'), 12)
        self.assertEqual(hete.get_num_edge_features('e1'), 8)
        self.assertEqual(hete.get_num_edge_features('e2'), 12)
        self.assertEqual(hete.get_num_nodes('n1'), 4)
        self.assertEqual(hete.get_num_nodes('n2'), 5)
        self.assertEqual(len(hete.node_types), 2)
        self.assertEqual(len(hete.edge_types), 2)

        message_types = hete.message_types
        self.assertEqual(len(message_types), 7)
        self.assertEqual(hete.get_num_node_labels('n1'), 2)
        self.assertEqual(hete.get_num_node_labels('n2'), 2)
        self.assertEqual(hete.get_num_edge_labels('e1'), 3)
        self.assertEqual(hete.get_num_edge_labels('e2'), 3)
        self.assertEqual(hete.get_num_edges(message_types[0]), 3)
        self.assertEqual(len(hete.node_label_index), 2)
    def test_dataset_hetero_graph_split(self):
        G = generate_dense_hete_dataset()
        hete = HeteroGraph(G)
        hete = HeteroGraph(node_feature=hete.node_feature,
                           node_label=hete.node_label,
                           edge_feature=hete.edge_feature,
                           edge_label=hete.edge_label,
                           edge_index=hete.edge_index,
                           directed=True)

        # node
        dataset = GraphDataset([hete], task="node")
        split_res = dataset.split()
        for node_type in hete.node_label_index:
            num_nodes = int(len(hete.node_label_index[node_type]))
            node_0 = int(num_nodes * 0.8)
            node_1 = int(num_nodes * 0.1)
            node_2 = num_nodes - node_0 - node_1

            self.assertEqual(
                len(split_res[0][0].node_label_index[node_type]),
                node_0,
            )

            self.assertEqual(
                len(split_res[1][0].node_label_index[node_type]),
                node_1,
            )

            self.assertEqual(
                len(split_res[2][0].node_label_index[node_type]),
                node_2,
            )

        # node with specified split type
        dataset = GraphDataset([hete], task="node")
        node_split_types = ["n1"]
        split_res = dataset.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = int(len(hete.node_label_index[node_type]))
                node_0 = int(num_nodes * 0.8)
                node_1 = int(num_nodes * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]),
                    node_0,
                )

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]),
                    node_1,
                )

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]),
                    node_2,
                )
            else:
                num_nodes = int(len(hete.node_label_index[node_type]))
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]),
                    num_nodes,
                )

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]),
                    num_nodes,
                )

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]),
                    num_nodes,
                )

        # node with specified split type (string mode)
        dataset = GraphDataset([hete], task="node")
        node_split_types = "n1"
        split_res = dataset.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = int(len(hete.node_label_index[node_type]))
                node_0 = int(num_nodes * 0.8)
                node_1 = int(num_nodes * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]),
                    node_0,
                )

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]),
                    node_1,
                )

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]),
                    node_2,
                )
            else:
                num_nodes = int(len(hete.node_label_index[node_type]))
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]),
                    num_nodes,
                )

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]),
                    num_nodes,
                )

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]),
                    num_nodes,
                )

        # edge
        dataset = GraphDataset([hete], task="edge")
        split_res = dataset.split()
        for edge_type in hete.edge_label_index:
            num_edges = hete.edge_label_index[edge_type].shape[1]
            edge_0 = int(num_edges * 0.8)
            edge_1 = int(num_edges * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(
                split_res[0][0].edge_label_index[edge_type].shape[1],
                edge_0,
            )

            self.assertEqual(
                split_res[1][0].edge_label_index[edge_type].shape[1],
                edge_1,
            )

            self.assertEqual(
                split_res[2][0].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # edge with specified split type
        dataset = GraphDataset([hete], task="edge")
        edge_split_types = [("n1", "e1", "n1"), ("n1", "e2", "n2")]
        split_res = dataset.split(split_types=edge_split_types)
        for edge_type in hete.edge_label_index:
            if edge_type in edge_split_types:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                edge_0 = int(num_edges * 0.8)
                edge_1 = int(num_edges * 0.1)
                edge_2 = num_edges - edge_0 - edge_1
                self.assertEqual(
                    split_res[0][0].edge_label_index[edge_type].shape[1],
                    edge_0,
                )

                self.assertEqual(
                    split_res[1][0].edge_label_index[edge_type].shape[1],
                    edge_1,
                )

                self.assertEqual(
                    split_res[2][0].edge_label_index[edge_type].shape[1],
                    edge_2,
                )
            else:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                self.assertEqual(
                    split_res[0][0].edge_label_index[edge_type].shape[1],
                    num_edges,
                )

                self.assertEqual(
                    split_res[1][0].edge_label_index[edge_type].shape[1],
                    num_edges,
                )

                self.assertEqual(
                    split_res[2][0].edge_label_index[edge_type].shape[1],
                    num_edges,
                )

        # link_pred
        dataset = GraphDataset([hete], task="link_pred")
        split_res = dataset.split(transductive=True)
        for edge_type in hete.edge_label_index:
            num_edges = hete.edge_label_index[edge_type].shape[1]
            edge_0 = 2 * int(0.8 * num_edges)
            edge_1 = 2 * int(0.1 * num_edges)
            edge_2 = 2 * (num_edges - int(0.8 * num_edges) -
                          int(0.1 * num_edges))
            self.assertEqual(
                split_res[0][0].edge_label_index[edge_type].shape[1], edge_0)
            self.assertEqual(
                split_res[1][0].edge_label_index[edge_type].shape[1], edge_1)
            self.assertEqual(
                split_res[2][0].edge_label_index[edge_type].shape[1], edge_2)

        # link_pred with specified split type
        dataset = GraphDataset([hete], task="link_pred")
        link_split_types = [("n1", "e1", "n1"), ("n1", "e2", "n2")]
        split_res = dataset.split(transductive=True,
                                  split_types=link_split_types)

        for edge_type in hete.edge_label_index:
            if edge_type in link_split_types:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                edge_0 = 2 * int(0.8 * num_edges)
                edge_1 = 2 * int(0.1 * num_edges)
                edge_2 = 2 * (num_edges - int(0.8 * num_edges) -
                              int(0.1 * num_edges))
                self.assertEqual(
                    split_res[0][0].edge_label_index[edge_type].shape[1],
                    edge_0)
                self.assertEqual(
                    split_res[1][0].edge_label_index[edge_type].shape[1],
                    edge_1)
                self.assertEqual(
                    split_res[2][0].edge_label_index[edge_type].shape[1],
                    edge_2)
            else:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                self.assertEqual(
                    split_res[0][0].edge_label_index[edge_type].shape[1],
                    num_edges)
                self.assertEqual(
                    split_res[1][0].edge_label_index[edge_type].shape[1],
                    num_edges)
                self.assertEqual(
                    split_res[2][0].edge_label_index[edge_type].shape[1],
                    num_edges)

        # link_pred + disjoint
        dataset = GraphDataset(
            [hete],
            task="link_pred",
            edge_train_mode="disjoint",
            edge_message_ratio=0.5,
        )
        split_res = dataset.split(
            transductive=True,
            split_ratio=[0.6, 0.2, 0.2],
        )
        for edge_type in hete.edge_label_index:
            num_edges = hete.edge_label_index[edge_type].shape[1]
            edge_0 = int(0.6 * num_edges)
            edge_0 = 2 * (edge_0 - int(0.5 * edge_0))
            edge_1 = 2 * int(0.2 * num_edges)
            edge_2 = 2 * (num_edges - int(0.6 * num_edges) -
                          int(0.2 * num_edges))

            self.assertEqual(
                split_res[0][0].edge_label_index[edge_type].shape[1],
                edge_0,
            )
            self.assertEqual(
                split_res[1][0].edge_label_index[edge_type].shape[1],
                edge_1,
            )
            self.assertEqual(
                split_res[2][0].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # link pred with edge_split_mode set to "exact"
        dataset = GraphDataset([hete],
                               task="link_pred",
                               edge_split_mode="approximate")
        split_res = dataset.split(transductive=True)
        hete_link_train_edge_num = 0
        hete_link_test_edge_num = 0
        hete_link_val_edge_num = 0
        num_edges = 0
        for edge_type in hete.edge_label_index:
            num_edges += hete.edge_label_index[edge_type].shape[1]
            if edge_type in split_res[0][0].edge_label_index:
                hete_link_train_edge_num += (
                    split_res[0][0].edge_label_index[edge_type].shape[1])
            if edge_type in split_res[1][0].edge_label_index:
                hete_link_test_edge_num += (
                    split_res[1][0].edge_label_index[edge_type].shape[1])
            if edge_type in split_res[2][0].edge_label_index:
                hete_link_val_edge_num += (
                    split_res[2][0].edge_label_index[edge_type].shape[1])

        # num_edges_reduced = num_edges - 3
        edge_0 = 2 * int(0.8 * num_edges)
        edge_1 = 2 * int(0.1 * num_edges)
        edge_2 = 2 * (num_edges - int(0.8 * num_edges) - int(0.1 * num_edges))

        self.assertEqual(hete_link_train_edge_num, edge_0)
        self.assertEqual(hete_link_test_edge_num, edge_1)
        self.assertEqual(hete_link_val_edge_num, edge_2)
        # link pred with specified types and edge_split_mode set to "exact"
        dataset = GraphDataset(
            [hete],
            task="link_pred",
            edge_split_mode="approximate",
        )
        link_split_types = [("n1", "e1", "n1"), ("n1", "e2", "n2")]
        split_res = dataset.split(
            transductive=True,
            split_types=link_split_types,
        )
        hete_link_train_edge_num = 0
        hete_link_test_edge_num = 0
        hete_link_val_edge_num = 0

        num_split_type_edges = 0
        num_non_split_type_edges = 0
        for edge_type in hete.edge_label_index:
            if edge_type in link_split_types:
                num_split_type_edges += (
                    hete.edge_label_index[edge_type].shape[1])
            else:
                num_non_split_type_edges += (
                    hete.edge_label_index[edge_type].shape[1])
            if edge_type in split_res[0][0].edge_label_index:
                hete_link_train_edge_num += (
                    split_res[0][0].edge_label_index[edge_type].shape[1])
            if edge_type in split_res[1][0].edge_label_index:
                hete_link_test_edge_num += (
                    split_res[1][0].edge_label_index[edge_type].shape[1])
            if edge_type in split_res[2][0].edge_label_index:
                hete_link_val_edge_num += (
                    split_res[2][0].edge_label_index[edge_type].shape[1])

        # num_edges_reduced = num_split_type_edges - 3
        num_edges = num_split_type_edges
        edge_0 = 2 * int(0.8 * num_edges) + num_non_split_type_edges
        edge_1 = 2 * int(0.1 * num_edges) + num_non_split_type_edges
        edge_2 = 2 * (num_edges - int(0.8 * num_edges) -
                      int(0.1 * num_edges)) + num_non_split_type_edges

        self.assertEqual(hete_link_train_edge_num, edge_0)
        self.assertEqual(hete_link_test_edge_num, edge_1)
        self.assertEqual(hete_link_val_edge_num, edge_2)
    node_feature["citeseer_node"] = citeseer_x

    node_label = {}
    node_label["cora_node"] = cora_y
    node_label["citeseer_node"] = citeseer_y

    # prepare undirected edge_indx
    for message_type in edge_index:
        edge_index[message_type] = torch.cat([
            edge_index[message_type],
            torch.flip(edge_index[message_type], [0])
        ],
                                             dim=1)

    hete = HeteroGraph(node_feature=node_feature,
                       node_label=node_label,
                       edge_index=edge_index,
                       directed=False)
    print(
        f"Heterogeneous graph {hete.num_nodes()} nodes, {hete.num_edges()} edges"
    )

    dataset = GraphDataset([hete], task='node')
    dataset_train, dataset_val, dataset_test = dataset.split(
        transductive=True, split_ratio=[0.8, 0.1, 0.1])
    train_loader = DataLoader(dataset_train,
                              collate_fn=Batch.collate(),
                              batch_size=16)
    val_loader = DataLoader(dataset_val,
                            collate_fn=Batch.collate(),
                            batch_size=16)
    test_loader = DataLoader(dataset_test,
                pred = logits[node_type][node_idx]
                pred = pred.max(1)[1]
                acc += pred.eq(batch.node_label[node_type][node_idx].to(device)).sum().item()
                total += pred.size(0)
            acc /= total
            accs.append(acc)
    if accs[1] > best_val:
        best_val = accs[1]
        best_model = copy.deepcopy(model)
    return accs

if __name__ == "__main__":
    cora_pyg = Planetoid('./cora', 'Cora')
    citeseer_pyg = Planetoid('./citeseer', 'CiteSeer')
    G = concatenate_citeseer_cora(cora_pyg[0], citeseer_pyg[0])
    hete = HeteroGraph(G)
    print("Heterogeneous graph {} nodes, {} edges".format(hete.num_nodes, hete.num_edges))

    dataset = GraphDataset([hete], task='node')
    dataset_train, dataset_val, dataset_test = dataset.split(transductive=True,
                                                            split_ratio=[0.8, 0.1, 0.1])
    train_loader = DataLoader(dataset_train, collate_fn=Batch.collate(),
                        batch_size=16)
    val_loader = DataLoader(dataset_val, collate_fn=Batch.collate(),
                        batch_size=16)
    test_loader = DataLoader(dataset_test, collate_fn=Batch.collate(),
                        batch_size=16)
    loaders = [train_loader, val_loader, test_loader]

    hidden_size = 32
    model = HeteroNet(hete, hidden_size, 0.5).to(device)
Exemple #14
0
    def test_dataset_hetero_graph_split(self):
        G = generate_dense_hete_dataset()
        hete = HeteroGraph(G)
        # node
        dataset = GraphDataset([hete], task='node')
        split_res = dataset.split()
        for node_type in hete.node_label_index:
            num_nodes = int(len(hete.node_label_index[node_type]))
            num_nodes_reduced = num_nodes - 3
            node_0 = 1 + int(num_nodes_reduced * 0.8)
            node_1 = 1 + int(num_nodes_reduced * 0.1)
            node_2 = num_nodes - node_0 - node_1

            self.assertEqual(
                len(split_res[0][0].node_label_index[node_type]), node_0)

            self.assertEqual(
                len(split_res[1][0].node_label_index[node_type]), node_1)

            self.assertEqual(
                len(split_res[2][0].node_label_index[node_type]), node_2)

        # node with specified split type
        dataset = GraphDataset([hete], task='node')
        node_split_types = ['n1']
        split_res = dataset.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = int(len(hete.node_label_index[node_type]))
                num_nodes_reduced = num_nodes - 3
                node_0 = 1 + int(num_nodes_reduced * 0.8)
                node_1 = 1 + int(num_nodes_reduced * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]), node_0)

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]), node_1)

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]), node_2)
            else:
                num_nodes = int(len(hete.node_label_index[node_type]))
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]), num_nodes)

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]), num_nodes)

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]), num_nodes)

        # node with specified split type (string mode)
        dataset = GraphDataset([hete], task='node')
        node_split_types = 'n1'
        split_res = dataset.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = int(len(hete.node_label_index[node_type]))
                num_nodes_reduced = num_nodes - 3
                node_0 = 1 + int(num_nodes_reduced * 0.8)
                node_1 = 1 + int(num_nodes_reduced * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]), node_0)

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]), node_1)

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]), node_2)
            else:
                num_nodes = int(len(hete.node_label_index[node_type]))
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]), num_nodes)

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]), num_nodes)

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]), num_nodes)

        # edge
        dataset = GraphDataset([hete], task='edge')
        split_res = dataset.split()
        for edge_type in hete.edge_label_index:
            num_edges = hete.edge_label_index[edge_type].shape[1]
            num_edges_reduced = num_edges - 3
            edge_0 = 1 + int(num_edges_reduced * 0.8)
            edge_1 = 1 + int(num_edges_reduced * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(
                split_res[0][0].edge_label_index[edge_type].shape[1], edge_0)

            self.assertEqual(
                split_res[1][0].edge_label_index[edge_type].shape[1], edge_1)

            self.assertEqual(
                split_res[2][0].edge_label_index[edge_type].shape[1], edge_2)

        # edge with specified split type
        dataset = GraphDataset([hete], task='edge')
        edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')]
        split_res = dataset.split(split_types=edge_split_types)
        for edge_type in hete.edge_label_index:
            if edge_type in edge_split_types:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                num_edges_reduced = num_edges - 3
                edge_0 = 1 + int(num_edges_reduced * 0.8)
                edge_1 = 1 + int(num_edges_reduced * 0.1)
                edge_2 = num_edges - edge_0 - edge_1
                self.assertEqual(
                    split_res[0][0].edge_label_index[edge_type].shape[1], edge_0)

                self.assertEqual(
                    split_res[1][0].edge_label_index[edge_type].shape[1], edge_1)

                self.assertEqual(
                    split_res[2][0].edge_label_index[edge_type].shape[1], edge_2)
            else:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                self.assertEqual(
                    split_res[0][0].edge_label_index[edge_type].shape[1], num_edges)

                self.assertEqual(
                    split_res[1][0].edge_label_index[edge_type].shape[1], num_edges)

                self.assertEqual(
                    split_res[2][0].edge_label_index[edge_type].shape[1], num_edges)

        # link_pred
        dataset = GraphDataset([hete], task='link_pred')
        split_res = dataset.split(transductive=True)
        for edge_type in hete.edge_label_index:
            num_edges = hete.edge_label_index[edge_type].shape[1]
            num_edges_reduced = num_edges - 3
            self.assertEqual(split_res[0][0].edge_label_index[edge_type].shape[1],
                             (2 * (1 + int(0.8 * (num_edges_reduced)))))
            self.assertEqual(split_res[1][0].edge_label_index[edge_type].shape[1],
                             (2 * (1 + (int(0.1 * (num_edges_reduced))))))
            self.assertEqual(split_res[2][0].edge_label_index[edge_type].shape[1],
                             2 * num_edges - 2 * (2 + int(0.1 * num_edges_reduced) +
                                                  int(0.8 * num_edges_reduced)))

        # link_pred with specified split type
        dataset = GraphDataset([hete], task='link_pred')
        link_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')]
        split_res = dataset.split(transductive=True, split_types=link_split_types)

        for edge_type in hete.edge_label_index:
            if edge_type in link_split_types:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                num_edges_reduced = num_edges - 3
                self.assertEqual(split_res[0][0].edge_label_index[edge_type].shape[1],
                                 (2 * (1 + int(0.8 * (num_edges_reduced)))))
                self.assertEqual(split_res[1][0].edge_label_index[edge_type].shape[1],
                                 (2 * (1 + (int(0.1 * (num_edges_reduced))))))
                self.assertEqual(split_res[2][0].edge_label_index[edge_type].shape[1],
                                 2 * num_edges - 2 * (2 + int(0.1 * num_edges_reduced) +
                                                      int(0.8 * num_edges_reduced)))
            else:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                self.assertEqual(split_res[0][0].edge_label_index[edge_type].shape[1],
                                 (1 * (0 + int(1.0 * (num_edges)))))
                self.assertEqual(split_res[1][0].edge_label_index[edge_type].shape[1],
                                 (1 * (0 + (int(1.0 * (num_edges))))))
                self.assertEqual(split_res[2][0].edge_label_index[edge_type].shape[1],
                                 1 * (0 + (int(1.0 * (num_edges)))))

        # link_pred + disjoint
        dataset = GraphDataset([hete], task='link_pred', edge_train_mode='disjoint', edge_message_ratio=0.5)
        split_res = dataset.split(transductive=True, split_ratio=[0.6, 0.2, 0.2])
        for edge_type in hete.edge_label_index:
            num_edges = hete.edge_label_index[edge_type].shape[1]
            num_edges_reduced = num_edges - 3
            edge_0 = (1 + int(0.6 * num_edges_reduced))
            edge_0 = 2 * (edge_0 - (1 + int(0.5 * (edge_0 - 2))))

            self.assertEqual(split_res[0][0].edge_label_index[edge_type].shape[1], edge_0)
            edge_1 = 2 * (1 + int(0.2 * num_edges_reduced))
            self.assertEqual(split_res[1][0].edge_label_index[edge_type].shape[1], edge_1)
            edge_2 = 2 * int(num_edges) - \
                (2 * (1 + int(0.6 * num_edges_reduced))) - edge_1
            self.assertEqual(split_res[2][0].edge_label_index[edge_type].shape[1], edge_2)

        # link pred with edge_split_mode set to "exact"
        dataset = GraphDataset([hete], task='link_pred', edge_split_mode="approximate")
        split_res = dataset.split(transductive=True)
        hete_link_train_edge_num = 0
        hete_link_test_edge_num = 0
        hete_link_val_edge_num = 0
        num_edges = 0
        for edge_type in hete.edge_label_index:
            num_edges += hete.edge_label_index[edge_type].shape[1]
            if edge_type in split_res[0][0].edge_label_index:
                hete_link_train_edge_num += split_res[0][0].edge_label_index[edge_type].shape[1]
            if edge_type in split_res[1][0].edge_label_index:
                hete_link_test_edge_num += split_res[1][0].edge_label_index[edge_type].shape[1]
            if edge_type in split_res[2][0].edge_label_index:
                hete_link_val_edge_num += split_res[2][0].edge_label_index[edge_type].shape[1]

        num_edges_reduced = num_edges - 3
        self.assertEqual(hete_link_train_edge_num,
                         (2 * (1 + int(0.8 * (num_edges_reduced)))))
        self.assertEqual(hete_link_test_edge_num,
                         (2 * (1 + (int(0.1 * (num_edges_reduced))))))
        self.assertEqual(hete_link_val_edge_num,
                         2 * num_edges - 2 * (2 + int(0.1 * num_edges_reduced) +
                                              int(0.8 * num_edges_reduced)))

        # link pred with specified types and edge_split_mode set to "exact"
        dataset = GraphDataset([hete], task='link_pred', edge_split_mode="approximate")
        link_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')]
        split_res = dataset.split(transductive=True, split_types=link_split_types)
        hete_link_train_edge_num = 0
        hete_link_test_edge_num = 0
        hete_link_val_edge_num = 0

        num_split_type_edges = 0
        num_non_split_type_edges = 0
        for edge_type in hete.edge_label_index:
            if edge_type in link_split_types:
                num_split_type_edges += hete.edge_label_index[edge_type].shape[1]
            else:
                num_non_split_type_edges += hete.edge_label_index[edge_type].shape[1]
            if edge_type in split_res[0][0].edge_label_index:
                hete_link_train_edge_num += split_res[0][0].edge_label_index[edge_type].shape[1]
            if edge_type in split_res[1][0].edge_label_index:
                hete_link_test_edge_num += split_res[1][0].edge_label_index[edge_type].shape[1]
            if edge_type in split_res[2][0].edge_label_index:
                hete_link_val_edge_num += split_res[2][0].edge_label_index[edge_type].shape[1]

        num_edges_reduced = num_split_type_edges - 3
        edge_0 = 2 * (1 + int(0.8 * (num_edges_reduced))) + num_non_split_type_edges
        edge_1 = 2 * (1 + int(0.1 * (num_edges_reduced))) + num_non_split_type_edges
        edge_2 = 2 * num_split_type_edges - 2 * (2 + int(0.1 * num_edges_reduced) + \
                                                 int(0.8 * num_edges_reduced)) + num_non_split_type_edges

        self.assertEqual(hete_link_train_edge_num, edge_0)
        self.assertEqual(hete_link_test_edge_num, edge_1)
        self.assertEqual(hete_link_val_edge_num, edge_2)
Exemple #15
0
def main():
    args = arg_parse()

    edge_train_mode = args.mode
    print('edge train mode: {}'.format(edge_train_mode))

    G = nx.read_gpickle(args.data_path)
    print(G.number_of_edges())
    print('Each node has node ID (n_id). Example: ', G.nodes[0])
    print(
        'Each edge has edge ID (id) and categorical label (e_label). Example: ',
        G[0][5871])

    # find num edge types
    max_label = 0
    labels = []
    for u, v, edge_key in G.edges:
        l = G[u][v][edge_key]['e_label']
        if not l in labels:
            labels.append(l)
    # labels are consecutive (0-17)
    num_edge_types = len(labels)

    H = WN_transform(G, num_edge_types)
    # The nodes in the graph have the features: node_feature and node_type (just one node type "n1" here)
    for node in H.nodes(data=True):
        print(node)
        break
    # The edges in the graph have the features: edge_feature and edge_type ("0" - "17" here)
    for edge in H.edges(data=True):
        print(edge)
        break

    hetero = HeteroGraph(H)
    hetero = HeteroGraph(edge_index=hetero.edge_index,
                         edge_feature=hetero.edge_feature,
                         node_feature=hetero.node_feature,
                         directed=hetero.is_directed())

    if edge_train_mode == "disjoint":
        dataset = GraphDataset([hetero],
                               task='link_pred',
                               edge_train_mode=edge_train_mode,
                               edge_message_ratio=args.edge_message_ratio)
    else:
        dataset = GraphDataset(
            [hetero],
            task='link_pred',
            edge_train_mode=edge_train_mode,
        )

    dataset_train, dataset_val, dataset_test = dataset.split(
        transductive=True, split_ratio=[0.8, 0.1, 0.1])
    train_loader = DataLoader(dataset_train,
                              collate_fn=Batch.collate(),
                              batch_size=1)
    val_loader = DataLoader(dataset_val,
                            collate_fn=Batch.collate(),
                            batch_size=1)
    test_loader = DataLoader(dataset_test,
                             collate_fn=Batch.collate(),
                             batch_size=1)
    dataloaders = {
        'train': train_loader,
        'val': val_loader,
        'test': test_loader
    }

    hidden_size = args.hidden_dim
    conv1, conv2 = generate_2convs_link_pred_layers(hetero, HeteroSAGEConv,
                                                    hidden_size)
    model = HeteroGNN(conv1, conv2, hetero, hidden_size).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    t_accu, v_accu, e_accu = train(model, dataloaders, optimizer, args)
Exemple #16
0
def main():
    writer = SummaryWriter()
    args = arg_parse()

    edge_train_mode = args.mode
    print('edge train mode: {}'.format(edge_train_mode))

    ppi_graph = read_ppi_data(args.ppi_path)

    mode = 'mixed'
    if mode == 'ppi':
        message_passing_graph = ppi_graph
        cmap_graph, knockout_nodes = read_cmap_data(args.data_path)
    elif mode == 'mixed':
        message_passing_graph, knockout_nodes = (
            read_cmap_data(args.data_path, ppi_graph)
        )

    print('Each node has gene ID. Example: ', message_passing_graph.nodes['ADPGK'])
    print('Each edge has de direction. Example', message_passing_graph['ADPGK']['IL1B'])
    print('Total num edges: ', message_passing_graph.number_of_edges())

    # disjoint edge label
    disjoint_split_ratio = 0.1
    val_ratio = 0.1
    disjoint_edge_label_index = []
    val_edges = []

    # newly edited
    train_edges = []
    for u in knockout_nodes:
        rand_num = np.random.rand()
        if rand_num < disjoint_split_ratio:
            # add all edges (cmap only) into edge label index
            # cmap is not a multigraph
            disjoint_edge_label_index.extend(
                [
                    (u, v, edge_key)
                    for v in message_passing_graph.successors(u)
                    for edge_key in message_passing_graph[u][v]
                    if message_passing_graph[u][v][edge_key]['edge_type'] == 1
                ]
            )

            train_edges.extend(
                [
                    (u, v, edge_key)
                    for v in message_passing_graph.successors(u)
                    for edge_key in message_passing_graph[u][v]
                    if message_passing_graph[u][v][edge_key]['edge_type'] == 1
                ]
            )
        elif rand_num < disjoint_split_ratio + val_ratio:
            val_edges.extend(
                [
                    (u, v, edge_key)
                    for v in message_passing_graph.successors(u)
                    for edge_key in message_passing_graph[u][v]
                    if message_passing_graph[u][v][edge_key]['edge_type'] == 1
                ]
            )
        else:
            train_edges.extend(
                [
                    (u, v, edge_key)
                    for v in message_passing_graph.successors(u)
                    for edge_key in message_passing_graph[u][v]
                    if message_passing_graph[u][v][edge_key]['edge_type'] == 1
                ]
            )
    # add default node types for message_passing_graph
    for node in message_passing_graph.nodes:
        message_passing_graph.nodes[node]['node_type'] = 0

    print('Num edges to predict: ', len(disjoint_edge_label_index))
    print('Num edges in val: ', len(val_edges))
    print('Num edges in train: ', len(train_edges))

    graph = HeteroGraph(
        message_passing_graph,
        custom={
            "general_splits": [
                train_edges,
                val_edges
            ],
            "disjoint_split": disjoint_edge_label_index,
            "task": "link_pred"
        }
    )

    graphs = [graph]
    graphDataset = GraphDataset(
        graphs,
        task="link_pred",
        edge_train_mode="disjoint"
    )

    # Transform dataset
    # de direction (currently using homogeneous graph)
    num_edge_types = 2

    graphDataset = graphDataset.apply_transform(
        cmap_transform, num_edge_types=num_edge_types, deep_copy=False
    )
    print('Number of node features: ', graphDataset.num_node_features())

    # split dataset
    dataset = {}
    dataset['train'], dataset['val'] = graphDataset.split(transductive=True)

    # sanity check
    print(f"dataset['train'][0].edge_label_index.keys(): {dataset['train'][0].edge_label_index.keys()}")
    print(f"dataset['train'][0].edge_label_index[(0, 1, 0)].shape[1]: {dataset['train'][0].edge_label_index[(0, 1, 0)].shape[1]}")
    print(f"dataset['val'][0].edge_label_index.keys(): {dataset['val'][0].edge_label_index.keys()}")
    print(f"dataset['val'][0].edge_label_index[(0, 1, 0)].shape[1]: {dataset['val'][0].edge_label_index[(0, 1, 0)].shape[1]}")
    print(f"len(list(dataset['train'][0].G.edges)): {len(list(dataset['train'][0].G.edges))}")
    print(f"len(list(dataset['val'][0].G.edges)): {len(list(dataset['val'][0].G.edges))}")
    print(f"list(dataset['train'][0].G.edges)[:10]: {list(dataset['train'][0].G.edges)[:10]}")
    print(f"list(dataset['val'][0].G.edges)[:10]: {list(dataset['val'][0].G.edges)[:10]}")


    # node feature dimension
    input_dim = dataset['train'].num_node_features()
    edge_feat_dim = dataset['train'].num_edge_features()
    num_classes = dataset['train'].num_edge_labels()
    print(
        'Node feature dim: {}; edge feature dim: {}; num classes: {}.'.format(
            input_dim, edge_feat_dim, num_classes
        )
    )
    exit()

    # relation type is both used for edge features and edge labels
    model = Net(input_dim, edge_feat_dim, num_classes, args).to(args.device)
    optimizer = torch.optim.Adam(
        model.parameters(), lr=0.001, weight_decay=5e-3
    )
    follow_batch = []  # e.g., follow_batch = ['edge_index']

    dataloaders = {
        split: DataLoader(
            ds, collate_fn=Batch.collate(follow_batch),
            batch_size=1, shuffle=(split == 'train')
        )
        for split, ds in dataset.items()
    }
    print('Graphs after split: ')
    for key, dataloader in dataloaders.items():
        for batch in dataloader:
            print(key, ': ', batch)

    train(model, dataloaders, optimizer, args, writer=writer)
    def test_hetero_multigraph_split(self):
        G = generate_dense_hete_multigraph()
        hete = HeteroGraph(G)
        hete = HeteroGraph(node_feature=hete.node_feature,
                           node_label=hete.node_label,
                           edge_feature=hete.edge_feature,
                           edge_label=hete.edge_label,
                           edge_index=hete.edge_index,
                           directed=True)

        # node
        hete_node = hete.split(task='node')
        for node_type in hete.node_label_index:
            num_nodes = len(hete.node_label_index[node_type])
            node_0 = int(num_nodes * 0.8)
            node_1 = int(num_nodes * 0.1)
            node_2 = num_nodes - node_0 - node_1
            self.assertEqual(
                len(hete_node[0].node_label_index[node_type]),
                node_0,
            )
            self.assertEqual(
                len(hete_node[1].node_label_index[node_type]),
                node_1,
            )
            self.assertEqual(
                len(hete_node[2].node_label_index[node_type]),
                node_2,
            )

        # edge
        hete_edge = hete.split(task='edge')
        for edge_type in hete.edge_label_index:
            num_edges = int(hete.edge_label_index[edge_type].shape[1])
            edge_0 = int(num_edges * 0.8)
            edge_1 = int(num_edges * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(
                hete_edge[0].edge_label_index[edge_type].shape[1],
                edge_0,
            )
            self.assertEqual(
                hete_edge[1].edge_label_index[edge_type].shape[1],
                edge_1,
            )
            self.assertEqual(
                hete_edge[2].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # link prediction
        hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2])
        # calculate the expected edge num for each splitted subgraph
        edge_0, edge_1, edge_2 = 0, 0, 0
        for _, val in hete.edge_label_index.items():
            num_edges = val.shape[1]
            edge_0 += int(0.5 * num_edges)
            edge_1 += int(0.3 * num_edges)
            edge_2 += num_edges - int(0.5 * num_edges) - int(0.3 * num_edges)

        train_edge_num = sum([
            hete_link[0].edge_label[message_type].shape[0]
            for message_type in hete_link[0].edge_label
        ])
        val_edge_num = sum([
            hete_link[1].edge_label[message_type].shape[0]
            for message_type in hete_link[1].edge_label
        ])
        test_edge_num = sum([
            hete_link[2].edge_label[message_type].shape[0]
            for message_type in hete_link[2].edge_label
        ])
        self.assertEqual(train_edge_num, edge_0)
        self.assertEqual(val_edge_num, edge_1)
        self.assertEqual(test_edge_num, edge_2)
    def test_hetero_graph_split(self):
        # directed G
        G = generate_dense_hete_graph()
        hete = HeteroGraph(G)
        hete = HeteroGraph(
            node_feature=hete.node_feature,
            node_label=hete.node_label,
            edge_feature=hete.edge_feature,
            edge_label=hete.edge_label,
            edge_index=hete.edge_index,
        )

        # node
        hete_node = hete.split()
        for node_type in hete.node_label_index:
            num_nodes = len(hete.node_label_index[node_type])
            node_0 = int(num_nodes * 0.8)
            node_1 = int(num_nodes * 0.1)
            node_2 = num_nodes - node_0 - node_1
            self.assertEqual(
                len(hete_node[0].node_label_index[node_type]),
                node_0,
            )
            self.assertEqual(
                len(hete_node[1].node_label_index[node_type]),
                node_1,
            )
            self.assertEqual(
                len(hete_node[2].node_label_index[node_type]),
                node_2,
            )

        # node with specified split type
        node_split_types = ['n1']
        hete_node = hete.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = len(hete.node_label_index[node_type])
                node_0 = int(num_nodes * 0.8)
                node_1 = int(num_nodes * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(hete_node[0].node_label_index[node_type]),
                    node_0,
                )
                self.assertEqual(
                    len(hete_node[1].node_label_index[node_type]),
                    node_1,
                )
                self.assertEqual(
                    len(hete_node[2].node_label_index[node_type]),
                    node_2,
                )
            else:
                self.assertEqual(
                    len(hete_node[0].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )
                self.assertEqual(
                    len(hete_node[1].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )
                self.assertEqual(
                    len(hete_node[2].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )

        # edge
        hete_edge = hete.split(task='edge')
        for edge_type in hete.edge_label_index:
            num_edges = int(hete.edge_label_index[edge_type].shape[1])
            edge_0 = int(num_edges * 0.8)
            edge_1 = int(num_edges * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(
                hete_edge[0].edge_label_index[edge_type].shape[1],
                edge_0,
            )
            self.assertEqual(
                hete_edge[1].edge_label_index[edge_type].shape[1],
                edge_1,
            )
            self.assertEqual(
                hete_edge[2].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # edge with specified split type
        edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')]
        hete_edge = hete.split(task='edge', split_types=edge_split_types)
        for edge_type in hete.edge_label_index:
            if edge_type in edge_split_types:
                num_edges = int(hete.edge_label_index[edge_type].shape[1])
                edge_0 = int(num_edges * 0.8)
                edge_1 = int(num_edges * 0.1)
                edge_2 = num_edges - edge_0 - edge_1

                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    edge_0,
                )
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    edge_1,
                )
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    edge_2,
                )
            else:
                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )

        # link_pred
        hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2])
        for key, val in hete.edge_label_index.items():
            num_edges = val.shape[1]
            edge_0 = int(0.5 * num_edges)
            edge_1 = int(0.3 * num_edges)
            edge_2 = num_edges - edge_0 - edge_1

            self.assertEqual(hete_link[0].edge_label[key].shape[0], edge_0)
            self.assertEqual(hete_link[1].edge_label[key].shape[0], edge_1)
            self.assertEqual(hete_link[2].edge_label[key].shape[0], edge_2)

        # undirected G
        G = generate_dense_hete_graph(directed=False)
        hete = HeteroGraph(G)
        hete = HeteroGraph(node_feature=hete.node_feature,
                           node_label=hete.node_label,
                           edge_feature=hete.edge_feature,
                           edge_label=hete.edge_label,
                           edge_index=hete.edge_index,
                           directed=False)

        # node
        hete_node = hete.split()
        for node_type in hete.node_label_index:
            num_nodes = len(hete.node_label_index[node_type])
            node_0 = int(num_nodes * 0.8)
            node_1 = int(num_nodes * 0.1)
            node_2 = num_nodes - node_0 - node_1
            self.assertEqual(
                len(hete_node[0].node_label_index[node_type]),
                node_0,
            )
            self.assertEqual(
                len(hete_node[1].node_label_index[node_type]),
                node_1,
            )
            self.assertEqual(
                len(hete_node[2].node_label_index[node_type]),
                node_2,
            )

        # node with specified split type
        node_split_types = ['n1']
        hete_node = hete.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = len(hete.node_label_index[node_type])
                node_0 = int(num_nodes * 0.8)
                node_1 = int(num_nodes * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(hete_node[0].node_label_index[node_type]),
                    node_0,
                )
                self.assertEqual(
                    len(hete_node[1].node_label_index[node_type]),
                    node_1,
                )
                self.assertEqual(
                    len(hete_node[2].node_label_index[node_type]),
                    node_2,
                )
            else:
                self.assertEqual(
                    len(hete_node[0].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )
                self.assertEqual(
                    len(hete_node[1].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )
                self.assertEqual(
                    len(hete_node[2].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )

        # edge
        hete_edge = hete.split(task='edge')
        for edge_type in hete.edge_label_index:
            num_edges = int(hete.num_edges(edge_type))
            edge_0 = int(num_edges * 0.8)
            edge_1 = int(num_edges * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(
                hete_edge[0].edge_label_index[edge_type].shape[1],
                edge_0,
            )
            self.assertEqual(
                hete_edge[1].edge_label_index[edge_type].shape[1],
                edge_1,
            )
            self.assertEqual(
                hete_edge[2].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # edge with specified split type
        edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')]
        hete_edge = hete.split(task='edge', split_types=edge_split_types)
        for edge_type in hete.edge_label_index:
            if edge_type in edge_split_types:
                num_edges = int(hete.num_edges(edge_type))
                edge_0 = int(num_edges * 0.8)
                edge_1 = int(num_edges * 0.1)
                edge_2 = num_edges - edge_0 - edge_1

                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    edge_0,
                )
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    edge_1,
                )
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    edge_2,
                )
            else:
                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )

        # link_pred
        hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2])
        for key, val in hete.edge_label_index.items():
            num_edges = int(val.shape[1] / 2)
            edge_0 = 2 * int(0.5 * num_edges)
            edge_1 = 2 * int(0.3 * num_edges)
            edge_2 = 2 * (num_edges - int(0.5 * num_edges) -
                          int(0.3 * num_edges))

            self.assertEqual(hete_link[0].edge_label[key].shape[0], edge_0)
            self.assertEqual(hete_link[1].edge_label[key].shape[0], edge_1)
            self.assertEqual(hete_link[2].edge_label[key].shape[0], edge_2)
    def test_hetero_multigraph_split(self):
        G = generate_dense_hete_multigraph()
        hete = HeteroGraph(G)
        hete = HeteroGraph(
            node_feature=hete.node_feature,
            node_label=hete.node_label,
            edge_feature=hete.edge_feature,
            edge_label=hete.edge_label,
            edge_index=hete.edge_index,
            directed=True
        )

        # node
        hete_node = hete.split(task='node')
        for node_type in hete.node_label_index:
            num_nodes = len(hete.node_label_index[node_type])
            num_nodes_reduced = num_nodes - 3
            node_0 = 1 + int(num_nodes_reduced * 0.8)
            node_1 = 1 + int(num_nodes_reduced * 0.1)
            node_2 = num_nodes - node_0 - node_1
            self.assertEqual(
                len(hete_node[0].node_label_index[node_type]),
                node_0,
            )
            self.assertEqual(
                len(hete_node[1].node_label_index[node_type]),
                node_1,
            )
            self.assertEqual(
                len(hete_node[2].node_label_index[node_type]),
                node_2,
            )

        # edge
        hete_edge = hete.split(task='edge')
        for edge_type in hete.edge_label_index:
            num_edges = int(hete.edge_label_index[edge_type].shape[1])
            num_edges_reduced = num_edges - 3
            edge_0 = 1 + int(num_edges_reduced * 0.8)
            edge_1 = 1 + int(num_edges_reduced * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(
                hete_edge[0].edge_label_index[edge_type].shape[1],
                edge_0,
            )
            self.assertEqual(
                hete_edge[1].edge_label_index[edge_type].shape[1],
                edge_1,
            )
            self.assertEqual(
                hete_edge[2].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # link prediction
        hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2])
        # calculate the expected edge num for each splitted subgraph
        hete_link_train_edge_num = 0
        hete_link_val_edge_num = 0
        hete_link_test_edge_num = 0
        for _, val in hete.edge_label_index.items():
            val_length = val.shape[1]
            val_length_reduced = val_length - 3
            hete_link_train_edge_num += 1 + int(0.5 * val_length_reduced)
            hete_link_val_edge_num += 1 + int(0.3 * val_length_reduced)
            hete_link_test_edge_num += (
                val_length
                - 2
                - int(0.5 * val_length_reduced)
                - int(0.3 * val_length_reduced)
            )

        train_edge_num = sum([
            hete_link[0].edge_label[message_type].shape[0]
            for message_type in hete_link[0].edge_label
        ])
        val_edge_num = sum([
            hete_link[1].edge_label[message_type].shape[0]
            for message_type in hete_link[1].edge_label
        ])
        test_edge_num = sum([
            hete_link[2].edge_label[message_type].shape[0]
            for message_type in hete_link[2].edge_label
        ])
        self.assertEqual(
            train_edge_num,
            hete_link_train_edge_num
        )
        self.assertEqual(
            val_edge_num,
            hete_link_val_edge_num
        )
        self.assertEqual(
            test_edge_num,
            hete_link_test_edge_num,
        )
Exemple #20
0
 def test_hetero_graph_none(self):
     G = generate_simple_hete_graph(no_edge_type=True)
     hete = HeteroGraph(G)
     message_types = hete.message_types
     for message_type in message_types:
         self.assertEqual(message_type[1], None)
Exemple #21
0
    def test_hetero_graph_split(self):
        G = generate_dense_hete_graph()
        hete = HeteroGraph(G)

        hete_node = hete.split()
        for node_type in hete.node_label_index:
            num_nodes = len(hete.node_label_index[node_type])
            num_nodes_reduced = num_nodes - 3
            node_0 = 1 + int(num_nodes_reduced * 0.8)
            node_1 = 1 + int(num_nodes_reduced * 0.1)
            node_2 = num_nodes - node_0 - node_1
            self.assertEqual(len(hete_node[0].node_label_index[node_type]),
                             node_0)
            self.assertEqual(len(hete_node[1].node_label_index[node_type]),
                             node_1)
            self.assertEqual(len(hete_node[2].node_label_index[node_type]),
                             node_2)

        # node with specified split type
        node_split_types = ['n1']
        hete_node = hete.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if (node_type in node_split_types):
                num_nodes = len(hete.node_label_index[node_type])
                num_nodes_reduced = num_nodes - 3
                node_0 = 1 + int(num_nodes_reduced * 0.8)
                node_1 = 1 + int(num_nodes_reduced * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(len(hete_node[0].node_label_index[node_type]),
                                 node_0)
                self.assertEqual(len(hete_node[1].node_label_index[node_type]),
                                 node_1)
                self.assertEqual(len(hete_node[2].node_label_index[node_type]),
                                 node_2)
            else:
                self.assertEqual(len(hete_node[0].node_label_index[node_type]),
                                 len(hete.node_label_index[node_type]))
                self.assertEqual(len(hete_node[1].node_label_index[node_type]),
                                 len(hete.node_label_index[node_type]))
                self.assertEqual(len(hete_node[2].node_label_index[node_type]),
                                 len(hete.node_label_index[node_type]))

        # edge
        hete_edge = hete.split(task='edge')
        for edge_type in hete.edge_label_index:
            num_edges = int(hete.edge_label_index[edge_type].shape[1])
            num_edges_reduced = num_edges - 3
            edge_0 = 1 + int(num_edges_reduced * 0.8)
            edge_1 = 1 + int(num_edges_reduced * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(hete_edge[0].edge_label_index[edge_type].shape[1],
                             edge_0)
            self.assertEqual(hete_edge[1].edge_label_index[edge_type].shape[1],
                             edge_1)
            self.assertEqual(hete_edge[2].edge_label_index[edge_type].shape[1],
                             edge_2)

        # edge with specified split type
        edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')]
        hete_edge = hete.split(task='edge', split_types=edge_split_types)
        for edge_type in hete.edge_label_index:
            if (edge_type in edge_split_types):
                num_edges = int(hete.edge_label_index[edge_type].shape[1])
                num_edges_reduced = num_edges - 3
                edge_0 = 1 + int(num_edges_reduced * 0.8)
                edge_1 = 1 + int(num_edges_reduced * 0.1)
                edge_2 = num_edges - edge_0 - edge_1

                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1], edge_0)
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1], edge_1)
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1], edge_2)
            else:
                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1])
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1])
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1])

        # link_pred
        hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2])
        # calculate the expected edge num for each splitted subgraph
        hete_link_train_edge_num, hete_link_val_edge_num, hete_link_test_edge_num = 0, 0, 0
        for key, val in hete.edge_label_index.items():
            val_length = val.shape[1]
            val_length_reduced = val_length - 3
            hete_link_train_edge_num += 1 + int(0.5 * val_length_reduced)
            hete_link_val_edge_num += 1 + int(0.3 * val_length_reduced)
            hete_link_test_edge_num += \
                val_length - 2 - int(0.5 * val_length_reduced) - int(0.3 * val_length_reduced)

        self.assertEqual(len(hete_link[0].edge_label),
                         hete_link_train_edge_num)
        self.assertEqual(len(hete_link[1].edge_label), hete_link_val_edge_num)
        self.assertEqual(len(hete_link[2].edge_label), hete_link_test_edge_num)
Exemple #22
0
    def test_hetero_graph_split(self):
        # directed G
        G = generate_dense_hete_graph()
        hete = HeteroGraph(G)

        hete_node = hete.split()
        for node_type in hete.node_label_index:
            num_nodes = len(hete.node_label_index[node_type])
            num_nodes_reduced = num_nodes - 3
            node_0 = 1 + int(num_nodes_reduced * 0.8)
            node_1 = 1 + int(num_nodes_reduced * 0.1)
            node_2 = num_nodes - node_0 - node_1
            self.assertEqual(
                len(hete_node[0].node_label_index[node_type]),
                node_0,
            )
            self.assertEqual(
                len(hete_node[1].node_label_index[node_type]),
                node_1,
            )
            self.assertEqual(
                len(hete_node[2].node_label_index[node_type]),
                node_2,
            )

        # node with specified split type
        node_split_types = ['n1']
        hete_node = hete.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = len(hete.node_label_index[node_type])
                num_nodes_reduced = num_nodes - 3
                node_0 = 1 + int(num_nodes_reduced * 0.8)
                node_1 = 1 + int(num_nodes_reduced * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(hete_node[0].node_label_index[node_type]),
                    node_0,
                )
                self.assertEqual(
                    len(hete_node[1].node_label_index[node_type]),
                    node_1,
                )
                self.assertEqual(
                    len(hete_node[2].node_label_index[node_type]),
                    node_2,
                )
            else:
                self.assertEqual(
                    len(hete_node[0].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )
                self.assertEqual(
                    len(hete_node[1].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )
                self.assertEqual(
                    len(hete_node[2].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )

        # edge
        hete_edge = hete.split(task='edge')
        for edge_type in hete.edge_label_index:
            num_edges = int(hete.edge_label_index[edge_type].shape[1])
            num_edges_reduced = num_edges - 3
            edge_0 = 1 + int(num_edges_reduced * 0.8)
            edge_1 = 1 + int(num_edges_reduced * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(
                hete_edge[0].edge_label_index[edge_type].shape[1],
                edge_0,
            )
            self.assertEqual(
                hete_edge[1].edge_label_index[edge_type].shape[1],
                edge_1,
            )
            self.assertEqual(
                hete_edge[2].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # edge with specified split type
        edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')]
        hete_edge = hete.split(task='edge', split_types=edge_split_types)
        for edge_type in hete.edge_label_index:
            if edge_type in edge_split_types:
                num_edges = int(hete.edge_label_index[edge_type].shape[1])
                num_edges_reduced = num_edges - 3
                edge_0 = 1 + int(num_edges_reduced * 0.8)
                edge_1 = 1 + int(num_edges_reduced * 0.1)
                edge_2 = num_edges - edge_0 - edge_1

                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    edge_0,
                )
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    edge_1,
                )
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    edge_2,
                )
            else:
                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )

        # link_pred
        hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2])
        for key, val in hete.edge_label_index.items():
            val_length = val.shape[1]
            val_length_reduced = val_length - 3
            hete_link_train_edge_num = 1 + int(0.5 * val_length_reduced)
            hete_link_val_edge_num = 1 + int(0.3 * val_length_reduced)
            hete_link_test_edge_num = (val_length - 2 -
                                       int(0.5 * val_length_reduced) -
                                       int(0.3 * val_length_reduced))

            self.assertEqual(hete_link[0].edge_label[key].shape[0],
                             hete_link_train_edge_num)
            self.assertEqual(hete_link[1].edge_label[key].shape[0],
                             hete_link_val_edge_num)
            self.assertEqual(hete_link[2].edge_label[key].shape[0],
                             hete_link_test_edge_num)

        # undirected G
        G = generate_dense_hete_graph(directed=False)
        hete = HeteroGraph(G)

        hete_node = hete.split()
        for node_type in hete.node_label_index:
            num_nodes = len(hete.node_label_index[node_type])
            num_nodes_reduced = num_nodes - 3
            node_0 = 1 + int(num_nodes_reduced * 0.8)
            node_1 = 1 + int(num_nodes_reduced * 0.1)
            node_2 = num_nodes - node_0 - node_1
            self.assertEqual(
                len(hete_node[0].node_label_index[node_type]),
                node_0,
            )
            self.assertEqual(
                len(hete_node[1].node_label_index[node_type]),
                node_1,
            )
            self.assertEqual(
                len(hete_node[2].node_label_index[node_type]),
                node_2,
            )

        # node with specified split type
        node_split_types = ['n1']
        hete_node = hete.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = len(hete.node_label_index[node_type])
                num_nodes_reduced = num_nodes - 3
                node_0 = 1 + int(num_nodes_reduced * 0.8)
                node_1 = 1 + int(num_nodes_reduced * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(hete_node[0].node_label_index[node_type]),
                    node_0,
                )
                self.assertEqual(
                    len(hete_node[1].node_label_index[node_type]),
                    node_1,
                )
                self.assertEqual(
                    len(hete_node[2].node_label_index[node_type]),
                    node_2,
                )
            else:
                self.assertEqual(
                    len(hete_node[0].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )
                self.assertEqual(
                    len(hete_node[1].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )
                self.assertEqual(
                    len(hete_node[2].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )

        # edge
        hete_edge = hete.split(task='edge')
        for edge_type in hete.edge_label_index:
            num_edges = int(hete.edge_label_index[edge_type].shape[1])
            num_edges_reduced = num_edges - 3
            edge_0 = 1 + int(num_edges_reduced * 0.8)
            edge_1 = 1 + int(num_edges_reduced * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(
                hete_edge[0].edge_label_index[edge_type].shape[1],
                edge_0,
            )
            self.assertEqual(
                hete_edge[1].edge_label_index[edge_type].shape[1],
                edge_1,
            )
            self.assertEqual(
                hete_edge[2].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # edge with specified split type
        edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')]
        hete_edge = hete.split(task='edge', split_types=edge_split_types)
        for edge_type in hete.edge_label_index:
            if edge_type in edge_split_types:
                num_edges = int(hete.edge_label_index[edge_type].shape[1])
                num_edges_reduced = num_edges - 3
                edge_0 = 1 + int(num_edges_reduced * 0.8)
                edge_1 = 1 + int(num_edges_reduced * 0.1)
                edge_2 = num_edges - edge_0 - edge_1

                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    edge_0,
                )
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    edge_1,
                )
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    edge_2,
                )
            else:
                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )

        hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2])
        for key, val in hete.edge_label_index.items():
            val_length = val.shape[1]
            hete_link_train_edge_num = (2 *
                                        (1 + int(0.5 *
                                                 (int(val_length / 2) - 3))))
            hete_link_val_edge_num = (2 * (1 + int(0.3 *
                                                   (int(val_length / 2) - 3))))
            hete_link_test_edge_num = (val_length - hete_link_train_edge_num -
                                       hete_link_val_edge_num)

            self.assertEqual(hete_link[0].edge_label[key].shape[0],
                             hete_link_train_edge_num)
            self.assertEqual(hete_link[1].edge_label[key].shape[0],
                             hete_link_val_edge_num)
            self.assertEqual(hete_link[2].edge_label[key].shape[0],
                             hete_link_test_edge_num)
    def test_secure_split_heterogeneous(self):
        G = generate_simple_small_hete_graph()
        graph = HeteroGraph(G)
        graph = HeteroGraph(node_label=graph.node_label,
                            edge_index=graph.edge_index,
                            edge_label=graph.edge_label,
                            directed=True)
        graphs = [graph]

        # node task
        dataset = GraphDataset(graphs, task="node")
        split_res = dataset.split()
        for node_type in graph.node_label_index:
            num_nodes = graph.node_label_index[node_type].shape[0]
            num_nodes_reduced = num_nodes - 3
            node_0 = 1 + int(num_nodes_reduced * 0.8)
            node_1 = 1 + int(num_nodes_reduced * 0.1)
            node_2 = num_nodes - node_0 - node_1
            node_size = [node_0, node_1, node_2]
            for i in range(3):
                self.assertEqual(
                    split_res[i][0].node_label_index[node_type].shape[0],
                    node_size[i])
                self.assertEqual(
                    split_res[i][0].node_label[node_type].shape[0],
                    node_size[i])

        # edge task
        dataset = GraphDataset(graphs, task="edge")
        split_res = dataset.split()
        for message_type in graph.edge_label_index:
            num_edges = graph.edge_label_index[message_type].shape[1]
            num_edges_reduced = num_edges - 3
            edge_0 = 1 + int(num_edges_reduced * 0.8)
            edge_1 = 1 + int(num_edges_reduced * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            edge_size = [edge_0, edge_1, edge_2]
            for i in range(3):
                self.assertEqual(
                    split_res[i][0].edge_label_index[message_type].shape[1],
                    edge_size[i])
                self.assertEqual(
                    split_res[i][0].edge_label[message_type].shape[0],
                    edge_size[i])

        # link_pred task
        dataset = GraphDataset(graphs, task="link_pred")
        split_res = dataset.split()
        for message_type in graph.edge_label_index:
            num_edges = graph.edge_label_index[message_type].shape[1]
            num_edges_reduced = num_edges - 3
            edge_0 = 2 * (1 + int(num_edges_reduced * 0.8))
            edge_1 = 2 * (1 + int(num_edges_reduced * 0.1))
            edge_2 = 2 * num_edges - edge_0 - edge_1
            edge_size = [edge_0, edge_1, edge_2]
            for i in range(3):
                self.assertEqual(
                    split_res[i][0].edge_label_index[message_type].shape[1],
                    edge_size[i])
                self.assertEqual(
                    split_res[i][0].edge_label[message_type].shape[0],
                    edge_size[i])
Exemple #24
0
    # Dictionary of node features
    node_feature = {}
    node_feature["paper"] = data['feature']

    # Dictionary of node labels
    node_label = {}
    node_label["paper"] = data['label']

    # Load the train, validation and test indices
    train_idx = {"paper": data['train_idx'].to(args.device)}
    val_idx = {"paper": data['val_idx'].to(args.device)}
    test_idx = {"paper": data['test_idx'].to(args.device)}

    # Construct a deepsnap tensor backend HeteroGraph
    hetero_graph = HeteroGraph(node_feature=node_feature,
                               node_label=node_label,
                               edge_index=edge_index,
                               directed=True)

    print(
        f"ACM heterogeneous graph: {hetero_graph.num_nodes()} nodes, {hetero_graph.num_edges()} edges"
    )

    # Node feature and node label to device
    for key in hetero_graph.node_feature:
        hetero_graph.node_feature[key] = hetero_graph.node_feature[key].to(
            args.device)
    for key in hetero_graph.node_label:
        hetero_graph.node_label[key] = hetero_graph.node_label[key].to(
            args.device)

    # Edge_index to sparse tensor and to device