def test_dataset_property(self): G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = ( simple_networkx_graph() ) Graph.add_edge_attr(G, "edge_feature", edge_x) Graph.add_edge_attr(G, "edge_label", edge_y) Graph.add_node_attr(G, "node_feature", x) Graph.add_node_attr(G, "node_label", y) Graph.add_graph_attr(G, "graph_feature", graph_x) Graph.add_graph_attr(G, "graph_label", graph_y) H = G.copy() Graph.add_graph_attr(H, "graph_label", torch.tensor([1])) graphs = GraphDataset.list_to_graphs([G, H]) dataset = GraphDataset(graphs) self.assertEqual(dataset.num_node_labels, 5) self.assertEqual(dataset.num_node_features, 2) self.assertEqual(dataset.num_edge_labels, 4) self.assertEqual(dataset.num_edge_features, 2) self.assertEqual(dataset.num_graph_labels, 2) self.assertEqual(dataset.num_graph_features, 2) self.assertEqual(dataset.num_labels, 5) # node task dataset = GraphDataset(graphs, task="edge") self.assertEqual(dataset.num_labels, 4) dataset = GraphDataset(graphs, task="link_pred") self.assertEqual(dataset.num_labels, 4) dataset = GraphDataset(graphs, task="graph") self.assertEqual(dataset.num_labels, 2)
def batch_nx_graphs(graphs, anchors=None): #motifs_batch = [pyg_utils.from_networkx( # nx.convert_node_labels_to_integers(graph)) for graph in graphs] #loader = DataLoader(motifs_batch, batch_size=len(motifs_batch)) #for b in loader: batch = b augmenter = feature_preprocess.FeatureAugment() if anchors is not None: for anchor, g in zip(anchors, graphs): for v in g.nodes: g.nodes[v]["node_feature"] = torch.tensor([float(v == anchor)]) if 'aifb' == 'aifb' or 'wn18' == 'wn18': # 90 edge types for g in graphs: for e in g.edges: # tmp = torch.zeros(90) # tmp[g.edges[e]['edge_type']] = 1. g.edges[e]["edge_feature"] = torch.tensor( [g.edges[e]['edge_type']], dtype=torch.long) batch = Batch.from_data_list(GraphDataset.list_to_graphs(graphs)) batch = augmenter.augment(batch) batch = batch.to(get_device()) return batch
def test_dataset_basic(self): G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = \ simple_networkx_graph() Graph.add_edge_attr(G, "edge_feature", edge_x) Graph.add_edge_attr(G, "edge_label", edge_y) Graph.add_node_attr(G, "node_feature", x) Graph.add_node_attr(G, "node_label", y) Graph.add_graph_attr(G, "graph_feature", graph_x) Graph.add_graph_attr(G, "graph_label", graph_y) H = deepcopy(G) graphs = GraphDataset.list_to_graphs([G, H]) dataset = GraphDataset(graphs) self.assertEqual(len(dataset), 2)
def test_batch_basic(self): G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = \ simple_networkx_graph() Graph.add_edge_attr(G, "edge_feature", edge_x) Graph.add_edge_attr(G, "edge_label", edge_y) Graph.add_node_attr(G, "node_feature", x) Graph.add_node_attr(G, "node_label", y) Graph.add_graph_attr(G, "graph_feature", graph_x) Graph.add_graph_attr(G, "graph_label", graph_y) H = deepcopy(G) graphs = GraphDataset.list_to_graphs([G, H]) batch = Batch.from_data_list(graphs) self.assertEqual(batch.num_graphs, 2) self.assertEqual(len(batch.node_feature), 2 * len(graphs[0].node_feature))
def batch_nx_graphs_multi(graphs, anchors=None): # motifs_batch = [pyg_utils.from_networkx( # nx.convert_node_labels_to_integers(graph)) for graph in graphs] # loader = DataLoader(motifs_batch, batch_size=len(motifs_batch)) # for b in loader: batch = b augmenter = feature_preprocess.FeatureAugment() if anchors is not None: for anchor, g in zip(anchors, graphs): for v in g.nodes: g.nodes[v]["node_feature"] = torch.tensor([float(v == anchor)]) batch = Batch.from_data_list(GraphDataset.list_to_graphs(graphs)) batch = augmenter.augment(batch) batch = batch.to(get_device()) return batch
def gen_data_loaders(self, size, batch_size, train=True, use_distributed_sampling=False): loaders = [] for i in range(2): neighs = [] for j in range(size // 2): graph, neigh = utils.sample_neigh( self.train_set if train else self.test_set, random.randint(self.min_size, self.max_size)) neighs.append(graph.subgraph(neigh)) dataset = GraphDataset(GraphDataset.list_to_graphs(neighs)) loaders.append( TorchDataLoader(dataset, collate_fn=Batch.collate([]), batch_size=batch_size // 2 if i == 0 else batch_size // 2, sampler=None, shuffle=False)) loaders.append([None] * (size // batch_size)) return loaders
def gen_batch(self, batch_target, batch_neg_target, batch_neg_query, train): def sample_subgraph(graph, offset=0, use_precomp_sizes=False, filter_negs=False, supersample_small_graphs=False, neg_target=None, hard_neg_idxs=None): if neg_target is not None: graph_idx = graph.G.graph["idx"] use_hard_neg = (hard_neg_idxs is not None and graph.G.graph["idx"] in hard_neg_idxs) done = False n_tries = 0 while not done: if use_precomp_sizes: size = graph.G.graph["subgraph_size"] else: if train and supersample_small_graphs: sizes = np.arange(self.min_size + offset, len(graph.G) + offset) ps = (sizes - self.min_size + 2)**(-1.1) ps /= ps.sum() size = stats.rv_discrete(values=(sizes, ps)).rvs() else: d = 1 if train else 0 size = random.randint(self.min_size + offset - d, len(graph.G) - 1 + offset) start_node = random.choice(list(graph.G.nodes)) neigh = [start_node] frontier = list( set(graph.G.neighbors(start_node)) - set(neigh)) visited = set([start_node]) while len(neigh) < size: new_node = random.choice(list(frontier)) assert new_node not in neigh neigh.append(new_node) visited.add(new_node) frontier += list(graph.G.neighbors(new_node)) frontier = [x for x in frontier if x not in visited] if self.node_anchored: anchor = neigh[0] for v in graph.G.nodes: graph.G.nodes[v]["node_feature"] = ( torch.ones(1) if anchor == v else torch.zeros(1)) #print(v, graph.G.nodes[v]["node_feature"]) neigh = graph.G.subgraph(neigh) if use_hard_neg and train: neigh = neigh.copy() if random.random( ) < 1.0 or not self.node_anchored: # add edges non_edges = list(nx.non_edges(neigh)) if len(non_edges) > 0: for u, v in random.sample( non_edges, random.randint(1, min(len(non_edges), 5))): neigh.add_edge(u, v) else: # perturb anchor anchor = random.choice(list(neigh.nodes)) for v in neigh.nodes: neigh.nodes[v]["node_feature"] = (torch.ones(1) if anchor == v else torch.zeros(1)) if (filter_negs and train and len(neigh) <= 6 and neg_target is not None): matcher = nx.algorithms.isomorphism.GraphMatcher( neg_target[graph_idx], neigh) if not matcher.subgraph_is_isomorphic(): done = True else: done = True return graph, DSGraph(neigh) augmenter = feature_preprocess.FeatureAugment() pos_target = batch_target pos_target, pos_query = pos_target.apply_transform_multi( sample_subgraph) neg_target = batch_neg_target # TODO: use hard negs hard_neg_idxs = set( random.sample(range(len(neg_target.G)), int(len(neg_target.G) * 1 / 2))) #hard_neg_idxs = set() batch_neg_query = Batch.from_data_list( GraphDataset.list_to_graphs([ self.generator.generate( size=len(g)) if i not in hard_neg_idxs else g for i, g in enumerate(neg_target.G) ])) for i, g in enumerate(batch_neg_query.G): g.graph["idx"] = i _, neg_query = batch_neg_query.apply_transform_multi( sample_subgraph, hard_neg_idxs=hard_neg_idxs) if self.node_anchored: def add_anchor(g, anchors=None): if anchors is not None: anchor = anchors[g.G.graph["idx"]] else: anchor = random.choice(list(g.G.nodes)) for v in g.G.nodes: if "node_feature" not in g.G.nodes[v]: g.G.nodes[v]["node_feature"] = ( torch.ones(1) if anchor == v else torch.zeros(1)) return g neg_target = neg_target.apply_transform(add_anchor) pos_target = augmenter.augment(pos_target).to(utils.get_device()) pos_query = augmenter.augment(pos_query).to(utils.get_device()) neg_target = augmenter.augment(neg_target).to(utils.get_device()) neg_query = augmenter.augment(neg_query).to(utils.get_device()) #print(len(pos_target.G[0]), len(pos_query.G[0])) return pos_target, pos_query, neg_target, neg_query
def main(): args = arg_parse() edge_train_mode = args.mode print('edge train mode: {}'.format(edge_train_mode)) WN_graph = nx.read_gpickle(args.data_path) print('Each node has node ID (n_id). Example: ', WN_graph.nodes[0]) print( 'Each edge has edge ID (id) and categorical label (e_label). Example: ', WN_graph[0][5871]) graphs = GraphDataset.list_to_graphs([WN_graph]) # Since both feature and label are relation types, # Only the disjoint mode would make sense dataset = GraphDataset( graphs, task='link_pred', edge_train_mode=edge_train_mode, edge_message_ratio=args.edge_message_ratio, edge_negative_sampling_ratio=args.neg_sampling_ratio) # find num edge types max_label = 0 labels = [] for u, v, edge_key in WN_graph.edges: l = WN_graph[u][v][edge_key]['e_label'] if not l in labels: labels.append(l) # labels are consecutive (0-17) num_edge_types = len(labels) print('Pre-transform: ', dataset[0]) dataset = dataset.apply_transform(WN_transform, num_edge_types=num_edge_types, deep_copy=False) print('Post-transform: ', dataset[0]) print('Initial data: {} nodes; {} edges.'.format( dataset[0].G.number_of_nodes(), dataset[0].G.number_of_edges())) print('Number of node features: {}'.format(dataset.num_node_features)) # split dataset datasets = {} datasets['train'], datasets['val'], datasets['test'] = dataset.split( transductive=True, split_ratio=[0.8, 0.1, 0.1]) print('After split:') print('Train message-passing graph: {} nodes; {} edges.'.format( datasets['train'][0].G.number_of_nodes(), datasets['train'][0].G.number_of_edges())) print('Val message-passing graph: {} nodes; {} edges.'.format( datasets['val'][0].G.number_of_nodes(), datasets['val'][0].G.number_of_edges())) print('Test message-passing graph: {} nodes; {} edges.'.format( datasets['test'][0].G.number_of_nodes(), datasets['test'][0].G.number_of_edges())) # node feature dimension input_dim = datasets['train'].num_node_features edge_feat_dim = datasets['train'].num_edge_features num_classes = datasets['train'].num_edge_labels print( 'Node feature dim: {}; edge feature dim: {}; num classes: {}.'.format( input_dim, edge_feat_dim, num_classes)) # relation type is both used for edge features and edge labels model = Net(input_dim, edge_feat_dim, num_classes, args).to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-3) follow_batch = [] # e.g., follow_batch = ['edge_index'] dataloaders = { split: DataLoader(ds, collate_fn=Batch.collate(follow_batch), batch_size=1, shuffle=(split == 'train')) for split, ds in datasets.items() } print('Graphs after split: ') for key, dataloader in dataloaders.items(): for batch in dataloader: print(key, ': ', batch) train(model, dataloaders, optimizer, args)