예제 #1
0
def deepsnap_ego(args, pyg_dataset):
    avg_time = 0
    task = "graph"
    for i in range(args.num_runs):
        if args.print_run:
            print("Run {}".format(i + 1))
        graphs = GraphDataset.pyg_to_graphs(pyg_dataset,
                                            verbose=True,
                                            netlib=netlib)
        dataset = GraphDataset(graphs, task=task)
        datasets = {}
        datasets['train'], datasets['val'], datasets['test'] = dataset.split(
            transductive=False, split_ratio=[0.8, 0.1, 0.1], shuffle=False)
        dataloaders = {
            split: DataLoader(dataset,
                              collate_fn=Batch.collate(),
                              batch_size=1,
                              shuffle=False)
            for split, dataset in datasets.items()
        }
        s = time.time()
        for batch in dataloaders['train']:
            batch = batch.apply_transform(ego_nets, update_tensor=True)
        avg_time += (time.time() - s)
    print("DeepSNAP has average time: {}".format(avg_time / args.num_runs))
    def test_resample_disjoint(self):
        pyg_dataset = Planetoid("./cora", "Cora")
        graphs = GraphDataset.pyg_to_graphs(pyg_dataset)
        graph = graphs[0]
        graph = Graph(node_label=graph.node_label,
                      node_feature=graph.node_feature,
                      edge_index=graph.edge_index,
                      edge_feature=graph.edge_feature,
                      directed=False)
        graphs = [graph]
        dataset = GraphDataset(graphs,
                               task="link_pred",
                               edge_train_mode="disjoint",
                               edge_message_ratio=0.8,
                               resample_disjoint=True,
                               resample_disjoint_period=1)
        dataset_train, _, _ = dataset.split(split_ratio=[0.5, 0.2, 0.3])
        graph_train_first = dataset_train[0]
        graph_train_second = dataset_train[0]

        self.assertEqual(graph_train_first.edge_label_index.shape[1],
                         graph_train_second.edge_label_index.shape[1])
        self.assertTrue(
            torch.equal(graph_train_first.edge_label,
                        graph_train_second.edge_label))
예제 #3
0
 def test_torch_dataloader_collate(self):
     # graph classification example
     pyg_dataset = TUDataset('./enzymes', 'ENZYMES')
     graphs = GraphDataset.pyg_to_graphs(pyg_dataset)
     dataset = GraphDataset(graphs, task="graph")
     train_batch_num = math.ceil(len(dataset) * 0.8 / 32)
     test_batch_num = math.ceil(len(dataset) * 0.1 / 32)
     val_batch_num = math.ceil(len(dataset) * 0.1 / 32)
     datasets = {}
     datasets['train'], datasets['val'], datasets['test'] = \
         dataset.split(transductive=False, split_ratio=[0.8, 0.1, 0.1])
     dataloaders = {
         split: DataLoader(dataset,
                           collate_fn=Batch.collate(),
                           batch_size=32,
                           shuffle=True)
         for split, dataset in datasets.items()
     }
     self.assertEqual(len(dataloaders['train']), train_batch_num)
     self.assertEqual(len(dataloaders['val']), test_batch_num)
     self.assertEqual(len(dataloaders['test']), val_batch_num)
     for i, data in enumerate(dataloaders['train']):
         if i != len(dataloaders['train']) - 1:
             self.assertEqual(data.num_graphs, 32)
     for i, data in enumerate(dataloaders['val']):
         if i != len(dataloaders['val']) - 1:
             self.assertEqual(data.num_graphs, 32)
     for i, data in enumerate(dataloaders['test']):
         if i != len(dataloaders['test']) - 1:
             self.assertEqual(data.num_graphs, 32)
예제 #4
0
def main():
    args = arg_parse()

    edge_train_mode = args.mode
    print('edge train mode: {}'.format(edge_train_mode))

    G = nx.read_gpickle(args.data_path)
    print(G.number_of_edges())
    print('Each node has node ID (n_id). Example: ', G.nodes[0])
    print(
        'Each edge has edge ID (id) and categorical label (e_label). Example: ',
        G[0][5871])

    # find num edge types
    max_label = 0
    labels = []
    for u, v, edge_key in G.edges:
        l = G[u][v][edge_key]['e_label']
        if not l in labels:
            labels.append(l)
    # labels are consecutive (0-17)
    num_edge_types = len(labels)

    H = WN_transform(G, num_edge_types)
    # The nodes in the graph have the features: node_feature and node_type (just one node type "n1" here)
    for node in H.nodes(data=True):
        print(node)
        break
    # The edges in the graph have the features: edge_feature and edge_type ("0" - "17" here)
    for edge in H.edges(data=True):
        print(edge)
        break

    hete = HeteroGraph(H)

    dataset = GraphDataset([hete], task='link_pred')
    dataset_train, dataset_val, dataset_test = dataset.split(
        transductive=True, split_ratio=[0.8, 0.1, 0.1])
    train_loader = DataLoader(dataset_train,
                              collate_fn=Batch.collate(),
                              batch_size=1)
    val_loader = DataLoader(dataset_val,
                            collate_fn=Batch.collate(),
                            batch_size=1)
    test_loader = DataLoader(dataset_test,
                             collate_fn=Batch.collate(),
                             batch_size=1)
    dataloaders = {
        'train': train_loader,
        'val': val_loader,
        'test': test_loader
    }

    hidden_size = 32
    model = HeteroNet(hete, hidden_size, 0.2).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=0.001,
                                 weight_decay=5e-4)

    train(model, dataloaders, optimizer, args)
    def test_resample_disjoint_heterogeneous(self):
        G = generate_dense_hete_dataset()
        hete = HeteroGraph(G)
        hete = HeteroGraph(node_feature=hete.node_feature,
                           node_label=hete.node_label,
                           edge_feature=hete.edge_feature,
                           edge_label=hete.edge_label,
                           edge_index=hete.edge_index,
                           directed=True)
        graphs = [hete]
        dataset = GraphDataset(graphs,
                               task="link_pred",
                               edge_train_mode="disjoint",
                               edge_message_ratio=0.8,
                               resample_disjoint=True,
                               resample_disjoint_period=1)
        dataset_train, _, _ = dataset.split(split_ratio=[0.5, 0.2, 0.3])
        graph_train_first = dataset_train[0]
        graph_train_second = dataset_train[0]

        for message_type in graph_train_first.edge_index:
            self.assertEqual(
                graph_train_first.edge_label_index[message_type].shape[1],
                graph_train_second.edge_label_index[message_type].shape[1])
            self.assertEqual(graph_train_first.edge_label[message_type].shape,
                             graph_train_second.edge_label[message_type].shape)
예제 #6
0
    def test_pyg_to_graphs_global(self):
        import deepsnap
        deepsnap.use(nx)

        pyg_dataset = Planetoid('./planetoid', "Cora")
        graphs = GraphDataset.pyg_to_graphs(pyg_dataset)
        self.assertTrue(isinstance(graphs[0].G, nx.Graph))
        dataset = GraphDataset(graphs, task='node')
        num_nodes = dataset.num_nodes[0]
        node_0 = int(0.8 * num_nodes)
        node_1 = int(0.1 * num_nodes)
        node_2 = num_nodes - node_0 - node_1
        train, val, test = dataset.split()
        self.assertTrue(isinstance(train[0].G, nx.Graph))
        self.assertTrue(isinstance(val[0].G, nx.Graph))
        self.assertTrue(isinstance(test[0].G, nx.Graph))
        self.assertEqual(train[0].node_label_index.shape[0], node_0)
        self.assertEqual(val[0].node_label_index.shape[0], node_1)
        self.assertEqual(test[0].node_label_index.shape[0], node_2)

        train_loader = DataLoader(train,
                                  collate_fn=Batch.collate(),
                                  batch_size=1)
        for batch in train_loader:
            self.assertTrue(isinstance(batch.G[0], nx.Graph))

        deepsnap.use(sx)
        graphs = GraphDataset.pyg_to_graphs(pyg_dataset)
        self.assertTrue(isinstance(graphs[0].G, sx.Graph))
        dataset = GraphDataset(graphs, task='node')
        num_nodes = dataset.num_nodes[0]
        node_0 = int(0.8 * num_nodes)
        node_1 = int(0.1 * num_nodes)
        node_2 = num_nodes - node_0 - node_1
        train, val, test = dataset.split()
        self.assertTrue(isinstance(train[0].G, sx.Graph))
        self.assertTrue(isinstance(val[0].G, sx.classes.graph.Graph))
        self.assertTrue(isinstance(test[0].G, sx.classes.graph.Graph))
        self.assertEqual(train[0].node_label_index.shape[0], node_0)
        self.assertEqual(val[0].node_label_index.shape[0], node_1)
        self.assertEqual(test[0].node_label_index.shape[0], node_2)

        train_loader = DataLoader(train,
                                  collate_fn=Batch.collate(),
                                  batch_size=1)
        for batch in train_loader:
            self.assertTrue(isinstance(batch.G[0], sx.Graph))
예제 #7
0
def main():
    args = arg_parse()

    pyg_dataset = Planetoid('./cora', 'Cora', transform=T.TargetIndegree())
    
    # the input that we assume users have
    edge_train_mode = args.mode
    print('edge train mode: {}'.format(edge_train_mode))

    graphs = GraphDataset.pyg_to_graphs(pyg_dataset, tensor_backend=True)
    if args.multigraph:
        graphs = [copy.deepcopy(graphs[0]) for _ in range(10)]

    dataset = GraphDataset(graphs, 
                           task='link_pred', 
                           edge_message_ratio=args.edge_message_ratio, 
                           edge_train_mode=edge_train_mode)
    print('Initial dataset: {}'.format(dataset))

    # split dataset
    datasets = {}
    datasets['train'], datasets['val'], datasets['test']= dataset.split(
            transductive=not args.multigraph, split_ratio=[0.85, 0.05, 0.1])

    print('after split')
    print('Train message-passing graph: {} nodes; {} edges.'.format(
            datasets['train'][0].num_nodes,
            datasets['train'][0].num_edges))
    print('Val message-passing graph: {} nodes; {} edges.'.format(
            datasets['val'][0].num_nodes,
            datasets['val'][0].num_edges))
    print('Test message-passing graph: {} nodes; {} edges.'.format(
            datasets['test'][0].num_nodes,
            datasets['test'][0].num_edges))


    # node feature dimension
    input_dim = datasets['train'].num_node_features
    # link prediction needs 2 classes (0, 1)
    num_classes = datasets['train'].num_edge_labels

    model = Net(input_dim, num_classes, args).to(args.device)
    #optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-3)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
    follow_batch = [] # e.g., follow_batch = ['edge_index']

    dataloaders = {split: DataLoader(
            ds, collate_fn=Batch.collate(follow_batch), 
            batch_size=args.batch_size, shuffle=(split=='train'))
            for split, ds in datasets.items()}
    print('Graphs after split: ')
    for key, dataloader in dataloaders.items():
        for batch in dataloader:
            print(key, ': ', batch)

    train(model, dataloaders, optimizer, args, scheduler=scheduler)
예제 #8
0
def create_dataset():
    ## Load dataset
    time1 = time.time()
    if cfg.dataset.format == 'OGB':
        graphs, splits = load_dataset()
    else:
        graphs = load_dataset()

    ## Filter graphs
    time2 = time.time()
    min_node = filter_graphs()

    ## Create whole dataset
    if type(graphs) is GraphDataset:
        dataset = graphs
    else:
        dataset = GraphDataset(
            graphs,
            task=cfg.dataset.task,
            edge_train_mode=cfg.dataset.edge_train_mode,
            edge_message_ratio=cfg.dataset.edge_message_ratio,
            edge_negative_sampling_ratio=cfg.dataset.
            edge_negative_sampling_ratio,
            resample_disjoint=cfg.dataset.resample_disjoint,
            minimum_node_per_graph=min_node)

    ## Transform the whole dataset
    dataset = transform_before_split(dataset)

    ## Split dataset
    time3 = time.time()
    # Use custom data splits
    if cfg.dataset.format == 'OGB':
        datasets = []
        datasets.append(dataset[splits['train']])
        datasets.append(dataset[splits['valid']])
        datasets.append(dataset[splits['test']])
    # Use random split, supported by DeepSNAP
    else:
        datasets = dataset.split(transductive=cfg.dataset.transductive,
                                 split_ratio=cfg.dataset.split)
    # We only change the training negative sampling ratio
    for i in range(1, len(datasets)):
        dataset.edge_negative_sampling_ratio = 1

    ## Transform each split dataset
    time4 = time.time()
    datasets = transform_after_split(datasets)

    time5 = time.time()
    logging.info('Load: {:.4}s, Before split: {:.4}s, '
                 'Split: {:.4}s, After split: {:.4}s'.format(
                     time2 - time1, time3 - time2, time4 - time3,
                     time5 - time4))

    return datasets
    def test_dataset_hetero_graph_split(self):
        G = generate_dense_hete_dataset()
        hete = HeteroGraph(G)
        hete = HeteroGraph(node_feature=hete.node_feature,
                           node_label=hete.node_label,
                           edge_feature=hete.edge_feature,
                           edge_label=hete.edge_label,
                           edge_index=hete.edge_index,
                           directed=True)

        # node
        dataset = GraphDataset([hete], task="node")
        split_res = dataset.split()
        for node_type in hete.node_label_index:
            num_nodes = int(len(hete.node_label_index[node_type]))
            node_0 = int(num_nodes * 0.8)
            node_1 = int(num_nodes * 0.1)
            node_2 = num_nodes - node_0 - node_1

            self.assertEqual(
                len(split_res[0][0].node_label_index[node_type]),
                node_0,
            )

            self.assertEqual(
                len(split_res[1][0].node_label_index[node_type]),
                node_1,
            )

            self.assertEqual(
                len(split_res[2][0].node_label_index[node_type]),
                node_2,
            )

        # node with specified split type
        dataset = GraphDataset([hete], task="node")
        node_split_types = ["n1"]
        split_res = dataset.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = int(len(hete.node_label_index[node_type]))
                node_0 = int(num_nodes * 0.8)
                node_1 = int(num_nodes * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]),
                    node_0,
                )

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]),
                    node_1,
                )

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]),
                    node_2,
                )
            else:
                num_nodes = int(len(hete.node_label_index[node_type]))
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]),
                    num_nodes,
                )

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]),
                    num_nodes,
                )

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]),
                    num_nodes,
                )

        # node with specified split type (string mode)
        dataset = GraphDataset([hete], task="node")
        node_split_types = "n1"
        split_res = dataset.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = int(len(hete.node_label_index[node_type]))
                node_0 = int(num_nodes * 0.8)
                node_1 = int(num_nodes * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]),
                    node_0,
                )

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]),
                    node_1,
                )

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]),
                    node_2,
                )
            else:
                num_nodes = int(len(hete.node_label_index[node_type]))
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]),
                    num_nodes,
                )

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]),
                    num_nodes,
                )

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]),
                    num_nodes,
                )

        # edge
        dataset = GraphDataset([hete], task="edge")
        split_res = dataset.split()
        for edge_type in hete.edge_label_index:
            num_edges = hete.edge_label_index[edge_type].shape[1]
            edge_0 = int(num_edges * 0.8)
            edge_1 = int(num_edges * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(
                split_res[0][0].edge_label_index[edge_type].shape[1],
                edge_0,
            )

            self.assertEqual(
                split_res[1][0].edge_label_index[edge_type].shape[1],
                edge_1,
            )

            self.assertEqual(
                split_res[2][0].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # edge with specified split type
        dataset = GraphDataset([hete], task="edge")
        edge_split_types = [("n1", "e1", "n1"), ("n1", "e2", "n2")]
        split_res = dataset.split(split_types=edge_split_types)
        for edge_type in hete.edge_label_index:
            if edge_type in edge_split_types:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                edge_0 = int(num_edges * 0.8)
                edge_1 = int(num_edges * 0.1)
                edge_2 = num_edges - edge_0 - edge_1
                self.assertEqual(
                    split_res[0][0].edge_label_index[edge_type].shape[1],
                    edge_0,
                )

                self.assertEqual(
                    split_res[1][0].edge_label_index[edge_type].shape[1],
                    edge_1,
                )

                self.assertEqual(
                    split_res[2][0].edge_label_index[edge_type].shape[1],
                    edge_2,
                )
            else:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                self.assertEqual(
                    split_res[0][0].edge_label_index[edge_type].shape[1],
                    num_edges,
                )

                self.assertEqual(
                    split_res[1][0].edge_label_index[edge_type].shape[1],
                    num_edges,
                )

                self.assertEqual(
                    split_res[2][0].edge_label_index[edge_type].shape[1],
                    num_edges,
                )

        # link_pred
        dataset = GraphDataset([hete], task="link_pred")
        split_res = dataset.split(transductive=True)
        for edge_type in hete.edge_label_index:
            num_edges = hete.edge_label_index[edge_type].shape[1]
            edge_0 = 2 * int(0.8 * num_edges)
            edge_1 = 2 * int(0.1 * num_edges)
            edge_2 = 2 * (num_edges - int(0.8 * num_edges) -
                          int(0.1 * num_edges))
            self.assertEqual(
                split_res[0][0].edge_label_index[edge_type].shape[1], edge_0)
            self.assertEqual(
                split_res[1][0].edge_label_index[edge_type].shape[1], edge_1)
            self.assertEqual(
                split_res[2][0].edge_label_index[edge_type].shape[1], edge_2)

        # link_pred with specified split type
        dataset = GraphDataset([hete], task="link_pred")
        link_split_types = [("n1", "e1", "n1"), ("n1", "e2", "n2")]
        split_res = dataset.split(transductive=True,
                                  split_types=link_split_types)

        for edge_type in hete.edge_label_index:
            if edge_type in link_split_types:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                edge_0 = 2 * int(0.8 * num_edges)
                edge_1 = 2 * int(0.1 * num_edges)
                edge_2 = 2 * (num_edges - int(0.8 * num_edges) -
                              int(0.1 * num_edges))
                self.assertEqual(
                    split_res[0][0].edge_label_index[edge_type].shape[1],
                    edge_0)
                self.assertEqual(
                    split_res[1][0].edge_label_index[edge_type].shape[1],
                    edge_1)
                self.assertEqual(
                    split_res[2][0].edge_label_index[edge_type].shape[1],
                    edge_2)
            else:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                self.assertEqual(
                    split_res[0][0].edge_label_index[edge_type].shape[1],
                    num_edges)
                self.assertEqual(
                    split_res[1][0].edge_label_index[edge_type].shape[1],
                    num_edges)
                self.assertEqual(
                    split_res[2][0].edge_label_index[edge_type].shape[1],
                    num_edges)

        # link_pred + disjoint
        dataset = GraphDataset(
            [hete],
            task="link_pred",
            edge_train_mode="disjoint",
            edge_message_ratio=0.5,
        )
        split_res = dataset.split(
            transductive=True,
            split_ratio=[0.6, 0.2, 0.2],
        )
        for edge_type in hete.edge_label_index:
            num_edges = hete.edge_label_index[edge_type].shape[1]
            edge_0 = int(0.6 * num_edges)
            edge_0 = 2 * (edge_0 - int(0.5 * edge_0))
            edge_1 = 2 * int(0.2 * num_edges)
            edge_2 = 2 * (num_edges - int(0.6 * num_edges) -
                          int(0.2 * num_edges))

            self.assertEqual(
                split_res[0][0].edge_label_index[edge_type].shape[1],
                edge_0,
            )
            self.assertEqual(
                split_res[1][0].edge_label_index[edge_type].shape[1],
                edge_1,
            )
            self.assertEqual(
                split_res[2][0].edge_label_index[edge_type].shape[1],
                edge_2,
            )

        # link pred with edge_split_mode set to "exact"
        dataset = GraphDataset([hete],
                               task="link_pred",
                               edge_split_mode="approximate")
        split_res = dataset.split(transductive=True)
        hete_link_train_edge_num = 0
        hete_link_test_edge_num = 0
        hete_link_val_edge_num = 0
        num_edges = 0
        for edge_type in hete.edge_label_index:
            num_edges += hete.edge_label_index[edge_type].shape[1]
            if edge_type in split_res[0][0].edge_label_index:
                hete_link_train_edge_num += (
                    split_res[0][0].edge_label_index[edge_type].shape[1])
            if edge_type in split_res[1][0].edge_label_index:
                hete_link_test_edge_num += (
                    split_res[1][0].edge_label_index[edge_type].shape[1])
            if edge_type in split_res[2][0].edge_label_index:
                hete_link_val_edge_num += (
                    split_res[2][0].edge_label_index[edge_type].shape[1])

        # num_edges_reduced = num_edges - 3
        edge_0 = 2 * int(0.8 * num_edges)
        edge_1 = 2 * int(0.1 * num_edges)
        edge_2 = 2 * (num_edges - int(0.8 * num_edges) - int(0.1 * num_edges))

        self.assertEqual(hete_link_train_edge_num, edge_0)
        self.assertEqual(hete_link_test_edge_num, edge_1)
        self.assertEqual(hete_link_val_edge_num, edge_2)
        # link pred with specified types and edge_split_mode set to "exact"
        dataset = GraphDataset(
            [hete],
            task="link_pred",
            edge_split_mode="approximate",
        )
        link_split_types = [("n1", "e1", "n1"), ("n1", "e2", "n2")]
        split_res = dataset.split(
            transductive=True,
            split_types=link_split_types,
        )
        hete_link_train_edge_num = 0
        hete_link_test_edge_num = 0
        hete_link_val_edge_num = 0

        num_split_type_edges = 0
        num_non_split_type_edges = 0
        for edge_type in hete.edge_label_index:
            if edge_type in link_split_types:
                num_split_type_edges += (
                    hete.edge_label_index[edge_type].shape[1])
            else:
                num_non_split_type_edges += (
                    hete.edge_label_index[edge_type].shape[1])
            if edge_type in split_res[0][0].edge_label_index:
                hete_link_train_edge_num += (
                    split_res[0][0].edge_label_index[edge_type].shape[1])
            if edge_type in split_res[1][0].edge_label_index:
                hete_link_test_edge_num += (
                    split_res[1][0].edge_label_index[edge_type].shape[1])
            if edge_type in split_res[2][0].edge_label_index:
                hete_link_val_edge_num += (
                    split_res[2][0].edge_label_index[edge_type].shape[1])

        # num_edges_reduced = num_split_type_edges - 3
        num_edges = num_split_type_edges
        edge_0 = 2 * int(0.8 * num_edges) + num_non_split_type_edges
        edge_1 = 2 * int(0.1 * num_edges) + num_non_split_type_edges
        edge_2 = 2 * (num_edges - int(0.8 * num_edges) -
                      int(0.1 * num_edges)) + num_non_split_type_edges

        self.assertEqual(hete_link_train_edge_num, edge_0)
        self.assertEqual(hete_link_test_edge_num, edge_1)
        self.assertEqual(hete_link_val_edge_num, edge_2)
    def test_dataset_split_custom(self):
        # transductive split with node task (self defined dataset)
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph())
        Graph.add_edge_attr(G, "edge_feature", edge_x)
        Graph.add_edge_attr(G, "edge_label", edge_y)
        Graph.add_node_attr(G, "node_feature", x)
        Graph.add_node_attr(G, "node_label", y)
        Graph.add_graph_attr(G, "graph_feature", graph_x)
        Graph.add_graph_attr(G, "graph_label", graph_y)

        num_nodes = len(list(G.nodes))
        nodes_train = torch.tensor(list(G.nodes)[:int(0.3 * num_nodes)])
        nodes_val = torch.tensor(
            list(G.nodes)[int(0.3 * num_nodes):int(0.6 * num_nodes)])
        nodes_test = torch.tensor(list(G.nodes)[int(0.6 * num_nodes):])

        graph_train = Graph(node_feature=x,
                            node_label=y,
                            edge_index=edge_index,
                            node_label_index=nodes_train,
                            directed=True)
        graph_val = Graph(node_feature=x,
                          node_label=y,
                          edge_index=edge_index,
                          node_label_index=nodes_val,
                          directed=True)
        graph_test = Graph(node_feature=x,
                           node_label=y,
                           edge_index=edge_index,
                           node_label_index=nodes_test,
                           directed=True)

        graphs_train = [graph_train]
        graphs_val = [graph_val]
        graphs_test = [graph_test]

        dataset_train, dataset_val, dataset_test = (GraphDataset(graphs_train,
                                                                 task='node'),
                                                    GraphDataset(graphs_val,
                                                                 task='node'),
                                                    GraphDataset(graphs_test,
                                                                 task='node'))

        self.assertEqual(dataset_train[0].node_label_index.tolist(),
                         list(range(int(0.3 * num_nodes))))
        self.assertEqual(
            dataset_val[0].node_label_index.tolist(),
            list(range(int(0.3 * num_nodes), int(0.6 * num_nodes))))
        self.assertEqual(dataset_test[0].node_label_index.tolist(),
                         list(range(int(0.6 * num_nodes), num_nodes)))

        # transductive split with link_pred task (train/val split)
        edges = list(G.edges)
        num_edges = len(edges)
        edges_train = edges[:int(0.7 * num_edges)]
        edges_val = edges[int(0.7 * num_edges):]
        link_size_list = [len(edges_train), len(edges_val)]

        # generate pseudo pos and neg edges, they may overlap here
        train_pos = torch.LongTensor(edges_train).permute(1, 0)
        val_pos = torch.LongTensor(edges_val).permute(1, 0)
        val_neg = torch.randint(high=10, size=val_pos.shape, dtype=torch.int64)
        val_neg_double = torch.cat((val_neg, val_neg), dim=1)

        num_train = len(edges_train)
        num_val = len(edges_val)

        graph_train = Graph(node_feature=x,
                            edge_index=edge_index,
                            edge_feature=edge_x,
                            directed=True,
                            edge_label_index=train_pos)

        graph_val = Graph(node_feature=x,
                          edge_index=edge_index,
                          edge_feature=edge_x,
                          directed=True,
                          edge_label_index=val_pos,
                          negative_edge=val_neg_double)

        graphs_train = [graph_train]
        graphs_val = [graph_val]

        dataset_train, dataset_val = (GraphDataset(graphs_train,
                                                   task='link_pred',
                                                   resample_negatives=True),
                                      GraphDataset(
                                          graphs_val,
                                          task='link_pred',
                                          edge_negative_sampling_ratio=2))

        self.assertEqual(dataset_train[0].edge_label_index.shape[1],
                         2 * link_size_list[0])
        self.assertEqual(dataset_train[0].edge_label.shape[0],
                         2 * link_size_list[0])
        self.assertEqual(dataset_val[0].edge_label_index.shape[1],
                         val_pos.shape[1] + val_neg_double.shape[1])
        self.assertEqual(dataset_val[0].edge_label.shape[0],
                         val_pos.shape[1] + val_neg_double.shape[1])
        self.assertTrue(
            torch.equal(dataset_train[0].edge_label_index[:, :num_train],
                        train_pos))
        self.assertTrue(
            torch.equal(dataset_val[0].edge_label_index[:, :num_val], val_pos))
        self.assertTrue(
            torch.equal(dataset_val[0].edge_label_index[:, num_val:],
                        val_neg_double))

        dataset_train.resample_negatives = False
        self.assertTrue(
            torch.equal(dataset_train[0].edge_label_index,
                        dataset_train[0].edge_label_index))

        # transductive split with link_pred task with edge label
        edge_label_train = torch.LongTensor([1, 2, 3, 2, 1, 1, 2, 3, 2, 0, 0])
        edge_label_val = torch.LongTensor([1, 2, 3, 2, 1, 0])

        graph_train = Graph(node_feature=x,
                            edge_index=edge_index,
                            directed=True,
                            edge_label_index=train_pos,
                            edge_label=edge_label_train)

        graph_val = Graph(node_feature=x,
                          edge_index=edge_index,
                          directed=True,
                          edge_label_index=val_pos,
                          negative_edge=val_neg,
                          edge_label=edge_label_val)

        graphs_train = [graph_train]
        graphs_val = [graph_val]

        dataset_train, dataset_val = (GraphDataset(graphs_train,
                                                   task='link_pred'),
                                      GraphDataset(graphs_val,
                                                   task='link_pred'))

        self.assertTrue(
            torch.equal(dataset_train[0].edge_label_index,
                        dataset_train[0].edge_label_index))

        self.assertTrue(
            torch.equal(dataset_train[0].edge_label[:num_train],
                        edge_label_train))

        self.assertTrue(
            torch.equal(dataset_val[0].edge_label[:num_val], edge_label_val))

        # Multiple graph tensor backend link prediction (inductive)
        pyg_dataset = Planetoid('./cora', 'Cora')
        x = pyg_dataset[0].x
        y = pyg_dataset[0].y
        edge_index = pyg_dataset[0].edge_index
        row, col = edge_index
        mask = row < col
        row, col = row[mask], col[mask]
        edge_index = torch.stack([row, col], dim=0)
        edge_index = torch.cat(
            [edge_index, torch.flip(edge_index, [0])], dim=1)

        graphs = [
            Graph(node_feature=x,
                  node_label=y,
                  edge_index=edge_index,
                  directed=False)
        ]
        graphs = [copy.deepcopy(graphs[0]) for _ in range(10)]

        edge_label_index = graphs[0].edge_label_index
        dataset = GraphDataset(graphs,
                               task='link_pred',
                               edge_message_ratio=0.6,
                               edge_train_mode="all")
        datasets = {}
        datasets['train'], datasets['val'], datasets['test'] = dataset.split(
            transductive=False, split_ratio=[0.85, 0.05, 0.1])
        edge_label_index_split = (
            datasets['train'][0].edge_label_index[:,
                                                  0:edge_label_index.shape[1]])

        self.assertTrue(torch.equal(edge_label_index, edge_label_index_split))

        # transductive split with node task (pytorch geometric dataset)
        pyg_dataset = Planetoid("./cora", "Cora")
        ds = pyg_to_dicts(pyg_dataset, task="cora")
        graphs = [Graph(**item) for item in ds]
        split_ratio = [0.3, 0.3, 0.4]
        node_size_list = [0 for i in range(len(split_ratio))]
        for graph in graphs:
            custom_splits = [[] for i in range(len(split_ratio))]
            split_offset = 0
            num_nodes = graph.num_nodes
            shuffled_node_indices = torch.randperm(graph.num_nodes)
            for i, split_ratio_i in enumerate(split_ratio):
                if i != len(split_ratio) - 1:
                    num_split_i = int(split_ratio_i * num_nodes)
                    nodes_split_i = (
                        shuffled_node_indices[split_offset:split_offset +
                                              num_split_i])
                    split_offset += num_split_i
                else:
                    nodes_split_i = shuffled_node_indices[split_offset:]

                custom_splits[i] = nodes_split_i
                node_size_list[i] += len(nodes_split_i)
            graph.custom = {"general_splits": custom_splits}

        node_feature = graphs[0].node_feature
        edge_index = graphs[0].edge_index
        directed = graphs[0].directed

        graph_train = Graph(
            node_feature=node_feature,
            edge_index=edge_index,
            directed=directed,
            node_label_index=graphs[0].custom["general_splits"][0])

        graph_val = Graph(
            node_feature=node_feature,
            edge_index=edge_index,
            directed=directed,
            node_label_index=graphs[0].custom["general_splits"][1])

        graph_test = Graph(
            node_feature=node_feature,
            edge_index=edge_index,
            directed=directed,
            node_label_index=graphs[0].custom["general_splits"][2])

        train_dataset = GraphDataset([graph_train], task="node")
        val_dataset = GraphDataset([graph_val], task="node")
        test_dataset = GraphDataset([graph_test], task="node")

        self.assertEqual(len(train_dataset[0].node_label_index),
                         node_size_list[0])
        self.assertEqual(len(val_dataset[0].node_label_index),
                         node_size_list[1])
        self.assertEqual(len(test_dataset[0].node_label_index),
                         node_size_list[2])

        # transductive split with edge task
        pyg_dataset = Planetoid("./cora", "Cora")
        graphs_g = GraphDataset.pyg_to_graphs(pyg_dataset)
        ds = pyg_to_dicts(pyg_dataset, task="cora")
        graphs = [Graph(**item) for item in ds]
        split_ratio = [0.3, 0.3, 0.4]
        edge_size_list = [0 for i in range(len(split_ratio))]
        for i, graph in enumerate(graphs):
            custom_splits = [[] for i in range(len(split_ratio))]
            split_offset = 0
            edges = list(graphs_g[i].G.edges)
            num_edges = graph.num_edges
            random.shuffle(edges)
            for i, split_ratio_i in enumerate(split_ratio):
                if i != len(split_ratio) - 1:
                    num_split_i = int(split_ratio_i * num_edges)
                    edges_split_i = (edges[split_offset:split_offset +
                                           num_split_i])
                    split_offset += num_split_i
                else:
                    edges_split_i = edges[split_offset:]

                custom_splits[i] = edges_split_i
                edge_size_list[i] += len(edges_split_i)
            graph.custom = {"general_splits": custom_splits}

        node_feature = graphs[0].node_feature
        edge_index = graphs[0].edge_index
        directed = graphs[0].directed

        train_index = torch.tensor(
            graphs[0].custom["general_splits"][0]).permute(1, 0)
        train_index = torch.cat((train_index, train_index), dim=1)
        val_index = torch.tensor(
            graphs[0].custom["general_splits"][1]).permute(1, 0)
        val_index = torch.cat((val_index, val_index), dim=1)
        test_index = torch.tensor(
            graphs[0].custom["general_splits"][2]).permute(1, 0)
        test_index = torch.cat((test_index, test_index), dim=1)

        graph_train = Graph(node_feature=node_feature,
                            edge_index=edge_index,
                            directed=directed,
                            edge_label_index=train_index)

        graph_val = Graph(node_feature=node_feature,
                          edge_index=edge_index,
                          directed=directed,
                          edge_label_index=val_index)

        graph_test = Graph(node_feature=node_feature,
                           edge_index=edge_index,
                           directed=directed,
                           edge_label_index=test_index)

        train_dataset = GraphDataset([graph_train], task="edge")
        val_dataset = GraphDataset([graph_val], task="edge")
        test_dataset = GraphDataset([graph_test], task="edge")

        self.assertEqual(train_dataset[0].edge_label_index.shape[1],
                         2 * edge_size_list[0])
        self.assertEqual(val_dataset[0].edge_label_index.shape[1],
                         2 * edge_size_list[1])
        self.assertEqual(test_dataset[0].edge_label_index.shape[1],
                         2 * edge_size_list[2])

        # inductive split with graph task
        pyg_dataset = TUDataset("./enzymes", "ENZYMES")
        ds = pyg_to_dicts(pyg_dataset)
        graphs = [Graph(**item) for item in ds]
        num_graphs = len(graphs)
        split_ratio = [0.3, 0.3, 0.4]
        graph_size_list = []
        split_offset = 0
        custom_split_graphs = []
        for i, split_ratio_i in enumerate(split_ratio):
            if i != len(split_ratio) - 1:
                num_split_i = int(split_ratio_i * num_graphs)
                custom_split_graphs.append(graphs[split_offset:split_offset +
                                                  num_split_i])
                split_offset += num_split_i
                graph_size_list.append(num_split_i)
            else:
                custom_split_graphs.append(graphs[split_offset:])
                graph_size_list.append(len(graphs[split_offset:]))
        dataset = GraphDataset(graphs,
                               task="graph",
                               custom_split_graphs=custom_split_graphs)
        split_res = dataset.split(transductive=False)
        self.assertEqual(graph_size_list[0], len(split_res[0]))
        self.assertEqual(graph_size_list[1], len(split_res[1]))
        self.assertEqual(graph_size_list[2], len(split_res[2]))
    def test_dataset_split(self):
        # inductively split with graph task
        pyg_dataset = TUDataset("./enzymes", "ENZYMES")
        ds = pyg_to_dicts(pyg_dataset)
        graphs = [Graph(**item) for item in ds]
        dataset = GraphDataset(graphs, task="graph")
        split_res = dataset.split(transductive=False)
        num_graphs = len(dataset)
        num_train = int(0.8 * num_graphs)
        num_val = int(0.1 * num_graphs)
        num_test = num_graphs - num_train - num_val
        self.assertEqual(num_train, len(split_res[0]))
        self.assertEqual(num_val, len(split_res[1]))
        self.assertEqual(num_test, len(split_res[2]))

        # inductively split with link_pred task
        # and default (`all`) edge_train_mode
        pyg_dataset = TUDataset("./enzymes", "ENZYMES")
        ds = pyg_to_dicts(pyg_dataset)
        graphs = [Graph(**item) for item in ds]
        dataset = GraphDataset(graphs, task="link_pred")
        split_res = dataset.split(transductive=False)
        num_graphs = len(dataset)
        num_train = int(0.8 * num_graphs)
        num_val = int(0.1 * num_graphs)
        num_test = num_graphs - num_train - num_val
        self.assertEqual(num_train, len(split_res[0]))
        self.assertEqual(num_val, len(split_res[1]))
        self.assertEqual(num_test, len(split_res[2]))

        # inductively split with link_pred task and `disjoint` edge_train_mode
        pyg_dataset = TUDataset("./enzymes", "ENZYMES")
        ds = pyg_to_dicts(pyg_dataset)
        graphs = [Graph(**item) for item in ds]
        dataset = GraphDataset(
            graphs,
            task="link_pred",
            edge_train_mode="disjoint",
        )
        split_res = dataset.split(transductive=False)
        num_graphs = len(dataset)
        num_train = int(0.8 * num_graphs)
        num_val = int(0.1 * num_graphs)
        num_test = num_graphs - num_train - num_val
        self.assertEqual(num_train, len(split_res[0]))
        self.assertEqual(num_val, len(split_res[1]))
        self.assertEqual(num_test, len(split_res[2]))

        # transductively split with node task
        pyg_dataset = Planetoid("./cora", "Cora")
        ds = pyg_to_dicts(pyg_dataset, task="cora")
        graphs = [Graph(**item) for item in ds]
        dataset = GraphDataset(graphs, task="node")
        num_nodes = dataset.num_nodes[0]
        num_edges = dataset.num_edges[0]
        node_0 = int(0.8 * num_nodes)
        node_1 = int(0.1 * num_nodes)
        node_2 = num_nodes - node_0 - node_1
        split_res = dataset.split()
        self.assertEqual(len(split_res[0][0].node_label_index), node_0)
        self.assertEqual(len(split_res[1][0].node_label_index), node_1)
        self.assertEqual(len(split_res[2][0].node_label_index), node_2)

        # transductively split with link_pred task
        # and default (`all`) edge_train_mode
        dataset = GraphDataset(graphs, task="link_pred")
        edge_0 = 2 * 2 * int(0.8 * num_edges)
        edge_1 = 2 * 2 * int(0.1 * num_edges)
        edge_2 = 2 * 2 * (num_edges - int(0.8 * num_edges) -
                          int(0.1 * num_edges))
        split_res = dataset.split()
        self.assertEqual(split_res[0][0].edge_label_index.shape[1], edge_0)
        self.assertEqual(split_res[1][0].edge_label_index.shape[1], edge_1)
        self.assertEqual(split_res[2][0].edge_label_index.shape[1], edge_2)

        # transductively split with link_pred task, `split` edge_train_mode
        # and 0.5 edge_message_ratio
        dataset = GraphDataset(
            graphs,
            task="link_pred",
            edge_train_mode="disjoint",
            edge_message_ratio=0.5,
        )
        split_res = dataset.split()
        edge_0 = 2 * int(0.8 * num_edges)
        edge_0 = 2 * (edge_0 - int(0.5 * edge_0))
        edge_1 = 2 * 2 * int(0.1 * num_edges)
        edge_2 = 2 * 2 * (num_edges - int(0.8 * num_edges) -
                          int(0.1 * num_edges))
        self.assertEqual(
            split_res[0][0].edge_label_index.shape[1],
            edge_0,
        )
        self.assertEqual(split_res[1][0].edge_label_index.shape[1], edge_1)
        self.assertEqual(split_res[2][0].edge_label_index.shape[1], edge_2)

        # transductively split with link_pred task
        # and specified edge_negative_sampling_ratio
        dataset = GraphDataset(graphs,
                               task="link_pred",
                               edge_negative_sampling_ratio=2)
        split_res = dataset.split()
        edge_0 = (2 + 1) * (2 * int(0.8 * num_edges))
        edge_1 = (2 + 1) * (2 * int(0.1 * num_edges))
        edge_2 = (2 + 1) * (
            2 * (num_edges - int(0.8 * num_edges) - int(0.1 * num_edges)))

        self.assertEqual(split_res[0][0].edge_label_index.shape[1], edge_0)
        self.assertEqual(split_res[1][0].edge_label_index.shape[1], edge_1)
        self.assertEqual(split_res[2][0].edge_label_index.shape[1], edge_2)
예제 #12
0
    def test_dataset_split_custom(self):
        # transductive split with node task
        pyg_dataset = Planetoid("./cora", "Cora")
        graphs = GraphDataset.pyg_to_graphs(pyg_dataset)
        split_ratio = [0.3, 0.3, 0.4]
        split_graphs = [[] for i in range(len(split_ratio))]
        node_size_list = [0 for i in range(len(split_ratio))]
        for graph in graphs:
            split_offset = 0
            shuffled_node_indices = torch.randperm(graph.num_nodes)
            for i, split_ratio_i in enumerate(split_ratio):
                if i != len(split_ratio) - 1:
                    num_split_i = (
                        1 +
                        int(
                            split_ratio_i *
                            (graph.num_nodes - len(split_ratio))
                        )
                    )
                    nodes_split_i = (
                        shuffled_node_indices[split_offset: split_offset + num_split_i]
                    )
                    split_offset += num_split_i
                else:
                    nodes_split_i = shuffled_node_indices[split_offset:]

                graph_new = copy.copy(graph)
                graph_new.custom_split_index = nodes_split_i
                split_graphs[i].append(graph_new)
                node_size_list[i] += len(nodes_split_i)

        dataset = GraphDataset(
            graphs, task="node", general_split_mode="custom",
            split_graphs=split_graphs
        )

        split_res = dataset.split(transductive=True)
        self.assertEqual(
            len(split_res[0][0].node_label_index),
            node_size_list[0]
        )
        self.assertEqual(
            len(split_res[1][0].node_label_index),
            node_size_list[1]
        )
        self.assertEqual(
            len(split_res[2][0].node_label_index),
            node_size_list[2]
        )

        # transductive split with edge task
        pyg_dataset = Planetoid("./cora", "Cora")
        graphs = GraphDataset.pyg_to_graphs(pyg_dataset)
        split_ratio = [0.3, 0.3, 0.4]
        split_graphs = [[] for i in range(len(split_ratio))]
        edge_size_list = [0 for i in range(len(split_ratio))]
        for graph in graphs:
            split_offset = 0
            edges = list(graph.G.edges())
            random.shuffle(edges)
            for i, split_ratio_i in enumerate(split_ratio):
                if i != len(split_ratio) - 1:
                    num_split_i = (
                        1 +
                        int(
                            split_ratio_i
                            * (graph.num_edges - len(split_ratio))
                        )
                    )
                    edges_split_i = (
                        edges[split_offset: split_offset + num_split_i]
                    )
                    split_offset += num_split_i
                else:
                    edges_split_i = edges[split_offset:]
                graph_new = copy.copy(graph)
                graph_new.custom_split_index = edges_split_i

                split_graphs[i].append(graph_new)
                edge_size_list[i] += len(edges_split_i)

        dataset = GraphDataset(
            graphs, task="edge", general_split_mode="custom",
            split_graphs=split_graphs
        )
        split_res = dataset.split(transductive=True)
        self.assertEqual(
            split_res[0][0].edge_label_index.shape[1],
            2 * edge_size_list[0]
        )
        self.assertEqual(
            split_res[1][0].edge_label_index.shape[1],
            2 * edge_size_list[1]
        )
        self.assertEqual(
            split_res[2][0].edge_label_index.shape[1],
            2 * edge_size_list[2]
        )

        # transductive split with link_pred task
        pyg_dataset = Planetoid("./cora", "Cora")
        graphs = GraphDataset.pyg_to_graphs(pyg_dataset)
        split_ratio = [0.3, 0.3, 0.4]
        split_graphs = [[] for i in range(len(split_ratio))]
        link_size_list = [0 for i in range(len(split_ratio))]

        for graph in graphs:
            split_offset = 0
            edges = list(graph.G.edges(data=True))
            random.shuffle(edges)
            num_edges_train = 1 + int(split_ratio[0] * (graph.num_edges - 3))
            num_edges_val = 1 + int(split_ratio[0] * (graph.num_edges - 3))
            edges_train = edges[:num_edges_train]
            edges_val = edges[num_edges_train:num_edges_train + num_edges_val]
            edges_test = edges[num_edges_train + num_edges_val:]

            graph_train = copy.copy(graph)
            graph_test = copy.copy(graph)
            graph_val = copy.copy(graph)

            graph_train.custom_split_index = edges_train
            graph_val.custom_split_index = edges_val
            graph_test.custom_split_index = edges_test

            split_graphs[0].append(graph_train)
            split_graphs[1].append(graph_val)
            split_graphs[2].append(graph_test)
            link_size_list[0] += len(edges_train)
            link_size_list[1] += len(edges_val)
            link_size_list[2] += len(edges_test)

        dataset = GraphDataset(
            graphs, task="link_pred", general_split_mode="custom",
            split_graphs=split_graphs
        )
        split_res = dataset.split(transductive=True)
        self.assertEqual(
            split_res[0][0].edge_label_index.shape[1],
            2 * 2 * link_size_list[0]
        )
        self.assertEqual(
            split_res[1][0].edge_label_index.shape[1],
            2 * 2 * link_size_list[1]
        )
        self.assertEqual(
            split_res[2][0].edge_label_index.shape[1],
            2 * 2 * link_size_list[2]
        )

        # inductive split with graph task
        pyg_dataset = TUDataset("./enzymes", "ENZYMES")
        graphs = GraphDataset.pyg_to_graphs(pyg_dataset)
        num_graphs = len(graphs)
        split_ratio = [0.3, 0.3, 0.4]
        graph_size_list = []
        split_offset = 0
        split_graphs = []
        for i, split_ratio_i in enumerate(split_ratio):
            if i != len(split_ratio) - 1:
                num_split_i = (
                    1 +
                    int(split_ratio_i * (num_graphs - len(split_ratio)))
                )
                split_graphs.append(
                    graphs[split_offset: split_offset + num_split_i]
                )
                split_offset += num_split_i
                graph_size_list.append(num_split_i)
            else:
                split_graphs.append(graphs[split_offset:])
                graph_size_list.append(len(graphs[split_offset:]))
        dataset = GraphDataset(
            graphs, task="graph", general_split_mode="custom",
            split_graphs=split_graphs
        )
        split_res = dataset.split(transductive=False)
        self.assertEqual(graph_size_list[0], len(split_res[0]))
        self.assertEqual(graph_size_list[1], len(split_res[1]))
        self.assertEqual(graph_size_list[2], len(split_res[2]))
예제 #13
0
    def test_dataset_split(self):
        # inductively split with graph task
        pyg_dataset = TUDataset("./enzymes", "ENZYMES")
        graphs = GraphDataset.pyg_to_graphs(pyg_dataset)
        dataset = GraphDataset(graphs, task="graph")
        split_res = dataset.split(transductive=False)
        num_graphs = len(dataset)
        num_graphs_reduced = num_graphs - 3
        num_train = 1 + int(num_graphs_reduced * 0.8)
        num_val = 1 + int(num_graphs_reduced * 0.1)
        num_test = num_graphs - num_train - num_val
        self.assertEqual(num_train, len(split_res[0]))
        self.assertEqual(num_val, len(split_res[1]))
        self.assertEqual(num_test, len(split_res[2]))

        # inductively split with link_pred task
        # and default (`all`) edge_train_mode
        pyg_dataset = TUDataset("./enzymes", "ENZYMES")
        graphs = GraphDataset.pyg_to_graphs(pyg_dataset)
        dataset = GraphDataset(graphs, task="link_pred")
        split_res = dataset.split(transductive=False)
        num_graphs = len(dataset)
        num_graphs_reduced = num_graphs - 3
        num_train = 1 + int(num_graphs_reduced * 0.8)
        num_val = 1 + int(num_graphs_reduced * 0.1)
        num_test = num_graphs - num_train - num_val
        self.assertEqual(num_train, len(split_res[0]))
        self.assertEqual(num_val, len(split_res[1]))
        self.assertEqual(num_test, len(split_res[2]))

        # inductively split with link_pred task and `disjoint` edge_train_mode
        pyg_dataset = TUDataset("./enzymes", "ENZYMES")
        graphs = GraphDataset.pyg_to_graphs(pyg_dataset)
        dataset = GraphDataset(
            graphs,
            task="link_pred",
            edge_train_mode="disjoint",
        )
        split_res = dataset.split(transductive=False)
        num_graphs = len(dataset)
        num_graphs_reduced = num_graphs - 3
        num_train = 1 + int(num_graphs_reduced * 0.8)
        num_val = 1 + int(num_graphs_reduced * 0.1)
        num_test = num_graphs - num_train - num_val
        self.assertEqual(num_train, len(split_res[0]))
        self.assertEqual(num_val, len(split_res[1]))
        self.assertEqual(num_test, len(split_res[2]))

        # transductively split with node task
        pyg_dataset = Planetoid("./cora", "Cora")
        graphs = GraphDataset.pyg_to_graphs(pyg_dataset)
        dataset = GraphDataset(graphs, task="node")
        num_nodes = dataset.num_nodes[0]
        num_nodes_reduced = num_nodes - 3
        num_edges = dataset.num_edges[0]
        num_edges_reduced = num_edges - 3
        split_res = dataset.split()
        self.assertEqual(
            len(split_res[0][0].node_label_index),
            1 + int(0.8 * num_nodes_reduced)
        )
        self.assertEqual(
            len(split_res[1][0].node_label_index),
            1 + int(0.1 * num_nodes_reduced)
        )
        self.assertEqual(
            len(split_res[2][0].node_label_index),
            num_nodes
            - 2
            - int(0.8 * num_nodes_reduced)
            - int(0.1 * num_nodes_reduced)
        )

        # transductively split with edge task
        dataset = GraphDataset(graphs, task="edge")
        split_res = dataset.split()
        edge_0 = 2 * (1 + int(0.8 * (num_edges_reduced)))
        self.assertEqual(
            split_res[0][0].edge_label_index.shape[1],
            edge_0,
        )
        edge_1 = 2 * (1 + int(0.1 * (num_edges_reduced)))
        self.assertEqual(
            split_res[1][0].edge_label_index.shape[1],
            edge_1,
        )
        self.assertEqual(
            split_res[2][0].edge_label_index.shape[1],
            2 * num_edges - edge_0 - edge_1,
        )

        # transductively split with link_pred task
        # and default (`all`) edge_train_mode
        dataset = GraphDataset(graphs, task="link_pred")
        split_res = dataset.split()
        self.assertEqual(
            split_res[0][0].edge_label_index.shape[1],
            2 * 2 * (1 + int(0.8 * (num_edges_reduced)))
        )
        self.assertEqual(
            split_res[1][0].edge_label_index.shape[1],
            2
            * 2 * (1 + (int(0.1 * (num_edges_reduced))))
        )
        self.assertEqual(
            split_res[2][0].edge_label_index.shape[1],
            2
            * 2
            * num_edges
            - 2
            * 2
            * (
                2
                + int(0.1 * num_edges_reduced)
                + int(0.8 * num_edges_reduced)
            )
        )

        # transductively split with link_pred task, `split` edge_train_mode
        # and 0.5 edge_message_ratio
        dataset = GraphDataset(
            graphs,
            task="link_pred",
            edge_train_mode="disjoint",
            edge_message_ratio=0.5,
        )
        split_res = dataset.split()
        edge_0 = 2 * (1 + int(0.8 * num_edges_reduced))
        edge_0 = 2 * (edge_0 - (1 + int(0.5 * (edge_0 - 3))))
        self.assertEqual(
            split_res[0][0].edge_label_index.shape[1],
            edge_0,
        )
        edge_1 = 2 * 2 * (1 + int(0.1 * num_edges_reduced))
        self.assertEqual(split_res[1][0].edge_label_index.shape[1], edge_1)
        edge_2 = (
            2
            * 2
            * int(num_edges)
            - 2
            * 2 * (1 + int(0.8 * num_edges_reduced))
            - edge_1
        )

        self.assertEqual(split_res[2][0].edge_label_index.shape[1], edge_2)

        # transductively split with link_pred task
        # and specified edge_negative_sampling_ratio
        dataset = GraphDataset(
            graphs,
            task="link_pred",
            edge_negative_sampling_ratio=2
        )
        split_res = dataset.split()
        edge_0 = (2 + 1) * (2 * (1 + int(0.8 * num_edges_reduced)))
        self.assertEqual(split_res[0][0].edge_label_index.shape[1], edge_0)
        edge_1 = (2 + 1) * 2 * (1 + int(0.1 * num_edges_reduced))
        self.assertEqual(split_res[1][0].edge_label_index.shape[1], edge_1)
        edge_2 = (2 + 1) * 2 * int(num_edges) - edge_0 - edge_1
        self.assertEqual(split_res[2][0].edge_label_index.shape[1], edge_2)
    def test_secure_split_heterogeneous(self):
        G = generate_simple_small_hete_graph()
        graph = HeteroGraph(G)
        graph = HeteroGraph(node_label=graph.node_label,
                            edge_index=graph.edge_index,
                            edge_label=graph.edge_label,
                            directed=True)
        graphs = [graph]

        # node task
        dataset = GraphDataset(graphs, task="node")
        split_res = dataset.split()
        for node_type in graph.node_label_index:
            num_nodes = graph.node_label_index[node_type].shape[0]
            num_nodes_reduced = num_nodes - 3
            node_0 = 1 + int(num_nodes_reduced * 0.8)
            node_1 = 1 + int(num_nodes_reduced * 0.1)
            node_2 = num_nodes - node_0 - node_1
            node_size = [node_0, node_1, node_2]
            for i in range(3):
                self.assertEqual(
                    split_res[i][0].node_label_index[node_type].shape[0],
                    node_size[i])
                self.assertEqual(
                    split_res[i][0].node_label[node_type].shape[0],
                    node_size[i])

        # edge task
        dataset = GraphDataset(graphs, task="edge")
        split_res = dataset.split()
        for message_type in graph.edge_label_index:
            num_edges = graph.edge_label_index[message_type].shape[1]
            num_edges_reduced = num_edges - 3
            edge_0 = 1 + int(num_edges_reduced * 0.8)
            edge_1 = 1 + int(num_edges_reduced * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            edge_size = [edge_0, edge_1, edge_2]
            for i in range(3):
                self.assertEqual(
                    split_res[i][0].edge_label_index[message_type].shape[1],
                    edge_size[i])
                self.assertEqual(
                    split_res[i][0].edge_label[message_type].shape[0],
                    edge_size[i])

        # link_pred task
        dataset = GraphDataset(graphs, task="link_pred")
        split_res = dataset.split()
        for message_type in graph.edge_label_index:
            num_edges = graph.edge_label_index[message_type].shape[1]
            num_edges_reduced = num_edges - 3
            edge_0 = 2 * (1 + int(num_edges_reduced * 0.8))
            edge_1 = 2 * (1 + int(num_edges_reduced * 0.1))
            edge_2 = 2 * num_edges - edge_0 - edge_1
            edge_size = [edge_0, edge_1, edge_2]
            for i in range(3):
                self.assertEqual(
                    split_res[i][0].edge_label_index[message_type].shape[1],
                    edge_size[i])
                self.assertEqual(
                    split_res[i][0].edge_label[message_type].shape[0],
                    edge_size[i])
예제 #15
0
def main():
    args = arg_parse()

    edge_train_mode = args.mode
    print('edge train mode: {}'.format(edge_train_mode))

    G = nx.read_gpickle(args.data_path)
    print(G.number_of_edges())
    print('Each node has node ID (n_id). Example: ', G.nodes[0])
    print(
        'Each edge has edge ID (id) and categorical label (e_label). Example: ',
        G[0][5871])

    # find num edge types
    max_label = 0
    labels = []
    for u, v, edge_key in G.edges:
        l = G[u][v][edge_key]['e_label']
        if not l in labels:
            labels.append(l)
    # labels are consecutive (0-17)
    num_edge_types = len(labels)

    H = WN_transform(G, num_edge_types)
    # The nodes in the graph have the features: node_feature and node_type (just one node type "n1" here)
    for node in H.nodes(data=True):
        print(node)
        break
    # The edges in the graph have the features: edge_feature and edge_type ("0" - "17" here)
    for edge in H.edges(data=True):
        print(edge)
        break

    hetero = HeteroGraph(H)
    hetero = HeteroGraph(edge_index=hetero.edge_index,
                         edge_feature=hetero.edge_feature,
                         node_feature=hetero.node_feature,
                         directed=hetero.is_directed())

    if edge_train_mode == "disjoint":
        dataset = GraphDataset([hetero],
                               task='link_pred',
                               edge_train_mode=edge_train_mode,
                               edge_message_ratio=args.edge_message_ratio)
    else:
        dataset = GraphDataset(
            [hetero],
            task='link_pred',
            edge_train_mode=edge_train_mode,
        )

    dataset_train, dataset_val, dataset_test = dataset.split(
        transductive=True, split_ratio=[0.8, 0.1, 0.1])
    train_loader = DataLoader(dataset_train,
                              collate_fn=Batch.collate(),
                              batch_size=1)
    val_loader = DataLoader(dataset_val,
                            collate_fn=Batch.collate(),
                            batch_size=1)
    test_loader = DataLoader(dataset_test,
                             collate_fn=Batch.collate(),
                             batch_size=1)
    dataloaders = {
        'train': train_loader,
        'val': val_loader,
        'test': test_loader
    }

    hidden_size = args.hidden_dim
    conv1, conv2 = generate_2convs_link_pred_layers(hetero, HeteroSAGEConv,
                                                    hidden_size)
    model = HeteroGNN(conv1, conv2, hetero, hidden_size).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    t_accu, v_accu, e_accu = train(model, dataloaders, optimizer, args)
예제 #16
0
def main():
    writer = SummaryWriter()
    args = arg_parse()

    edge_train_mode = args.mode
    print('edge train mode: {}'.format(edge_train_mode))

    ppi_graph = read_ppi_data(args.ppi_path)

    mode = 'mixed'
    if mode == 'ppi':
        message_passing_graph = ppi_graph
        cmap_graph, knockout_nodes = read_cmap_data(args.data_path)
    elif mode == 'mixed':
        message_passing_graph, knockout_nodes = (
            read_cmap_data(args.data_path, ppi_graph)
        )

    print('Each node has gene ID. Example: ', message_passing_graph.nodes['ADPGK'])
    print('Each edge has de direction. Example', message_passing_graph['ADPGK']['IL1B'])
    print('Total num edges: ', message_passing_graph.number_of_edges())

    # disjoint edge label
    disjoint_split_ratio = 0.1
    val_ratio = 0.1
    disjoint_edge_label_index = []
    val_edges = []

    # newly edited
    train_edges = []
    for u in knockout_nodes:
        rand_num = np.random.rand()
        if rand_num < disjoint_split_ratio:
            # add all edges (cmap only) into edge label index
            # cmap is not a multigraph
            disjoint_edge_label_index.extend(
                [
                    (u, v, edge_key)
                    for v in message_passing_graph.successors(u)
                    for edge_key in message_passing_graph[u][v]
                    if message_passing_graph[u][v][edge_key]['edge_type'] == 1
                ]
            )

            train_edges.extend(
                [
                    (u, v, edge_key)
                    for v in message_passing_graph.successors(u)
                    for edge_key in message_passing_graph[u][v]
                    if message_passing_graph[u][v][edge_key]['edge_type'] == 1
                ]
            )
        elif rand_num < disjoint_split_ratio + val_ratio:
            val_edges.extend(
                [
                    (u, v, edge_key)
                    for v in message_passing_graph.successors(u)
                    for edge_key in message_passing_graph[u][v]
                    if message_passing_graph[u][v][edge_key]['edge_type'] == 1
                ]
            )
        else:
            train_edges.extend(
                [
                    (u, v, edge_key)
                    for v in message_passing_graph.successors(u)
                    for edge_key in message_passing_graph[u][v]
                    if message_passing_graph[u][v][edge_key]['edge_type'] == 1
                ]
            )
    # add default node types for message_passing_graph
    for node in message_passing_graph.nodes:
        message_passing_graph.nodes[node]['node_type'] = 0

    print('Num edges to predict: ', len(disjoint_edge_label_index))
    print('Num edges in val: ', len(val_edges))
    print('Num edges in train: ', len(train_edges))

    graph = HeteroGraph(
        message_passing_graph,
        custom={
            "general_splits": [
                train_edges,
                val_edges
            ],
            "disjoint_split": disjoint_edge_label_index,
            "task": "link_pred"
        }
    )

    graphs = [graph]
    graphDataset = GraphDataset(
        graphs,
        task="link_pred",
        edge_train_mode="disjoint"
    )

    # Transform dataset
    # de direction (currently using homogeneous graph)
    num_edge_types = 2

    graphDataset = graphDataset.apply_transform(
        cmap_transform, num_edge_types=num_edge_types, deep_copy=False
    )
    print('Number of node features: ', graphDataset.num_node_features())

    # split dataset
    dataset = {}
    dataset['train'], dataset['val'] = graphDataset.split(transductive=True)

    # sanity check
    print(f"dataset['train'][0].edge_label_index.keys(): {dataset['train'][0].edge_label_index.keys()}")
    print(f"dataset['train'][0].edge_label_index[(0, 1, 0)].shape[1]: {dataset['train'][0].edge_label_index[(0, 1, 0)].shape[1]}")
    print(f"dataset['val'][0].edge_label_index.keys(): {dataset['val'][0].edge_label_index.keys()}")
    print(f"dataset['val'][0].edge_label_index[(0, 1, 0)].shape[1]: {dataset['val'][0].edge_label_index[(0, 1, 0)].shape[1]}")
    print(f"len(list(dataset['train'][0].G.edges)): {len(list(dataset['train'][0].G.edges))}")
    print(f"len(list(dataset['val'][0].G.edges)): {len(list(dataset['val'][0].G.edges))}")
    print(f"list(dataset['train'][0].G.edges)[:10]: {list(dataset['train'][0].G.edges)[:10]}")
    print(f"list(dataset['val'][0].G.edges)[:10]: {list(dataset['val'][0].G.edges)[:10]}")


    # node feature dimension
    input_dim = dataset['train'].num_node_features()
    edge_feat_dim = dataset['train'].num_edge_features()
    num_classes = dataset['train'].num_edge_labels()
    print(
        'Node feature dim: {}; edge feature dim: {}; num classes: {}.'.format(
            input_dim, edge_feat_dim, num_classes
        )
    )
    exit()

    # relation type is both used for edge features and edge labels
    model = Net(input_dim, edge_feat_dim, num_classes, args).to(args.device)
    optimizer = torch.optim.Adam(
        model.parameters(), lr=0.001, weight_decay=5e-3
    )
    follow_batch = []  # e.g., follow_batch = ['edge_index']

    dataloaders = {
        split: DataLoader(
            ds, collate_fn=Batch.collate(follow_batch),
            batch_size=1, shuffle=(split == 'train')
        )
        for split, ds in dataset.items()
    }
    print('Graphs after split: ')
    for key, dataloader in dataloaders.items():
        for batch in dataloader:
            print(key, ': ', batch)

    train(model, dataloaders, optimizer, args, writer=writer)
예제 #17
0
    def test_dataset_split_custom(self):
        # transductive split with node task (self defined dataset)
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph_alphabet())
        Graph.add_edge_attr(G, "edge_feature", edge_x)
        Graph.add_edge_attr(G, "edge_label", edge_y)
        Graph.add_node_attr(G, "node_feature", x)
        Graph.add_node_attr(G, "node_label", y)
        Graph.add_graph_attr(G, "graph_feature", graph_x)
        Graph.add_graph_attr(G, "graph_label", graph_y)

        num_nodes = len(list(G.nodes))
        nodes_train = list(G.nodes)[:int(0.3 * num_nodes)]
        nodes_val = list(G.nodes)[int(0.3 * num_nodes):int(0.6 * num_nodes)]
        nodes_test = list(G.nodes)[int(0.6 * num_nodes):]
        graph = Graph(G,
                      custom_splits=[nodes_train, nodes_val, nodes_test],
                      task="node")
        graphs = [graph]
        dataset = GraphDataset(
            graphs,
            task="node",
            general_split_mode="custom",
        )

        split_res = dataset.split(transductive=True)
        self.assertEqual(split_res[0][0].node_label_index,
                         list(range(int(0.3 * num_nodes))))
        self.assertEqual(
            split_res[1][0].node_label_index,
            list(range(int(0.3 * num_nodes), int(0.6 * num_nodes))))
        self.assertEqual(split_res[2][0].node_label_index,
                         list(range(int(0.6 * num_nodes), num_nodes)))

        # transductive split with link_pred task (disjoint mode) (self defined dataset)
        edges = list(G.edges)
        num_edges = len(edges)
        edges_train = edges[:int(0.3 * num_edges)]
        edges_train_disjoint = edges[:int(0.5 * 0.3 * num_edges)]
        edges_val = edges[int(0.3 * num_edges):int(0.6 * num_edges)]
        edges_test = edges[int(0.6 * num_edges):]
        link_size_list = [
            len(edges_train_disjoint),
            len(edges_val),
            len(edges_test)
        ]
        graph = Graph(G,
                      custom_splits=[edges_train, edges_val, edges_test],
                      custom_disjoint_split=edges_train_disjoint,
                      task="link_pred")

        graphs = [graph]

        dataset = GraphDataset(
            graphs,
            task="link_pred",
            edge_train_mode="disjoint",
            general_split_mode="custom",
            disjoint_split_mode="custom",
        )

        split_res = dataset.split(transductive=True)
        self.assertEqual(split_res[0][0].edge_label_index.shape[1],
                         2 * link_size_list[0])
        self.assertEqual(split_res[1][0].edge_label_index.shape[1],
                         2 * link_size_list[1])
        self.assertEqual(split_res[2][0].edge_label_index.shape[1],
                         2 * link_size_list[2])

        # transductive split with link_pred task (disjoint mode) (self defined disjoint data)
        edges = list(G.edges)
        num_edges = len(edges)
        edges_train = edges[:int(0.7 * num_edges)]
        edges_train_disjoint = edges[:int(0.5 * 0.7 * num_edges)]
        edges_val = edges[int(0.7 * num_edges):]
        link_size_list = [len(edges_train_disjoint), len(edges_val)]

        graph = Graph(G,
                      custom_splits=[
                          edges_train,
                          edges_val,
                      ],
                      custom_disjoint_split=edges_train_disjoint,
                      task="link_pred")

        graphs = [graph]

        dataset = GraphDataset(
            graphs,
            task="link_pred",
            edge_train_mode="disjoint",
            general_split_mode="custom",
            disjoint_split_mode="custom",
        )

        split_res = dataset.split(transductive=True)

        self.assertEqual(split_res[0][0].edge_label_index.shape[1],
                         2 * link_size_list[0])
        self.assertEqual(split_res[1][0].edge_label_index.shape[1],
                         2 * link_size_list[1])

        # transductive split with link_pred task (disjoint mode) (self defined disjoint data) (multigraph) (train/val split)
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_multigraph())
        Graph.add_edge_attr(G, "edge_feature", edge_x)
        Graph.add_edge_attr(G, "edge_label", edge_y)
        Graph.add_node_attr(G, "node_feature", x)
        Graph.add_node_attr(G, "node_label", y)
        Graph.add_graph_attr(G, "graph_feature", graph_x)
        Graph.add_graph_attr(G, "graph_label", graph_y)
        edges = list(G.edges)
        num_edges = len(edges)
        edges_train = edges[:int(0.6 * num_edges)]
        edges_train_disjoint = edges[:int(0.6 * 0.2 * num_edges)]
        edges_val = edges[int(0.6 * num_edges):]
        link_size_list = [len(edges_train_disjoint), len(edges_val)]

        graph = Graph(G,
                      custom_splits=[
                          edges_train,
                          edges_val,
                      ],
                      custom_disjoint_split=edges_train_disjoint,
                      task="link_pred")

        graphs = [graph]

        dataset = GraphDataset(
            graphs,
            task="link_pred",
            edge_train_mode="disjoint",
            general_split_mode="custom",
            disjoint_split_mode="custom",
        )

        split_res = dataset.split(transductive=True)

        self.assertEqual(split_res[0][0].edge_label_index.shape[1],
                         2 * link_size_list[0])
        self.assertEqual(split_res[1][0].edge_label_index.shape[1],
                         2 * link_size_list[1])

        # transductive split with link_pred task (disjoint mode) (self defined disjoint data) (multigraph) (train/val/test split)
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_multigraph())
        Graph.add_edge_attr(G, "edge_feature", edge_x)
        Graph.add_edge_attr(G, "edge_label", edge_y)
        Graph.add_node_attr(G, "node_feature", x)
        Graph.add_node_attr(G, "node_label", y)
        Graph.add_graph_attr(G, "graph_feature", graph_x)
        Graph.add_graph_attr(G, "graph_label", graph_y)

        edges = list(G.edges)
        num_edges = len(edges)
        edges_train = edges[:int(0.6 * num_edges)]
        edges_train_disjoint = edges[:int(0.6 * 0.2 * num_edges)]
        edges_val = edges[int(0.6 * num_edges):int(0.8 * num_edges)]
        edges_test = edges[int(0.8 * num_edges):]
        link_size_list = [
            len(edges_train_disjoint),
            len(edges_val),
            len(edges_test)
        ]

        graph = Graph(G,
                      custom_splits=[
                          edges_train,
                          edges_val,
                          edges_test,
                      ],
                      custom_disjoint_split=edges_train_disjoint,
                      task="link_pred")

        graphs = [graph]

        dataset = GraphDataset(
            graphs,
            task="link_pred",
            edge_train_mode="disjoint",
            general_split_mode="custom",
            disjoint_split_mode="custom",
        )

        split_res = dataset.split(transductive=True)

        self.assertEqual(split_res[0][0].edge_label_index.shape[1],
                         2 * link_size_list[0])
        self.assertEqual(split_res[1][0].edge_label_index.shape[1],
                         2 * link_size_list[1])
        self.assertEqual(split_res[2][0].edge_label_index.shape[1],
                         2 * link_size_list[2])

        # transductive split with node task (pytorch geometric dataset)
        pyg_dataset = Planetoid("./cora", "Cora")
        graphs = GraphDataset.pyg_to_graphs(pyg_dataset)
        split_ratio = [0.3, 0.3, 0.4]

        node_size_list = [0 for i in range(len(split_ratio))]
        for graph in graphs:
            custom_splits = [[] for i in range(len(split_ratio))]
            split_offset = 0
            shuffled_node_indices = torch.randperm(graph.num_nodes)
            for i, split_ratio_i in enumerate(split_ratio):
                if i != len(split_ratio) - 1:
                    num_split_i = (1 +
                                   int(split_ratio_i *
                                       (graph.num_nodes - len(split_ratio))))
                    nodes_split_i = (
                        shuffled_node_indices[split_offset:split_offset +
                                              num_split_i])
                    split_offset += num_split_i
                else:
                    nodes_split_i = shuffled_node_indices[split_offset:]

                custom_splits[i] = nodes_split_i
                node_size_list[i] += len(nodes_split_i)
            graph.custom_splits = custom_splits

        dataset = GraphDataset(
            graphs,
            task="node",
            general_split_mode="custom",
        )

        split_res = dataset.split(transductive=True)
        self.assertEqual(len(split_res[0][0].node_label_index),
                         node_size_list[0])
        self.assertEqual(len(split_res[1][0].node_label_index),
                         node_size_list[1])
        self.assertEqual(len(split_res[2][0].node_label_index),
                         node_size_list[2])

        # transductive split with edge task
        pyg_dataset = Planetoid("./cora", "Cora")
        graphs = GraphDataset.pyg_to_graphs(pyg_dataset)
        split_ratio = [0.3, 0.3, 0.4]
        edge_size_list = [0 for i in range(len(split_ratio))]
        for graph in graphs:
            custom_splits = [[] for i in range(len(split_ratio))]
            split_offset = 0
            edges = list(graph.G.edges)
            random.shuffle(edges)
            for i, split_ratio_i in enumerate(split_ratio):
                if i != len(split_ratio) - 1:
                    num_split_i = (1 +
                                   int(split_ratio_i *
                                       (graph.num_edges - len(split_ratio))))
                    edges_split_i = (edges[split_offset:split_offset +
                                           num_split_i])
                    split_offset += num_split_i
                else:
                    edges_split_i = edges[split_offset:]

                custom_splits[i] = edges_split_i
                edge_size_list[i] += len(edges_split_i)
            graph.custom_splits = custom_splits

        dataset = GraphDataset(
            graphs,
            task="edge",
            general_split_mode="custom",
        )
        split_res = dataset.split(transductive=True)
        self.assertEqual(split_res[0][0].edge_label_index.shape[1],
                         2 * edge_size_list[0])
        self.assertEqual(split_res[1][0].edge_label_index.shape[1],
                         2 * edge_size_list[1])
        self.assertEqual(split_res[2][0].edge_label_index.shape[1],
                         2 * edge_size_list[2])

        # transductive split with link_pred task
        pyg_dataset = Planetoid("./cora", "Cora")
        graphs = GraphDataset.pyg_to_graphs(pyg_dataset)
        split_ratio = [0.3, 0.3, 0.4]
        link_size_list = [0 for i in range(len(split_ratio))]

        for graph in graphs:
            split_offset = 0
            edges = list(graph.G.edges)
            random.shuffle(edges)
            num_edges_train = 1 + int(split_ratio[0] * (graph.num_edges - 3))
            num_edges_val = 1 + int(split_ratio[0] * (graph.num_edges - 3))
            edges_train = edges[:num_edges_train]
            edges_val = edges[num_edges_train:num_edges_train + num_edges_val]
            edges_test = edges[num_edges_train + num_edges_val:]

            custom_splits = [
                edges_train,
                edges_val,
                edges_test,
            ]
            graph.custom_splits = custom_splits

            link_size_list[0] += len(edges_train)
            link_size_list[1] += len(edges_val)
            link_size_list[2] += len(edges_test)

        dataset = GraphDataset(
            graphs,
            task="link_pred",
            general_split_mode="custom",
        )
        split_res = dataset.split(transductive=True)
        self.assertEqual(split_res[0][0].edge_label_index.shape[1],
                         2 * 2 * link_size_list[0])
        self.assertEqual(split_res[1][0].edge_label_index.shape[1],
                         2 * 2 * link_size_list[1])
        self.assertEqual(split_res[2][0].edge_label_index.shape[1],
                         2 * 2 * link_size_list[2])

        # inductive split with graph task
        pyg_dataset = TUDataset("./enzymes", "ENZYMES")
        graphs = GraphDataset.pyg_to_graphs(pyg_dataset)
        num_graphs = len(graphs)
        split_ratio = [0.3, 0.3, 0.4]
        graph_size_list = []
        split_offset = 0
        custom_split_graphs = []
        for i, split_ratio_i in enumerate(split_ratio):
            if i != len(split_ratio) - 1:
                num_split_i = (1 + int(split_ratio_i *
                                       (num_graphs - len(split_ratio))))
                custom_split_graphs.append(graphs[split_offset:split_offset +
                                                  num_split_i])
                split_offset += num_split_i
                graph_size_list.append(num_split_i)
            else:
                custom_split_graphs.append(graphs[split_offset:])
                graph_size_list.append(len(graphs[split_offset:]))
        dataset = GraphDataset(
            graphs,
            task="graph",
            general_split_mode="custom",
            custom_split_graphs=custom_split_graphs,
        )
        split_res = dataset.split(transductive=False)
        self.assertEqual(graph_size_list[0], len(split_res[0]))
        self.assertEqual(graph_size_list[1], len(split_res[1]))
        self.assertEqual(graph_size_list[2], len(split_res[2]))

        # transductive split with link_pred task in `disjoint` edge_train_mode.
        pyg_dataset = Planetoid("./cora", "Cora")
        graphs = GraphDataset.pyg_to_graphs(pyg_dataset)
        split_ratio = [0.3, 0.3, 0.4]
        link_size_list = [0 for i in range(len(split_ratio))]

        for graph in graphs:
            split_offset = 0
            edges = list(graph.G.edges)
            random.shuffle(edges)
            num_edges_train = 1 + int(split_ratio[0] * (graph.num_edges - 3))
            num_edges_train_disjoint = (1 + int(split_ratio[0] * 0.5 *
                                                (graph.num_edges - 3)))
            num_edges_val = 1 + int(split_ratio[0] * (graph.num_edges - 3))

            edges_train = edges[:num_edges_train]
            edges_train_disjoint = edges[:num_edges_train_disjoint]
            edges_val = edges[num_edges_train:num_edges_train + num_edges_val]
            edges_test = edges[num_edges_train + num_edges_val:]

            custom_splits = [
                edges_train,
                edges_val,
                edges_test,
            ]
            graph.custom_splits = custom_splits
            graph.custom_disjoint_split = edges_train_disjoint

            link_size_list[0] += len(edges_train_disjoint)
            link_size_list[1] += len(edges_val)
            link_size_list[2] += len(edges_test)

        dataset = GraphDataset(
            graphs,
            task="link_pred",
            edge_train_mode="disjoint",
            general_split_mode="custom",
            disjoint_split_mode="custom",
        )
        split_res = dataset.split(transductive=True)
        self.assertEqual(split_res[0][0].edge_label_index.shape[1],
                         2 * 2 * link_size_list[0])
        self.assertEqual(split_res[1][0].edge_label_index.shape[1],
                         2 * 2 * link_size_list[1])
        self.assertEqual(split_res[2][0].edge_label_index.shape[1],
                         2 * 2 * link_size_list[2])
예제 #18
0
def main():
    args = arg_parse()

    name = 'BioSNAP-FF'
    f = 'minerff.tsv'
    f2 = 'minerf.tsv'
    d = readFilePD(f, ['relation'])
    d2 = readFilePD(f2, ['namespace'])
    nxg = pdToNx3(d, d2, 'GO_id0', 'GO_id2', 'relation', 'GO_id1', 'namespace')
    dg = Graph(nxg)
    graphs = [dg]

    # the input that we assume users have
    edge_train_mode = args.mode
    print('edge train mode: {}'.format(edge_train_mode))

    #graphs = GraphDataset(graphs)
    if args.multigraph:
        graphs = [copy.deepcopy(graphs[0]) for _ in range(10)]

    dataset = GraphDataset(graphs,
                           task='link_pred',
                           edge_message_ratio=args.edge_message_ratio,
                           edge_train_mode=edge_train_mode)
    print('Initial dataset: {}'.format(dataset))

    # split dataset
    datasets = {}
    datasets['train'], datasets['val'], datasets['test'] = dataset.split(
        transductive=not args.multigraph, split_ratio=[0.85, 0.05, 0.1])

    print('after split')
    print('Train message-passing graph: {} nodes; {} edges.'.format(
        datasets['train'][0].G.number_of_nodes(),
        datasets['train'][0].G.number_of_edges()))
    print('Val message-passing graph: {} nodes; {} edges.'.format(
        datasets['val'][0].G.number_of_nodes(),
        datasets['val'][0].G.number_of_edges()))
    print('Test message-passing graph: {} nodes; {} edges.'.format(
        datasets['test'][0].G.number_of_nodes(),
        datasets['test'][0].G.number_of_edges()))

    # node feature dimension
    input_dim = 47410  #datasets['train'].num_node_features
    # link prediction needs 2 classes (0, 1)
    num_classes = datasets['train'].num_edge_labels
    #print('num_edge_labels',datasets['train'].num_edge_labels)

    model = Net(input_dim, num_classes, args).to(args.device)
    #optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-3)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=0.1,
                                momentum=0.9,
                                weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           T_max=args.epochs)
    follow_batch = []  # e.g., follow_batch = ['edge_index']

    dataloaders = {
        split: DataLoader(ds,
                          collate_fn=Batch.collate(follow_batch),
                          batch_size=args.batch_size,
                          shuffle=(split == 'train'))
        for split, ds in datasets.items()
    }
    print('Graphs after split: ')
    for key, dataloader in dataloaders.items():
        for batch in dataloader:
            print(key, ': ', batch)

    train(model, dataloaders, optimizer, args, scheduler=scheduler)
예제 #19
0
if __name__ == "__main__":
    args = arg_parse()

    if args.dataset == 'enzymes':
        pyg_dataset = TUDataset('./enzymes', 'ENZYMES')
    elif args.dataset == 'dd':
        pyg_dataset = TUDataset('./dd', 'DD')
    else:
        raise ValueError("Unsupported dataset.")

    graphs = GraphDataset.pyg_to_graphs(pyg_dataset)

    dataset = GraphDataset(graphs, task="graph")
    datasets = {}
    datasets['train'], datasets['val'], datasets['test'] = dataset.split(
        transductive=False, split_ratio=[0.8, 0.1, 0.1])
    dataloaders = {
        split: DataLoader(dataset,
                          collate_fn=Batch.collate(),
                          batch_size=args.batch_size,
                          shuffle=True)
        for split, dataset in datasets.items()
    }

    num_classes = datasets['train'].num_graph_labels
    num_node_features = datasets['train'].num_node_features

    train(dataloaders['train'], dataloaders['val'], dataloaders['test'], args,
          num_node_features, num_classes, args.device)
예제 #20
0
        import snap
        import snapx as netlib
        print("Use SnapX as the backend network library.")
    else:
        raise ValueError("{} network library is not supported.".format(
            args.netlib))

    if args.split == 'random':
        graphs = GraphDataset.pyg_to_graphs(pyg_dataset,
                                            verbose=True,
                                            fixed_split=False,
                                            netlib=netlib)
        dataset = GraphDataset(graphs,
                               task='node')  # node, edge, link_pred, graph
        dataset_train, dataset_val, dataset_test = dataset.split(
            transductive=True,
            split_ratio=[0.8, 0.1, 0.1])  # transductive split, inductive split
    else:
        graphs_train, graphs_val, graphs_test = \
            GraphDataset.pyg_to_graphs(pyg_dataset, verbose=True,
                    fixed_split=True, netlib=netlib)

        dataset_train, dataset_val, dataset_test = \
            GraphDataset(graphs_train, task='node'), GraphDataset(graphs_val,task='node'), \
            GraphDataset(graphs_test, task='node')

    train_loader = DataLoader(dataset_train,
                              collate_fn=Batch.collate(),
                              batch_size=16)  # basic data loader
    val_loader = DataLoader(dataset_val,
                            collate_fn=Batch.collate(),
    def test_secure_split(self):
        G = simple_networkx_small_graph()
        graph = Graph(G)
        graph = Graph(node_label=graph.node_label,
                      edge_index=graph.edge_index,
                      edge_label=graph.edge_label,
                      directed=True)
        graphs = [graph]

        # node task
        dataset = GraphDataset(graphs, task="node")
        num_nodes = dataset.num_nodes[0]
        num_nodes_reduced = num_nodes - 3
        node_0 = 1 + int(0.8 * num_nodes_reduced)
        node_1 = 1 + int(0.1 * num_nodes_reduced)
        node_2 = num_nodes - node_0 - node_1
        node_size = [node_0, node_1, node_2]

        split_res = dataset.split()
        for i in range(3):
            self.assertEqual(split_res[i][0].node_label_index.shape[0],
                             node_size[i])
            self.assertEqual(split_res[i][0].node_label.shape[0], node_size[i])

        # edge task
        dataset = GraphDataset(graphs, task="edge")
        num_edges = dataset.num_edges[0]
        num_edges_reduced = num_edges - 3
        edge_0 = 1 + int(0.8 * num_edges_reduced)
        edge_1 = 1 + int(0.1 * num_edges_reduced)
        edge_2 = num_edges - edge_0 - edge_1
        edge_size = [edge_0, edge_1, edge_2]

        split_res = dataset.split()
        for i in range(3):
            self.assertEqual(split_res[i][0].edge_label_index.shape[1],
                             edge_size[i])
            self.assertEqual(split_res[i][0].edge_label.shape[0], edge_size[i])

        # link_pred task
        dataset = GraphDataset(graphs, task="link_pred")
        num_edges = dataset.num_edges[0]
        num_edges_reduced = num_edges - 3
        edge_0 = 2 * (1 + int(0.8 * num_edges_reduced))
        edge_1 = 2 * (1 + int(0.1 * num_edges_reduced))
        edge_2 = 2 * num_edges - edge_0 - edge_1
        edge_size = [edge_0, edge_1, edge_2]

        split_res = dataset.split()
        for i in range(3):
            self.assertEqual(split_res[i][0].edge_label_index.shape[1],
                             edge_size[i])
            self.assertEqual(split_res[i][0].edge_label.shape[0], edge_size[i])

        # graph task
        graphs = [deepcopy(graph) for _ in range(5)]
        dataset = GraphDataset(graphs, task="link_pred")
        num_graphs = len(dataset)
        num_graphs_reduced = num_graphs - 3
        num_train = 1 + int(num_graphs_reduced * 0.8)
        num_val = 1 + int(num_graphs_reduced * 0.1)
        num_test = num_graphs - num_train - num_val
        split_res = dataset.split(transductive=False)
        self.assertEqual(num_train, len(split_res[0]))
        self.assertEqual(num_val, len(split_res[1]))
        self.assertEqual(num_test, len(split_res[2]))
예제 #22
0
    def test_dataset_hetero_graph_split(self):
        G = generate_dense_hete_dataset()
        hete = HeteroGraph(G)
        # node
        dataset = GraphDataset([hete], task='node')
        split_res = dataset.split()
        for node_type in hete.node_label_index:
            num_nodes = int(len(hete.node_label_index[node_type]))
            num_nodes_reduced = num_nodes - 3
            node_0 = 1 + int(num_nodes_reduced * 0.8)
            node_1 = 1 + int(num_nodes_reduced * 0.1)
            node_2 = num_nodes - node_0 - node_1

            self.assertEqual(
                len(split_res[0][0].node_label_index[node_type]), node_0)

            self.assertEqual(
                len(split_res[1][0].node_label_index[node_type]), node_1)

            self.assertEqual(
                len(split_res[2][0].node_label_index[node_type]), node_2)

        # node with specified split type
        dataset = GraphDataset([hete], task='node')
        node_split_types = ['n1']
        split_res = dataset.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = int(len(hete.node_label_index[node_type]))
                num_nodes_reduced = num_nodes - 3
                node_0 = 1 + int(num_nodes_reduced * 0.8)
                node_1 = 1 + int(num_nodes_reduced * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]), node_0)

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]), node_1)

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]), node_2)
            else:
                num_nodes = int(len(hete.node_label_index[node_type]))
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]), num_nodes)

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]), num_nodes)

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]), num_nodes)

        # node with specified split type (string mode)
        dataset = GraphDataset([hete], task='node')
        node_split_types = 'n1'
        split_res = dataset.split(split_types=node_split_types)
        for node_type in hete.node_label_index:
            if node_type in node_split_types:
                num_nodes = int(len(hete.node_label_index[node_type]))
                num_nodes_reduced = num_nodes - 3
                node_0 = 1 + int(num_nodes_reduced * 0.8)
                node_1 = 1 + int(num_nodes_reduced * 0.1)
                node_2 = num_nodes - node_0 - node_1
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]), node_0)

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]), node_1)

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]), node_2)
            else:
                num_nodes = int(len(hete.node_label_index[node_type]))
                self.assertEqual(
                    len(split_res[0][0].node_label_index[node_type]), num_nodes)

                self.assertEqual(
                    len(split_res[1][0].node_label_index[node_type]), num_nodes)

                self.assertEqual(
                    len(split_res[2][0].node_label_index[node_type]), num_nodes)

        # edge
        dataset = GraphDataset([hete], task='edge')
        split_res = dataset.split()
        for edge_type in hete.edge_label_index:
            num_edges = hete.edge_label_index[edge_type].shape[1]
            num_edges_reduced = num_edges - 3
            edge_0 = 1 + int(num_edges_reduced * 0.8)
            edge_1 = 1 + int(num_edges_reduced * 0.1)
            edge_2 = num_edges - edge_0 - edge_1
            self.assertEqual(
                split_res[0][0].edge_label_index[edge_type].shape[1], edge_0)

            self.assertEqual(
                split_res[1][0].edge_label_index[edge_type].shape[1], edge_1)

            self.assertEqual(
                split_res[2][0].edge_label_index[edge_type].shape[1], edge_2)

        # edge with specified split type
        dataset = GraphDataset([hete], task='edge')
        edge_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')]
        split_res = dataset.split(split_types=edge_split_types)
        for edge_type in hete.edge_label_index:
            if edge_type in edge_split_types:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                num_edges_reduced = num_edges - 3
                edge_0 = 1 + int(num_edges_reduced * 0.8)
                edge_1 = 1 + int(num_edges_reduced * 0.1)
                edge_2 = num_edges - edge_0 - edge_1
                self.assertEqual(
                    split_res[0][0].edge_label_index[edge_type].shape[1], edge_0)

                self.assertEqual(
                    split_res[1][0].edge_label_index[edge_type].shape[1], edge_1)

                self.assertEqual(
                    split_res[2][0].edge_label_index[edge_type].shape[1], edge_2)
            else:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                self.assertEqual(
                    split_res[0][0].edge_label_index[edge_type].shape[1], num_edges)

                self.assertEqual(
                    split_res[1][0].edge_label_index[edge_type].shape[1], num_edges)

                self.assertEqual(
                    split_res[2][0].edge_label_index[edge_type].shape[1], num_edges)

        # link_pred
        dataset = GraphDataset([hete], task='link_pred')
        split_res = dataset.split(transductive=True)
        for edge_type in hete.edge_label_index:
            num_edges = hete.edge_label_index[edge_type].shape[1]
            num_edges_reduced = num_edges - 3
            self.assertEqual(split_res[0][0].edge_label_index[edge_type].shape[1],
                             (2 * (1 + int(0.8 * (num_edges_reduced)))))
            self.assertEqual(split_res[1][0].edge_label_index[edge_type].shape[1],
                             (2 * (1 + (int(0.1 * (num_edges_reduced))))))
            self.assertEqual(split_res[2][0].edge_label_index[edge_type].shape[1],
                             2 * num_edges - 2 * (2 + int(0.1 * num_edges_reduced) +
                                                  int(0.8 * num_edges_reduced)))

        # link_pred with specified split type
        dataset = GraphDataset([hete], task='link_pred')
        link_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')]
        split_res = dataset.split(transductive=True, split_types=link_split_types)

        for edge_type in hete.edge_label_index:
            if edge_type in link_split_types:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                num_edges_reduced = num_edges - 3
                self.assertEqual(split_res[0][0].edge_label_index[edge_type].shape[1],
                                 (2 * (1 + int(0.8 * (num_edges_reduced)))))
                self.assertEqual(split_res[1][0].edge_label_index[edge_type].shape[1],
                                 (2 * (1 + (int(0.1 * (num_edges_reduced))))))
                self.assertEqual(split_res[2][0].edge_label_index[edge_type].shape[1],
                                 2 * num_edges - 2 * (2 + int(0.1 * num_edges_reduced) +
                                                      int(0.8 * num_edges_reduced)))
            else:
                num_edges = hete.edge_label_index[edge_type].shape[1]
                self.assertEqual(split_res[0][0].edge_label_index[edge_type].shape[1],
                                 (1 * (0 + int(1.0 * (num_edges)))))
                self.assertEqual(split_res[1][0].edge_label_index[edge_type].shape[1],
                                 (1 * (0 + (int(1.0 * (num_edges))))))
                self.assertEqual(split_res[2][0].edge_label_index[edge_type].shape[1],
                                 1 * (0 + (int(1.0 * (num_edges)))))

        # link_pred + disjoint
        dataset = GraphDataset([hete], task='link_pred', edge_train_mode='disjoint', edge_message_ratio=0.5)
        split_res = dataset.split(transductive=True, split_ratio=[0.6, 0.2, 0.2])
        for edge_type in hete.edge_label_index:
            num_edges = hete.edge_label_index[edge_type].shape[1]
            num_edges_reduced = num_edges - 3
            edge_0 = (1 + int(0.6 * num_edges_reduced))
            edge_0 = 2 * (edge_0 - (1 + int(0.5 * (edge_0 - 2))))

            self.assertEqual(split_res[0][0].edge_label_index[edge_type].shape[1], edge_0)
            edge_1 = 2 * (1 + int(0.2 * num_edges_reduced))
            self.assertEqual(split_res[1][0].edge_label_index[edge_type].shape[1], edge_1)
            edge_2 = 2 * int(num_edges) - \
                (2 * (1 + int(0.6 * num_edges_reduced))) - edge_1
            self.assertEqual(split_res[2][0].edge_label_index[edge_type].shape[1], edge_2)

        # link pred with edge_split_mode set to "exact"
        dataset = GraphDataset([hete], task='link_pred', edge_split_mode="approximate")
        split_res = dataset.split(transductive=True)
        hete_link_train_edge_num = 0
        hete_link_test_edge_num = 0
        hete_link_val_edge_num = 0
        num_edges = 0
        for edge_type in hete.edge_label_index:
            num_edges += hete.edge_label_index[edge_type].shape[1]
            if edge_type in split_res[0][0].edge_label_index:
                hete_link_train_edge_num += split_res[0][0].edge_label_index[edge_type].shape[1]
            if edge_type in split_res[1][0].edge_label_index:
                hete_link_test_edge_num += split_res[1][0].edge_label_index[edge_type].shape[1]
            if edge_type in split_res[2][0].edge_label_index:
                hete_link_val_edge_num += split_res[2][0].edge_label_index[edge_type].shape[1]

        num_edges_reduced = num_edges - 3
        self.assertEqual(hete_link_train_edge_num,
                         (2 * (1 + int(0.8 * (num_edges_reduced)))))
        self.assertEqual(hete_link_test_edge_num,
                         (2 * (1 + (int(0.1 * (num_edges_reduced))))))
        self.assertEqual(hete_link_val_edge_num,
                         2 * num_edges - 2 * (2 + int(0.1 * num_edges_reduced) +
                                              int(0.8 * num_edges_reduced)))

        # link pred with specified types and edge_split_mode set to "exact"
        dataset = GraphDataset([hete], task='link_pred', edge_split_mode="approximate")
        link_split_types = [('n1', 'e1', 'n1'), ('n1', 'e2', 'n2')]
        split_res = dataset.split(transductive=True, split_types=link_split_types)
        hete_link_train_edge_num = 0
        hete_link_test_edge_num = 0
        hete_link_val_edge_num = 0

        num_split_type_edges = 0
        num_non_split_type_edges = 0
        for edge_type in hete.edge_label_index:
            if edge_type in link_split_types:
                num_split_type_edges += hete.edge_label_index[edge_type].shape[1]
            else:
                num_non_split_type_edges += hete.edge_label_index[edge_type].shape[1]
            if edge_type in split_res[0][0].edge_label_index:
                hete_link_train_edge_num += split_res[0][0].edge_label_index[edge_type].shape[1]
            if edge_type in split_res[1][0].edge_label_index:
                hete_link_test_edge_num += split_res[1][0].edge_label_index[edge_type].shape[1]
            if edge_type in split_res[2][0].edge_label_index:
                hete_link_val_edge_num += split_res[2][0].edge_label_index[edge_type].shape[1]

        num_edges_reduced = num_split_type_edges - 3
        edge_0 = 2 * (1 + int(0.8 * (num_edges_reduced))) + num_non_split_type_edges
        edge_1 = 2 * (1 + int(0.1 * (num_edges_reduced))) + num_non_split_type_edges
        edge_2 = 2 * num_split_type_edges - 2 * (2 + int(0.1 * num_edges_reduced) + \
                                                 int(0.8 * num_edges_reduced)) + num_non_split_type_edges

        self.assertEqual(hete_link_train_edge_num, edge_0)
        self.assertEqual(hete_link_test_edge_num, edge_1)
        self.assertEqual(hete_link_val_edge_num, edge_2)
예제 #23
0
def main():
    args = arg_parse()

    edge_train_mode = args.mode
    print('edge train mode: {}'.format(edge_train_mode))

    WN_graph = nx.read_gpickle(args.data_path)
    print('Each node has node ID (n_id). Example: ', WN_graph.nodes[0])
    print(
        'Each edge has edge ID (id) and categorical label (e_label). Example: ',
        WN_graph[0][5871])

    # Since both feature and label are relation types,
    # Only the disjoint mode would make sense
    dataset = GraphDataset(
        [WN_graph],
        task='link_pred',
        edge_train_mode=edge_train_mode,
        edge_message_ratio=args.edge_message_ratio,
        edge_negative_sampling_ratio=args.neg_sampling_ratio)

    # find num edge types
    max_label = 0
    labels = []
    for u, v, edge_key in WN_graph.edges:
        l = WN_graph[u][v][edge_key]['e_label']
        if not l in labels:
            labels.append(l)
    # labels are consecutive (0-17)
    num_edge_types = len(labels)

    print('Pre-transform: ', dataset[0])
    dataset = dataset.apply_transform(WN_transform,
                                      num_edge_types=num_edge_types,
                                      deep_copy=False)
    print('Post-transform: ', dataset[0])
    print('Initial data: {} nodes; {} edges.'.format(
        dataset[0].G.number_of_nodes(), dataset[0].G.number_of_edges()))
    print('Number of node features: {}'.format(dataset.num_node_features))

    # split dataset
    datasets = {}
    datasets['train'], datasets['val'], datasets['test'] = dataset.split(
        transductive=True, split_ratio=[0.8, 0.1, 0.1])

    print('After split:')
    print('Train message-passing graph: {} nodes; {} edges.'.format(
        datasets['train'][0].G.number_of_nodes(),
        datasets['train'][0].G.number_of_edges()))
    print('Val message-passing graph: {} nodes; {} edges.'.format(
        datasets['val'][0].G.number_of_nodes(),
        datasets['val'][0].G.number_of_edges()))
    print('Test message-passing graph: {} nodes; {} edges.'.format(
        datasets['test'][0].G.number_of_nodes(),
        datasets['test'][0].G.number_of_edges()))

    # node feature dimension
    input_dim = datasets['train'].num_node_features
    edge_feat_dim = datasets['train'].num_edge_features
    num_classes = datasets['train'].num_edge_labels
    print(
        'Node feature dim: {}; edge feature dim: {}; num classes: {}.'.format(
            input_dim, edge_feat_dim, num_classes))

    # relation type is both used for edge features and edge labels
    model = Net(input_dim, edge_feat_dim, num_classes, args).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=0.001,
                                 weight_decay=5e-3)
    follow_batch = []  # e.g., follow_batch = ['edge_index']

    dataloaders = {
        split: DataLoader(ds,
                          collate_fn=Batch.collate(follow_batch),
                          batch_size=1,
                          shuffle=(split == 'train'))
        for split, ds in datasets.items()
    }
    print('Graphs after split: ')
    for key, dataloader in dataloaders.items():
        for batch in dataloader:
            print(key, ': ', batch)

    train(model, dataloaders, optimizer, args)
    def __init__(self,
                 res_graph_path,
                 user_graph_path,
                 user_per_res_path,
                 split=[0.8, 0.1, 0.1]):
        """
        res_graph_path = "../graphs/restaurants_with_categories.gpickle"
        user_graph_path = "../graphs/2017-2018_user_network.gpickle"
        user_per_res_path = "../datasets/2017-2018_visited_users.csv"
        """
        # 1. restaurant graphs
        self.res_G = nx.read_gpickle(res_graph_path)

        print(f"Number of restaurants: {self.res_G.number_of_nodes()}")
        print(f"Number of neighbors: {self.res_G.number_of_edges()}")

        self.res_idx2node = dict(enumerate(self.res_G.nodes()))
        self.res_node2idx = {
            node: idx
            for idx, node in self.res_idx2node.items()
        }

        print("converting restaurant graph to pyg graph...", end=" ")
        self.res_pyg_graph = Graph(self.res_G)
        self.res_pyg_graph.node_label = torch.LongTensor(
            self.res_pyg_graph.node_label)
        print("done!")

        # 2. user graph
        self.user_G = nx.read_gpickle(user_graph_path)

        print(f"Number of users: {self.user_G.number_of_nodes()}")
        print(f"Number of friends: {self.user_G.number_of_edges()}")

        self.user_idx2node = dict(enumerate(self.user_G.nodes()))
        self.user_node2idx = {
            node: idx
            for idx, node in self.user_idx2node.items()
        }

        print("converting restaurant graph to pyg graph...", end=" ")
        self.user_pyg_graph = Graph(self.user_G)
        print("done!")

        # 3. visited users per restaurant
        self.visited_user_df = pd.read_csv(user_per_res_path)
        self.visited_user_df.set_index("business_id", inplace=True)
        self.visited_user_df["user_ids"] = self.visited_user_df[
            "user_ids"].apply(eval)

        self.max_k = self.visited_user_df["user_ids"].apply(len).max()
        print(self.max_k)

        # split
        dataset = GraphDataset(graphs=[self.res_pyg_graph], task='node')
        dataset_train, dataset_val, dataset_test = dataset.split(
            transductive=True, split_ratio=split, shuffle=True)
        self.train_index = dataset_train.graphs[0].node_label_index
        self.val_index = dataset_val.graphs[0].node_label_index
        self.test_index = dataset_test.graphs[0].node_label_index

        self.res_x = self.res_pyg_graph.node_feature
        self.user_x = self.user_pyg_graph.node_feature
        self.labels = self.res_pyg_graph.node_label