def batch_nx_graphs(graphs, anchors=None): #motifs_batch = [pyg_utils.from_networkx( # nx.convert_node_labels_to_integers(graph)) for graph in graphs] #loader = DataLoader(motifs_batch, batch_size=len(motifs_batch)) #for b in loader: batch = b augmenter = feature_preprocess.FeatureAugment() if anchors is not None: for anchor, g in zip(anchors, graphs): for v in g.nodes: g.nodes[v]["node_feature"] = torch.tensor([float(v == anchor)]) if 'aifb' == 'aifb' or 'wn18' == 'wn18': # 90 edge types for g in graphs: for e in g.edges: # tmp = torch.zeros(90) # tmp[g.edges[e]['edge_type']] = 1. g.edges[e]["edge_feature"] = torch.tensor( [g.edges[e]['edge_type']], dtype=torch.long) batch = Batch.from_data_list(GraphDataset.list_to_graphs(graphs)) batch = augmenter.augment(batch) batch = batch.to(get_device()) return batch
def test_unbatch_nested(self): dims = [2, 3] G_sizes = [10, 5] G_list = [] for i, size in enumerate(G_sizes): G = Graph() G.G = nx.complete_graph(i + 1) G.node_property = { "node_prop0": torch.ones(size, dims[0]) * i, "node_prop1": torch.ones(size, dims[1]) * i, } G_list.append(G) batch = Batch.from_data_list(G_list) # reconstruct graph list G_list_recon = batch.to_data_list() self.assertEqual( G_list_recon[0].node_property["node_prop0"].size(0), 10, ) self.assertEqual( G_list_recon[0].node_property["node_prop0"].size(1), 2, ) self.assertEqual( G_list_recon[1].node_property["node_prop1"].size(0), 5, ) self.assertEqual( G_list_recon[1].node_property["node_prop1"].size(1), 3, )
def test_batch_basic(self): G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = \ simple_networkx_graph() Graph.add_edge_attr(G, "edge_feature", edge_x) Graph.add_edge_attr(G, "edge_label", edge_y) Graph.add_node_attr(G, "node_feature", x) Graph.add_node_attr(G, "node_label", y) Graph.add_graph_attr(G, "graph_feature", graph_x) Graph.add_graph_attr(G, "graph_label", graph_y) H = deepcopy(G) graphs = GraphDataset.list_to_graphs([G, H]) batch = Batch.from_data_list(graphs) self.assertEqual(batch.num_graphs, 2) self.assertEqual(len(batch.node_feature), 2 * len(graphs[0].node_feature))
def batch_nx_graphs(graphs, anchors=None): #motifs_batch = [pyg_utils.from_networkx( # nx.convert_node_labels_to_integers(graph)) for graph in graphs] #loader = DataLoader(motifs_batch, batch_size=len(motifs_batch)) #for b in loader: batch = b augmenter = feature_preprocess.FeatureAugment() if anchors is not None: for anchor, g in zip(anchors, graphs): for v in g.nodes: g.nodes[v]["node_feature"] = torch.tensor([float(v == anchor)]) batch = Batch.from_data_list([DSGraph(g) for g in graphs]) batch = augmenter.augment(batch) batch = batch.to(get_device()) return batch
def test_collate_batch_nested(self): dims = [2, 3] G_sizes = [10, 5] G_list = [] for i, size in enumerate(G_sizes): G = Graph() G.G = nx.complete_graph(i + 1) G.node_property = { 'node_prop0': torch.ones(size, dims[0]) * i, 'node_prop1': torch.ones(size, dims[1]) * i } G_list.append(G) batch = Batch.from_data_list(G_list) self.assertEqual(batch.num_graphs, 2) self.assertEqual(batch.node_property['node_prop0'].size(0), sum(G_sizes))
def gen_batch(self, batch_target, batch_neg_target, batch_neg_query, train): def sample_subgraph(graph, offset=0, use_precomp_sizes=False, filter_negs=False, supersample_small_graphs=False, neg_target=None, hard_neg_idxs=None): if neg_target is not None: graph_idx = graph.G.graph["idx"] use_hard_neg = (hard_neg_idxs is not None and graph.G.graph["idx"] in hard_neg_idxs) done = False n_tries = 0 while not done: if use_precomp_sizes: size = graph.G.graph["subgraph_size"] else: if train and supersample_small_graphs: sizes = np.arange(self.min_size + offset, len(graph.G) + offset) ps = (sizes - self.min_size + 2)**(-1.1) ps /= ps.sum() size = stats.rv_discrete(values=(sizes, ps)).rvs() else: d = 1 if train else 0 size = random.randint(self.min_size + offset - d, len(graph.G) - 1 + offset) start_node = random.choice(list(graph.G.nodes)) neigh = [start_node] frontier = list( set(graph.G.neighbors(start_node)) - set(neigh)) visited = set([start_node]) while len(neigh) < size: new_node = random.choice(list(frontier)) assert new_node not in neigh neigh.append(new_node) visited.add(new_node) frontier += list(graph.G.neighbors(new_node)) frontier = [x for x in frontier if x not in visited] if self.node_anchored: anchor = neigh[0] for v in graph.G.nodes: graph.G.nodes[v]["node_feature"] = ( torch.ones(1) if anchor == v else torch.zeros(1)) #print(v, graph.G.nodes[v]["node_feature"]) neigh = graph.G.subgraph(neigh) if use_hard_neg and train: neigh = neigh.copy() if random.random( ) < 1.0 or not self.node_anchored: # add edges non_edges = list(nx.non_edges(neigh)) if len(non_edges) > 0: for u, v in random.sample( non_edges, random.randint(1, min(len(non_edges), 5))): neigh.add_edge(u, v) else: # perturb anchor anchor = random.choice(list(neigh.nodes)) for v in neigh.nodes: neigh.nodes[v]["node_feature"] = (torch.ones(1) if anchor == v else torch.zeros(1)) if (filter_negs and train and len(neigh) <= 6 and neg_target is not None): matcher = nx.algorithms.isomorphism.GraphMatcher( neg_target[graph_idx], neigh) if not matcher.subgraph_is_isomorphic(): done = True else: done = True return graph, DSGraph(neigh) augmenter = feature_preprocess.FeatureAugment() pos_target = batch_target pos_target, pos_query = pos_target.apply_transform_multi( sample_subgraph) neg_target = batch_neg_target # TODO: use hard negs hard_neg_idxs = set( random.sample(range(len(neg_target.G)), int(len(neg_target.G) * 1 / 2))) #hard_neg_idxs = set() batch_neg_query = Batch.from_data_list( GraphDataset.list_to_graphs([ self.generator.generate( size=len(g)) if i not in hard_neg_idxs else g for i, g in enumerate(neg_target.G) ])) for i, g in enumerate(batch_neg_query.G): g.graph["idx"] = i _, neg_query = batch_neg_query.apply_transform_multi( sample_subgraph, hard_neg_idxs=hard_neg_idxs) if self.node_anchored: def add_anchor(g, anchors=None): if anchors is not None: anchor = anchors[g.G.graph["idx"]] else: anchor = random.choice(list(g.G.nodes)) for v in g.G.nodes: if "node_feature" not in g.G.nodes[v]: g.G.nodes[v]["node_feature"] = ( torch.ones(1) if anchor == v else torch.zeros(1)) return g neg_target = neg_target.apply_transform(add_anchor) pos_target = augmenter.augment(pos_target).to(utils.get_device()) pos_query = augmenter.augment(pos_query).to(utils.get_device()) neg_target = augmenter.augment(neg_target).to(utils.get_device()) neg_query = augmenter.augment(neg_query).to(utils.get_device()) #print(len(pos_target.G[0]), len(pos_query.G[0])) return pos_target, pos_query, neg_target, neg_query