Пример #1
0
    def test_hetero_multigraph_split(self):
        G = generate_dense_hete_multigraph()
        hete = HeteroGraph(G)

        # node
        hete_node = hete.split(task='node')
        for node_type in hete.node_label_index:
            num_nodes = len(hete.node_label_index[node_type])
            num_nodes_reduced = num_nodes - 3
            node_0 = 1 + int(num_nodes_reduced * 0.8)
            node_1 = 1 + int(num_nodes_reduced * 0.1)
            node_2 = num_nodes - node_0 - node_1
            self.assertEqual(
                len(hete_node[0].node_label_index[node_type]),
                node_0,
            )
            self.assertEqual(
                len(hete_node[1].node_label_index[node_type]),
                node_1,
            )
            self.assertEqual(
                len(hete_node[2].node_label_index[node_type]),
                node_2,
            )

        # edge
        hete_edge = hete.split(task='edge')
        for edge_type in hete.edge_label_index:
            num_edges = int(hete.edge_label_index[edge_type].shape[1])
            num_edges_reduced = num_edges - 3
            edge_0 = 1 + int(num_edges_reduced * 0.8)
            edge_1 = 1 + int(num_edges_reduced * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(
                hete_edge[0].edge_label_index[edge_type].shape[1],
                edge_0,
            )
            self.assertEqual(
                hete_edge[1].edge_label_index[edge_type].shape[1],
                edge_1,
            )
            self.assertEqual(
                hete_edge[2].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # link prediction
        hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2])
        # calculate the expected edge num for each splitted subgraph
        hete_link_train_edge_num = 0
        hete_link_val_edge_num = 0
        hete_link_test_edge_num = 0
        for _, val in hete.edge_label_index.items():
            val_length = val.shape[1]
            val_length_reduced = val_length - 3
            hete_link_train_edge_num += 1 + int(0.5 * val_length_reduced)
            hete_link_val_edge_num += 1 + int(0.3 * val_length_reduced)
            hete_link_test_edge_num += (val_length - 2 -
                                        int(0.5 * val_length_reduced) -
                                        int(0.3 * val_length_reduced))

        self.assertEqual(
            len(hete_link[0].edge_label),
            hete_link_train_edge_num,
        )
        self.assertEqual(
            len(hete_link[1].edge_label),
            hete_link_val_edge_num,
        )
        self.assertEqual(
            len(hete_link[2].edge_label),
            hete_link_test_edge_num,
        )
Пример #2
0
    def test_hetero_graph_split(self):
        G = generate_dense_hete_graph()
        hete = HeteroGraph(G)

        hete_node = hete.split()
        for node_type in hete.node_label_index:
            num_nodes = len(hete.node_label_index[node_type])
            num_nodes_reduced = num_nodes - 3
            node_0 = 1 + int(num_nodes_reduced * 0.8)
            node_1 = 1 + int(num_nodes_reduced * 0.1)
            node_2 = num_nodes - node_0 - node_1
            self.assertEqual(
                len(hete_node[0].node_label_index[node_type]),
                node_0,
            )
            self.assertEqual(
                len(hete_node[1].node_label_index[node_type]),
                node_1,
            )
            self.assertEqual(
                len(hete_node[2].node_label_index[node_type]),
                node_2,
            )

        # node with specified split type
        node_split_types = ['n1']
        hete_node = hete.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = len(hete.node_label_index[node_type])
                num_nodes_reduced = num_nodes - 3
                node_0 = 1 + int(num_nodes_reduced * 0.8)
                node_1 = 1 + int(num_nodes_reduced * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(hete_node[0].node_label_index[node_type]),
                    node_0,
                )
                self.assertEqual(
                    len(hete_node[1].node_label_index[node_type]),
                    node_1,
                )
                self.assertEqual(
                    len(hete_node[2].node_label_index[node_type]),
                    node_2,
                )
            else:
                self.assertEqual(
                    len(hete_node[0].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )
                self.assertEqual(
                    len(hete_node[1].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )
                self.assertEqual(
                    len(hete_node[2].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )

        # edge
        hete_edge = hete.split(task='edge')
        for edge_type in hete.edge_label_index:
            num_edges = int(hete.edge_label_index[edge_type].shape[1])
            num_edges_reduced = num_edges - 3
            edge_0 = 1 + int(num_edges_reduced * 0.8)
            edge_1 = 1 + int(num_edges_reduced * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(
                hete_edge[0].edge_label_index[edge_type].shape[1],
                edge_0,
            )
            self.assertEqual(
                hete_edge[1].edge_label_index[edge_type].shape[1],
                edge_1,
            )
            self.assertEqual(
                hete_edge[2].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # edge with specified split type
        edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')]
        hete_edge = hete.split(task='edge', split_types=edge_split_types)
        for edge_type in hete.edge_label_index:
            if edge_type in edge_split_types:
                num_edges = int(hete.edge_label_index[edge_type].shape[1])
                num_edges_reduced = num_edges - 3
                edge_0 = 1 + int(num_edges_reduced * 0.8)
                edge_1 = 1 + int(num_edges_reduced * 0.1)
                edge_2 = num_edges - edge_0 - edge_1

                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    edge_0,
                )
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    edge_1,
                )
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    edge_2,
                )
            else:
                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )

        # link_pred
        hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2])
        # calculate the expected edge num for each splitted subgraph
        hete_link_train_edge_num = 0
        hete_link_val_edge_num = 0
        hete_link_test_edge_num = 0
        for key, val in hete.edge_label_index.items():
            val_length = val.shape[1]
            val_length_reduced = val_length - 3
            hete_link_train_edge_num += 1 + int(0.5 * val_length_reduced)
            hete_link_val_edge_num += 1 + int(0.3 * val_length_reduced)
            hete_link_test_edge_num += (val_length - 2 -
                                        int(0.5 * val_length_reduced) -
                                        int(0.3 * val_length_reduced))

        self.assertEqual(
            len(hete_link[0].edge_label),
            hete_link_train_edge_num,
        )
        self.assertEqual(
            len(hete_link[1].edge_label),
            hete_link_val_edge_num,
        )
        self.assertEqual(
            len(hete_link[2].edge_label),
            hete_link_test_edge_num,
        )
    def test_hetero_graph_split(self):
        # directed G
        G = generate_dense_hete_graph()
        hete = HeteroGraph(G)
        hete = HeteroGraph(
            node_feature=hete.node_feature,
            node_label=hete.node_label,
            edge_feature=hete.edge_feature,
            edge_label=hete.edge_label,
            edge_index=hete.edge_index,
        )

        # node
        hete_node = hete.split()
        for node_type in hete.node_label_index:
            num_nodes = len(hete.node_label_index[node_type])
            node_0 = int(num_nodes * 0.8)
            node_1 = int(num_nodes * 0.1)
            node_2 = num_nodes - node_0 - node_1
            self.assertEqual(
                len(hete_node[0].node_label_index[node_type]),
                node_0,
            )
            self.assertEqual(
                len(hete_node[1].node_label_index[node_type]),
                node_1,
            )
            self.assertEqual(
                len(hete_node[2].node_label_index[node_type]),
                node_2,
            )

        # node with specified split type
        node_split_types = ['n1']
        hete_node = hete.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = len(hete.node_label_index[node_type])
                node_0 = int(num_nodes * 0.8)
                node_1 = int(num_nodes * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(hete_node[0].node_label_index[node_type]),
                    node_0,
                )
                self.assertEqual(
                    len(hete_node[1].node_label_index[node_type]),
                    node_1,
                )
                self.assertEqual(
                    len(hete_node[2].node_label_index[node_type]),
                    node_2,
                )
            else:
                self.assertEqual(
                    len(hete_node[0].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )
                self.assertEqual(
                    len(hete_node[1].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )
                self.assertEqual(
                    len(hete_node[2].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )

        # edge
        hete_edge = hete.split(task='edge')
        for edge_type in hete.edge_label_index:
            num_edges = int(hete.edge_label_index[edge_type].shape[1])
            edge_0 = int(num_edges * 0.8)
            edge_1 = int(num_edges * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(
                hete_edge[0].edge_label_index[edge_type].shape[1],
                edge_0,
            )
            self.assertEqual(
                hete_edge[1].edge_label_index[edge_type].shape[1],
                edge_1,
            )
            self.assertEqual(
                hete_edge[2].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # edge with specified split type
        edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')]
        hete_edge = hete.split(task='edge', split_types=edge_split_types)
        for edge_type in hete.edge_label_index:
            if edge_type in edge_split_types:
                num_edges = int(hete.edge_label_index[edge_type].shape[1])
                edge_0 = int(num_edges * 0.8)
                edge_1 = int(num_edges * 0.1)
                edge_2 = num_edges - edge_0 - edge_1

                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    edge_0,
                )
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    edge_1,
                )
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    edge_2,
                )
            else:
                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )

        # link_pred
        hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2])
        for key, val in hete.edge_label_index.items():
            num_edges = val.shape[1]
            edge_0 = int(0.5 * num_edges)
            edge_1 = int(0.3 * num_edges)
            edge_2 = num_edges - edge_0 - edge_1

            self.assertEqual(hete_link[0].edge_label[key].shape[0], edge_0)
            self.assertEqual(hete_link[1].edge_label[key].shape[0], edge_1)
            self.assertEqual(hete_link[2].edge_label[key].shape[0], edge_2)

        # undirected G
        G = generate_dense_hete_graph(directed=False)
        hete = HeteroGraph(G)
        hete = HeteroGraph(node_feature=hete.node_feature,
                           node_label=hete.node_label,
                           edge_feature=hete.edge_feature,
                           edge_label=hete.edge_label,
                           edge_index=hete.edge_index,
                           directed=False)

        # node
        hete_node = hete.split()
        for node_type in hete.node_label_index:
            num_nodes = len(hete.node_label_index[node_type])
            node_0 = int(num_nodes * 0.8)
            node_1 = int(num_nodes * 0.1)
            node_2 = num_nodes - node_0 - node_1
            self.assertEqual(
                len(hete_node[0].node_label_index[node_type]),
                node_0,
            )
            self.assertEqual(
                len(hete_node[1].node_label_index[node_type]),
                node_1,
            )
            self.assertEqual(
                len(hete_node[2].node_label_index[node_type]),
                node_2,
            )

        # node with specified split type
        node_split_types = ['n1']
        hete_node = hete.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = len(hete.node_label_index[node_type])
                node_0 = int(num_nodes * 0.8)
                node_1 = int(num_nodes * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(hete_node[0].node_label_index[node_type]),
                    node_0,
                )
                self.assertEqual(
                    len(hete_node[1].node_label_index[node_type]),
                    node_1,
                )
                self.assertEqual(
                    len(hete_node[2].node_label_index[node_type]),
                    node_2,
                )
            else:
                self.assertEqual(
                    len(hete_node[0].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )
                self.assertEqual(
                    len(hete_node[1].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )
                self.assertEqual(
                    len(hete_node[2].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )

        # edge
        hete_edge = hete.split(task='edge')
        for edge_type in hete.edge_label_index:
            num_edges = int(hete.num_edges(edge_type))
            edge_0 = int(num_edges * 0.8)
            edge_1 = int(num_edges * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(
                hete_edge[0].edge_label_index[edge_type].shape[1],
                edge_0,
            )
            self.assertEqual(
                hete_edge[1].edge_label_index[edge_type].shape[1],
                edge_1,
            )
            self.assertEqual(
                hete_edge[2].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # edge with specified split type
        edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')]
        hete_edge = hete.split(task='edge', split_types=edge_split_types)
        for edge_type in hete.edge_label_index:
            if edge_type in edge_split_types:
                num_edges = int(hete.num_edges(edge_type))
                edge_0 = int(num_edges * 0.8)
                edge_1 = int(num_edges * 0.1)
                edge_2 = num_edges - edge_0 - edge_1

                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    edge_0,
                )
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    edge_1,
                )
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    edge_2,
                )
            else:
                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )

        # link_pred
        hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2])
        for key, val in hete.edge_label_index.items():
            num_edges = int(val.shape[1] / 2)
            edge_0 = 2 * int(0.5 * num_edges)
            edge_1 = 2 * int(0.3 * num_edges)
            edge_2 = 2 * (num_edges - int(0.5 * num_edges) -
                          int(0.3 * num_edges))

            self.assertEqual(hete_link[0].edge_label[key].shape[0], edge_0)
            self.assertEqual(hete_link[1].edge_label[key].shape[0], edge_1)
            self.assertEqual(hete_link[2].edge_label[key].shape[0], edge_2)
    def test_hetero_multigraph_split(self):
        G = generate_dense_hete_multigraph()
        hete = HeteroGraph(G)
        hete = HeteroGraph(node_feature=hete.node_feature,
                           node_label=hete.node_label,
                           edge_feature=hete.edge_feature,
                           edge_label=hete.edge_label,
                           edge_index=hete.edge_index,
                           directed=True)

        # node
        hete_node = hete.split(task='node')
        for node_type in hete.node_label_index:
            num_nodes = len(hete.node_label_index[node_type])
            node_0 = int(num_nodes * 0.8)
            node_1 = int(num_nodes * 0.1)
            node_2 = num_nodes - node_0 - node_1
            self.assertEqual(
                len(hete_node[0].node_label_index[node_type]),
                node_0,
            )
            self.assertEqual(
                len(hete_node[1].node_label_index[node_type]),
                node_1,
            )
            self.assertEqual(
                len(hete_node[2].node_label_index[node_type]),
                node_2,
            )

        # edge
        hete_edge = hete.split(task='edge')
        for edge_type in hete.edge_label_index:
            num_edges = int(hete.edge_label_index[edge_type].shape[1])
            edge_0 = int(num_edges * 0.8)
            edge_1 = int(num_edges * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(
                hete_edge[0].edge_label_index[edge_type].shape[1],
                edge_0,
            )
            self.assertEqual(
                hete_edge[1].edge_label_index[edge_type].shape[1],
                edge_1,
            )
            self.assertEqual(
                hete_edge[2].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # link prediction
        hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2])
        # calculate the expected edge num for each splitted subgraph
        edge_0, edge_1, edge_2 = 0, 0, 0
        for _, val in hete.edge_label_index.items():
            num_edges = val.shape[1]
            edge_0 += int(0.5 * num_edges)
            edge_1 += int(0.3 * num_edges)
            edge_2 += num_edges - int(0.5 * num_edges) - int(0.3 * num_edges)

        train_edge_num = sum([
            hete_link[0].edge_label[message_type].shape[0]
            for message_type in hete_link[0].edge_label
        ])
        val_edge_num = sum([
            hete_link[1].edge_label[message_type].shape[0]
            for message_type in hete_link[1].edge_label
        ])
        test_edge_num = sum([
            hete_link[2].edge_label[message_type].shape[0]
            for message_type in hete_link[2].edge_label
        ])
        self.assertEqual(train_edge_num, edge_0)
        self.assertEqual(val_edge_num, edge_1)
        self.assertEqual(test_edge_num, edge_2)
    def test_hetero_graph_split(self):
        # directed G
        G = generate_dense_hete_graph()
        hete = HeteroGraph(G)
        hete = HeteroGraph(
            node_feature=hete.node_feature,
            node_label=hete.node_label,
            edge_feature=hete.edge_feature,
            edge_label=hete.edge_label,
            edge_index=hete.edge_index,
        )

        # node
        hete_node = hete.split()
        for node_type in hete.node_label_index:
            num_nodes = len(hete.node_label_index[node_type])
            num_nodes_reduced = num_nodes - 3
            node_0 = 1 + int(num_nodes_reduced * 0.8)
            node_1 = 1 + int(num_nodes_reduced * 0.1)
            node_2 = num_nodes - node_0 - node_1
            self.assertEqual(
                len(hete_node[0].node_label_index[node_type]),
                node_0,
            )
            self.assertEqual(
                len(hete_node[1].node_label_index[node_type]),
                node_1,
            )
            self.assertEqual(
                len(hete_node[2].node_label_index[node_type]),
                node_2,
            )

        # node with specified split type
        node_split_types = ['n1']
        hete_node = hete.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = len(hete.node_label_index[node_type])
                num_nodes_reduced = num_nodes - 3
                node_0 = 1 + int(num_nodes_reduced * 0.8)
                node_1 = 1 + int(num_nodes_reduced * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(hete_node[0].node_label_index[node_type]),
                    node_0,
                )
                self.assertEqual(
                    len(hete_node[1].node_label_index[node_type]),
                    node_1,
                )
                self.assertEqual(
                    len(hete_node[2].node_label_index[node_type]),
                    node_2,
                )
            else:
                self.assertEqual(
                    len(hete_node[0].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )
                self.assertEqual(
                    len(hete_node[1].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )
                self.assertEqual(
                    len(hete_node[2].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )

        # edge
        hete_edge = hete.split(task='edge')
        for edge_type in hete.edge_label_index:
            num_edges = int(hete.edge_label_index[edge_type].shape[1])
            num_edges_reduced = num_edges - 3
            edge_0 = 1 + int(num_edges_reduced * 0.8)
            edge_1 = 1 + int(num_edges_reduced * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(
                hete_edge[0].edge_label_index[edge_type].shape[1],
                edge_0,
            )
            self.assertEqual(
                hete_edge[1].edge_label_index[edge_type].shape[1],
                edge_1,
            )
            self.assertEqual(
                hete_edge[2].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # edge with specified split type
        edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')]
        hete_edge = hete.split(task='edge', split_types=edge_split_types)
        for edge_type in hete.edge_label_index:
            if edge_type in edge_split_types:
                num_edges = int(hete.edge_label_index[edge_type].shape[1])
                num_edges_reduced = num_edges - 3
                edge_0 = 1 + int(num_edges_reduced * 0.8)
                edge_1 = 1 + int(num_edges_reduced * 0.1)
                edge_2 = num_edges - edge_0 - edge_1

                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    edge_0,
                )
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    edge_1,
                )
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    edge_2,
                )
            else:
                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )

        # link_pred
        hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2])
        for key, val in hete.edge_label_index.items():
            val_length = val.shape[1]
            val_length_reduced = val_length - 3
            hete_link_train_edge_num = 1 + int(0.5 * val_length_reduced)
            hete_link_val_edge_num = 1 + int(0.3 * val_length_reduced)
            hete_link_test_edge_num = (
                val_length
                - 2
                - int(0.5 * val_length_reduced)
                - int(0.3 * val_length_reduced)
            )

            self.assertEqual(
                hete_link[0].edge_label[key].shape[0], hete_link_train_edge_num
            )
            self.assertEqual(
                hete_link[1].edge_label[key].shape[0], hete_link_val_edge_num
            )
            self.assertEqual(
                hete_link[2].edge_label[key].shape[0], hete_link_test_edge_num
            )

        # undirected G
        G = generate_dense_hete_graph(directed=False)
        hete = HeteroGraph(G)
        hete = HeteroGraph(
            node_feature=hete.node_feature,
            node_label=hete.node_label,
            edge_feature=hete.edge_feature,
            edge_label=hete.edge_label,
            edge_index=hete.edge_index,
            directed=False
        )

        # node
        hete_node = hete.split()
        for node_type in hete.node_label_index:
            num_nodes = len(hete.node_label_index[node_type])
            num_nodes_reduced = num_nodes - 3
            node_0 = 1 + int(num_nodes_reduced * 0.8)
            node_1 = 1 + int(num_nodes_reduced * 0.1)
            node_2 = num_nodes - node_0 - node_1
            self.assertEqual(
                len(hete_node[0].node_label_index[node_type]),
                node_0,
            )
            self.assertEqual(
                len(hete_node[1].node_label_index[node_type]),
                node_1,
            )
            self.assertEqual(
                len(hete_node[2].node_label_index[node_type]),
                node_2,
            )

        # node with specified split type
        node_split_types = ['n1']
        hete_node = hete.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = len(hete.node_label_index[node_type])
                num_nodes_reduced = num_nodes - 3
                node_0 = 1 + int(num_nodes_reduced * 0.8)
                node_1 = 1 + int(num_nodes_reduced * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(hete_node[0].node_label_index[node_type]),
                    node_0,
                )
                self.assertEqual(
                    len(hete_node[1].node_label_index[node_type]),
                    node_1,
                )
                self.assertEqual(
                    len(hete_node[2].node_label_index[node_type]),
                    node_2,
                )
            else:
                self.assertEqual(
                    len(hete_node[0].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )
                self.assertEqual(
                    len(hete_node[1].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )
                self.assertEqual(
                    len(hete_node[2].node_label_index[node_type]),
                    len(hete.node_label_index[node_type]),
                )

        # edge
        hete_edge = hete.split(task='edge')
        for edge_type in hete.edge_label_index:
            num_edges = int(hete.edge_label_index[edge_type].shape[1])
            num_edges_reduced = num_edges - 3
            edge_0 = 1 + int(num_edges_reduced * 0.8)
            edge_1 = 1 + int(num_edges_reduced * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(
                hete_edge[0].edge_label_index[edge_type].shape[1],
                edge_0,
            )
            self.assertEqual(
                hete_edge[1].edge_label_index[edge_type].shape[1],
                edge_1,
            )
            self.assertEqual(
                hete_edge[2].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # edge with specified split type
        edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')]
        hete_edge = hete.split(task='edge', split_types=edge_split_types)
        for edge_type in hete.edge_label_index:
            if edge_type in edge_split_types:
                num_edges = int(hete.edge_label_index[edge_type].shape[1])
                num_edges_reduced = num_edges - 3
                edge_0 = 1 + int(num_edges_reduced * 0.8)
                edge_1 = 1 + int(num_edges_reduced * 0.1)
                edge_2 = num_edges - edge_0 - edge_1

                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    edge_0,
                )
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    edge_1,
                )
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    edge_2,
                )
            else:
                self.assertEqual(
                    hete_edge[0].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )
                self.assertEqual(
                    hete_edge[1].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )
                self.assertEqual(
                    hete_edge[2].edge_label_index[edge_type].shape[1],
                    hete.edge_label_index[edge_type].shape[1],
                )

        # link_pred
        hete_link = hete.split(task='link_pred', split_ratio=[0.5, 0.3, 0.2])
        for key, val in hete.edge_label_index.items():
            val_length = val.shape[1]
            hete_link_train_edge_num = (
                2 * (1 + int(0.5 * (int(val_length / 2) - 3)))
            )
            hete_link_val_edge_num = (
                2 * (1 + int(0.3 * (int(val_length / 2) - 3)))
            )
            hete_link_test_edge_num = (
                val_length - hete_link_train_edge_num - hete_link_val_edge_num
            )

            self.assertEqual(
                hete_link[0].edge_label[key].shape[0], hete_link_train_edge_num
            )
            self.assertEqual(
                hete_link[1].edge_label[key].shape[0], hete_link_val_edge_num
            )
            self.assertEqual(
                hete_link[2].edge_label[key].shape[0], hete_link_test_edge_num
            )