def test_resample_disjoint_heterogeneous(self):
        G = generate_dense_hete_dataset()
        hete = HeteroGraph(G)
        hete = HeteroGraph(node_feature=hete.node_feature,
                           node_label=hete.node_label,
                           edge_feature=hete.edge_feature,
                           edge_label=hete.edge_label,
                           edge_index=hete.edge_index,
                           directed=True)
        graphs = [hete]
        dataset = GraphDataset(graphs,
                               task="link_pred",
                               edge_train_mode="disjoint",
                               edge_message_ratio=0.8,
                               resample_disjoint=True,
                               resample_disjoint_period=1)
        dataset_train, _, _ = dataset.split(split_ratio=[0.5, 0.2, 0.3])
        graph_train_first = dataset_train[0]
        graph_train_second = dataset_train[0]

        for message_type in graph_train_first.edge_index:
            self.assertEqual(
                graph_train_first.edge_label_index[message_type].shape[1],
                graph_train_second.edge_label_index[message_type].shape[1])
            self.assertEqual(graph_train_first.edge_label[message_type].shape,
                             graph_train_second.edge_label[message_type].shape)
    def test_dataset_hetero_graph_split(self):
        G = generate_dense_hete_dataset()
        hete = HeteroGraph(G)
        hete = HeteroGraph(node_feature=hete.node_feature,
                           node_label=hete.node_label,
                           edge_feature=hete.edge_feature,
                           edge_label=hete.edge_label,
                           edge_index=hete.edge_index,
                           directed=True)

        # node
        dataset = GraphDataset([hete], task="node")
        split_res = dataset.split()
        for node_type in hete.node_label_index:
            num_nodes = int(len(hete.node_label_index[node_type]))
            node_0 = int(num_nodes * 0.8)
            node_1 = int(num_nodes * 0.1)
            node_2 = num_nodes - node_0 - node_1

            self.assertEqual(
                len(split_res[0][0].node_label_index[node_type]),
                node_0,
            )

            self.assertEqual(
                len(split_res[1][0].node_label_index[node_type]),
                node_1,
            )

            self.assertEqual(
                len(split_res[2][0].node_label_index[node_type]),
                node_2,
            )

        # node with specified split type
        dataset = GraphDataset([hete], task="node")
        node_split_types = ["n1"]
        split_res = dataset.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = int(len(hete.node_label_index[node_type]))
                node_0 = int(num_nodes * 0.8)
                node_1 = int(num_nodes * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]),
                    node_0,
                )

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]),
                    node_1,
                )

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]),
                    node_2,
                )
            else:
                num_nodes = int(len(hete.node_label_index[node_type]))
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]),
                    num_nodes,
                )

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]),
                    num_nodes,
                )

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]),
                    num_nodes,
                )

        # node with specified split type (string mode)
        dataset = GraphDataset([hete], task="node")
        node_split_types = "n1"
        split_res = dataset.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = int(len(hete.node_label_index[node_type]))
                node_0 = int(num_nodes * 0.8)
                node_1 = int(num_nodes * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]),
                    node_0,
                )

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]),
                    node_1,
                )

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]),
                    node_2,
                )
            else:
                num_nodes = int(len(hete.node_label_index[node_type]))
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]),
                    num_nodes,
                )

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]),
                    num_nodes,
                )

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]),
                    num_nodes,
                )

        # edge
        dataset = GraphDataset([hete], task="edge")
        split_res = dataset.split()
        for edge_type in hete.edge_label_index:
            num_edges = hete.edge_label_index[edge_type].shape[1]
            edge_0 = int(num_edges * 0.8)
            edge_1 = int(num_edges * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(
                split_res[0][0].edge_label_index[edge_type].shape[1],
                edge_0,
            )

            self.assertEqual(
                split_res[1][0].edge_label_index[edge_type].shape[1],
                edge_1,
            )

            self.assertEqual(
                split_res[2][0].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # edge with specified split type
        dataset = GraphDataset([hete], task="edge")
        edge_split_types = [("n1", "e1", "n1"), ("n1", "e2", "n2")]
        split_res = dataset.split(split_types=edge_split_types)
        for edge_type in hete.edge_label_index:
            if edge_type in edge_split_types:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                edge_0 = int(num_edges * 0.8)
                edge_1 = int(num_edges * 0.1)
                edge_2 = num_edges - edge_0 - edge_1
                self.assertEqual(
                    split_res[0][0].edge_label_index[edge_type].shape[1],
                    edge_0,
                )

                self.assertEqual(
                    split_res[1][0].edge_label_index[edge_type].shape[1],
                    edge_1,
                )

                self.assertEqual(
                    split_res[2][0].edge_label_index[edge_type].shape[1],
                    edge_2,
                )
            else:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                self.assertEqual(
                    split_res[0][0].edge_label_index[edge_type].shape[1],
                    num_edges,
                )

                self.assertEqual(
                    split_res[1][0].edge_label_index[edge_type].shape[1],
                    num_edges,
                )

                self.assertEqual(
                    split_res[2][0].edge_label_index[edge_type].shape[1],
                    num_edges,
                )

        # link_pred
        dataset = GraphDataset([hete], task="link_pred")
        split_res = dataset.split(transductive=True)
        for edge_type in hete.edge_label_index:
            num_edges = hete.edge_label_index[edge_type].shape[1]
            edge_0 = 2 * int(0.8 * num_edges)
            edge_1 = 2 * int(0.1 * num_edges)
            edge_2 = 2 * (num_edges - int(0.8 * num_edges) -
                          int(0.1 * num_edges))
            self.assertEqual(
                split_res[0][0].edge_label_index[edge_type].shape[1], edge_0)
            self.assertEqual(
                split_res[1][0].edge_label_index[edge_type].shape[1], edge_1)
            self.assertEqual(
                split_res[2][0].edge_label_index[edge_type].shape[1], edge_2)

        # link_pred with specified split type
        dataset = GraphDataset([hete], task="link_pred")
        link_split_types = [("n1", "e1", "n1"), ("n1", "e2", "n2")]
        split_res = dataset.split(transductive=True,
                                  split_types=link_split_types)

        for edge_type in hete.edge_label_index:
            if edge_type in link_split_types:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                edge_0 = 2 * int(0.8 * num_edges)
                edge_1 = 2 * int(0.1 * num_edges)
                edge_2 = 2 * (num_edges - int(0.8 * num_edges) -
                              int(0.1 * num_edges))
                self.assertEqual(
                    split_res[0][0].edge_label_index[edge_type].shape[1],
                    edge_0)
                self.assertEqual(
                    split_res[1][0].edge_label_index[edge_type].shape[1],
                    edge_1)
                self.assertEqual(
                    split_res[2][0].edge_label_index[edge_type].shape[1],
                    edge_2)
            else:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                self.assertEqual(
                    split_res[0][0].edge_label_index[edge_type].shape[1],
                    num_edges)
                self.assertEqual(
                    split_res[1][0].edge_label_index[edge_type].shape[1],
                    num_edges)
                self.assertEqual(
                    split_res[2][0].edge_label_index[edge_type].shape[1],
                    num_edges)

        # link_pred + disjoint
        dataset = GraphDataset(
            [hete],
            task="link_pred",
            edge_train_mode="disjoint",
            edge_message_ratio=0.5,
        )
        split_res = dataset.split(
            transductive=True,
            split_ratio=[0.6, 0.2, 0.2],
        )
        for edge_type in hete.edge_label_index:
            num_edges = hete.edge_label_index[edge_type].shape[1]
            edge_0 = int(0.6 * num_edges)
            edge_0 = 2 * (edge_0 - int(0.5 * edge_0))
            edge_1 = 2 * int(0.2 * num_edges)
            edge_2 = 2 * (num_edges - int(0.6 * num_edges) -
                          int(0.2 * num_edges))

            self.assertEqual(
                split_res[0][0].edge_label_index[edge_type].shape[1],
                edge_0,
            )
            self.assertEqual(
                split_res[1][0].edge_label_index[edge_type].shape[1],
                edge_1,
            )
            self.assertEqual(
                split_res[2][0].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # link pred with edge_split_mode set to "exact"
        dataset = GraphDataset([hete],
                               task="link_pred",
                               edge_split_mode="approximate")
        split_res = dataset.split(transductive=True)
        hete_link_train_edge_num = 0
        hete_link_test_edge_num = 0
        hete_link_val_edge_num = 0
        num_edges = 0
        for edge_type in hete.edge_label_index:
            num_edges += hete.edge_label_index[edge_type].shape[1]
            if edge_type in split_res[0][0].edge_label_index:
                hete_link_train_edge_num += (
                    split_res[0][0].edge_label_index[edge_type].shape[1])
            if edge_type in split_res[1][0].edge_label_index:
                hete_link_test_edge_num += (
                    split_res[1][0].edge_label_index[edge_type].shape[1])
            if edge_type in split_res[2][0].edge_label_index:
                hete_link_val_edge_num += (
                    split_res[2][0].edge_label_index[edge_type].shape[1])

        # num_edges_reduced = num_edges - 3
        edge_0 = 2 * int(0.8 * num_edges)
        edge_1 = 2 * int(0.1 * num_edges)
        edge_2 = 2 * (num_edges - int(0.8 * num_edges) - int(0.1 * num_edges))

        self.assertEqual(hete_link_train_edge_num, edge_0)
        self.assertEqual(hete_link_test_edge_num, edge_1)
        self.assertEqual(hete_link_val_edge_num, edge_2)
        # link pred with specified types and edge_split_mode set to "exact"
        dataset = GraphDataset(
            [hete],
            task="link_pred",
            edge_split_mode="approximate",
        )
        link_split_types = [("n1", "e1", "n1"), ("n1", "e2", "n2")]
        split_res = dataset.split(
            transductive=True,
            split_types=link_split_types,
        )
        hete_link_train_edge_num = 0
        hete_link_test_edge_num = 0
        hete_link_val_edge_num = 0

        num_split_type_edges = 0
        num_non_split_type_edges = 0
        for edge_type in hete.edge_label_index:
            if edge_type in link_split_types:
                num_split_type_edges += (
                    hete.edge_label_index[edge_type].shape[1])
            else:
                num_non_split_type_edges += (
                    hete.edge_label_index[edge_type].shape[1])
            if edge_type in split_res[0][0].edge_label_index:
                hete_link_train_edge_num += (
                    split_res[0][0].edge_label_index[edge_type].shape[1])
            if edge_type in split_res[1][0].edge_label_index:
                hete_link_test_edge_num += (
                    split_res[1][0].edge_label_index[edge_type].shape[1])
            if edge_type in split_res[2][0].edge_label_index:
                hete_link_val_edge_num += (
                    split_res[2][0].edge_label_index[edge_type].shape[1])

        # num_edges_reduced = num_split_type_edges - 3
        num_edges = num_split_type_edges
        edge_0 = 2 * int(0.8 * num_edges) + num_non_split_type_edges
        edge_1 = 2 * int(0.1 * num_edges) + num_non_split_type_edges
        edge_2 = 2 * (num_edges - int(0.8 * num_edges) -
                      int(0.1 * num_edges)) + num_non_split_type_edges

        self.assertEqual(hete_link_train_edge_num, edge_0)
        self.assertEqual(hete_link_test_edge_num, edge_1)
        self.assertEqual(hete_link_val_edge_num, edge_2)