Esempio n. 1
0
def load_data(args):
    '''Wraps the dgl's load_data utility to handle ppi special case'''
    if args.dataset != 'ppi':
        return _load_data(args)
    train_dataset = PPIDataset('train')
    val_dataset = PPIDataset('valid')
    test_dataset = PPIDataset('test')
    PPIDataType = namedtuple('PPIDataset', [
        'train_mask', 'test_mask', 'val_mask', 'features', 'labels',
        'num_labels', 'graph'
    ])
    G = dgl.BatchedDGLGraph(
        [train_dataset.graph, val_dataset.graph, test_dataset.graph],
        edge_attrs=None,
        node_attrs=None)
    G = G.to_networkx()
    # hack to dodge the potential bugs of to_networkx
    for (n1, n2, d) in G.edges(data=True):
        d.clear()
    train_nodes_num = train_dataset.graph.number_of_nodes()
    test_nodes_num = test_dataset.graph.number_of_nodes()
    val_nodes_num = val_dataset.graph.number_of_nodes()
    nodes_num = G.number_of_nodes()
    assert (nodes_num == (train_nodes_num + test_nodes_num + val_nodes_num))
    # construct mask
    mask = np.zeros((nodes_num, ), dtype=bool)
    train_mask = mask.copy()
    train_mask[:train_nodes_num] = True
    val_mask = mask.copy()
    val_mask[train_nodes_num:-test_nodes_num] = True
    test_mask = mask.copy()
    test_mask[-test_nodes_num:] = True

    # construct features
    features = np.concatenate(
        [train_dataset.features, val_dataset.features, test_dataset.features],
        axis=0)

    labels = np.concatenate(
        [train_dataset.labels, val_dataset.labels, test_dataset.labels],
        axis=0)

    data = PPIDataType(graph=G,
                       train_mask=train_mask,
                       test_mask=test_mask,
                       val_mask=val_mask,
                       features=features,
                       labels=labels,
                       num_labels=121)
    return data
Esempio n. 2
0
def load_data(args):
    '''Wraps the dgl's load_data utility to handle ppi special case'''
    DataType = namedtuple('Dataset', ['num_classes', 'g'])
    if args.dataset != 'ppi':
        dataset = _load_data(args)
        data = DataType(g=dataset[0], num_classes=dataset.num_classes)
        return data
    train_dataset = PPIDataset('train')
    train_graph = dgl.batch(
        [train_dataset[i] for i in range(len(train_dataset))],
        edge_attrs=None,
        node_attrs=None)
    val_dataset = PPIDataset('valid')
    val_graph = dgl.batch([val_dataset[i] for i in range(len(val_dataset))],
                          edge_attrs=None,
                          node_attrs=None)
    test_dataset = PPIDataset('test')
    test_graph = dgl.batch([test_dataset[i] for i in range(len(test_dataset))],
                           edge_attrs=None,
                           node_attrs=None)
    G = dgl.batch([train_graph, val_graph, test_graph],
                  edge_attrs=None,
                  node_attrs=None)

    train_nodes_num = train_graph.number_of_nodes()
    test_nodes_num = test_graph.number_of_nodes()
    val_nodes_num = val_graph.number_of_nodes()
    nodes_num = G.number_of_nodes()
    assert (nodes_num == (train_nodes_num + test_nodes_num + val_nodes_num))
    # construct mask
    mask = np.zeros((nodes_num, ), dtype=bool)
    train_mask = mask.copy()
    train_mask[:train_nodes_num] = True
    val_mask = mask.copy()
    val_mask[train_nodes_num:-test_nodes_num] = True
    test_mask = mask.copy()
    test_mask[-test_nodes_num:] = True

    G.ndata['train_mask'] = torch.tensor(train_mask, dtype=torch.bool)
    G.ndata['val_mask'] = torch.tensor(val_mask, dtype=torch.bool)
    G.ndata['test_mask'] = torch.tensor(test_mask, dtype=torch.bool)

    data = DataType(g=G, num_classes=train_dataset.num_labels)
    return data