def gen_batch(self, graphs_a, graphs_b, _, train): def add_anchor(g): anchor = random.choice(list(g.G.nodes)) for v in g.G.nodes: g.G.nodes[v]["node_feature"] = (torch.ones(1) if anchor == v or not self.node_anchored else torch.zeros(1)) return g pos_a, pos_b, neg_a, neg_b = [], [], [], [] fn = "data/cache/imbalanced-{}-{}-{}".format(self.dataset_name.lower(), str(self.node_anchored), self.batch_idx) if not os.path.exists(fn): graphs_a = graphs_a.apply_transform(add_anchor) graphs_b = graphs_b.apply_transform(add_anchor) for graph_a, graph_b in tqdm(list(zip(graphs_a.G, graphs_b.G))): matcher = nx.algorithms.isomorphism.GraphMatcher( graph_a, graph_b, node_match=(lambda a, b: (a["node_feature"][0] > 0.5) == (b["node_feature"][0] > 0.5)) if self.node_anchored else None) if matcher.subgraph_is_isomorphic(): pos_a.append(graph_a) pos_b.append(graph_b) else: neg_a.append(graph_a) neg_b.append(graph_b) if not os.path.exists("data/cache"): os.makedirs("data/cache") with open(fn, "wb") as f: pickle.dump((pos_a, pos_b, neg_a, neg_b), f) print("saved", fn) else: with open(fn, "rb") as f: print("loaded", fn) pos_a, pos_b, neg_a, neg_b = pickle.load(f) print(len(pos_a), len(neg_a)) nx_graphs = ( deepcopy(pos_a), deepcopy(pos_b), deepcopy(neg_a), deepcopy(neg_b), ) if pos_a: pos_a = utils.batch_nx_graphs(pos_a) pos_b = utils.batch_nx_graphs(pos_b) neg_a = utils.batch_nx_graphs(neg_a) neg_b = utils.batch_nx_graphs(neg_b) self.batch_idx += 1 return pos_a, pos_b, neg_a, neg_b, nx_graphs
def gen_alignment_matrix(model, query, target, method_type="order"): """Generate subgraph matching alignment matrix for a given query and target graph. Each entry (u, v) of the matrix contains the confidence score the model gives for the query graph, anchored at u, being a subgraph of the target graph, anchored at v. Args: model: the subgraph matching model. Must have been trained with node anchored setting (--node_anchored, default) query: the query graph (networkx Graph) target: the target graph (networkx Graph) method_type: the method used for the model. "order" for order embedding or "mlp" for MLP model """ mat = np.zeros((len(query), len(target))) for u in query.nodes: for v in target.nodes: batch = utils.batch_nx_graphs([query, target], anchors=[u, v]) embs = model.emb_model(batch) pred = model(embs[0].unsqueeze(0), embs[1].unsqueeze(0)) raw_pred = model.predict(pred) if method_type == "order": raw_pred = torch.log(raw_pred) elif method_type == "mlp": raw_pred = raw_pred[0][1] mat[u][v] = raw_pred.item() return mat
def step(self): new_beam_sets = [] print("seeds come from", len(set(b[0][-1] for b in self.beam_sets)), "distinct graphs") analyze_embs_cur = [] for beam_set in tqdm(self.beam_sets): new_beams = [] for _, neigh, frontier, visited, graph_idx in beam_set: graph = self.dataset[graph_idx] if len(neigh) >= self.max_pattern_size or not frontier: continue cand_neighs, anchors = [], [] for cand_node in frontier: cand_neigh = graph.subgraph(neigh + [cand_node]) cand_neighs.append(cand_neigh) if self.node_anchored: anchors.append(neigh[0]) cand_embs = self.model.emb_model(utils.batch_nx_graphs( cand_neighs, anchors=anchors if self.node_anchored else None)) best_score, best_node = float("inf"), None for cand_node, cand_emb in zip(frontier, cand_embs): score, n_embs = 0, 0 for emb_batch in self.embs: n_embs += len(emb_batch) if self.model_type == "order": score -= torch.sum(torch.argmax( self.model.clf_model(self.model.predict(( emb_batch.to(utils.get_device()), cand_emb)).unsqueeze(1)), axis=1)).item() elif self.model_type == "mlp": score += torch.sum(self.model( emb_batch.to(utils.get_device()), cand_emb.unsqueeze(0).expand(len(emb_batch), -1) )[:,0]).item() else: print("unrecognized model type") if score < best_score: best_score = score best_node = cand_node new_frontier = list(((set(frontier) | set(graph.neighbors(cand_node))) - visited) - set([cand_node])) new_beams.append(( score, neigh + [cand_node], new_frontier, visited | set([cand_node]), graph_idx)) new_beams = list(sorted(new_beams, key=lambda x: x[0]))[:self.n_beams] for score, neigh, frontier, visited, graph_idx in new_beams[:1]: graph = self.dataset[graph_idx] # add to record neigh_g = graph.subgraph(neigh).copy() neigh_g.remove_edges_from(nx.selfloop_edges(neigh_g)) for v in neigh_g.nodes: neigh_g.nodes[v]["anchor"] = 1 if v == neigh[0] else 0 self.cand_patterns[len(neigh_g)].append((score, neigh_g)) if self.rank_method in ["counts", "hybrid"]: self.counts[len(neigh_g)][utils.wl_hash(neigh_g, node_anchored=self.node_anchored)].append(neigh_g) if self.analyze and len(neigh) >= 3: emb = self.model.emb_model(utils.batch_nx_graphs( [neigh_g], anchors=[neigh[0]] if self.node_anchored else None)).squeeze(0) analyze_embs_cur.append(emb.detach().cpu().numpy()) if len(new_beams) > 0: new_beam_sets.append(new_beams) self.beam_sets = new_beam_sets self.analyze_embs.append(analyze_embs_cur)
def step(self): ps = np.array([len(g) for g in self.dataset], dtype=np.float) ps /= np.sum(ps) graph_dist = stats.rv_discrete(values=(np.arange(len(self.dataset)), ps)) print("Size", self.max_size) print(len(self.visited_seed_nodes), "distinct seeds") for simulation_n in tqdm(range(self.n_trials // (self.max_pattern_size+1-self.min_pattern_size))): # pick seed node best_graph_idx, best_start_node, best_score = None, None, -float("inf") for cand_graph_idx, cand_start_node in self.visited_seed_nodes: state = cand_graph_idx, cand_start_node my_visit_counts = sum(self.visit_counts[state].values()) q_score = (sum(self.cum_action_values[state].values()) / (my_visit_counts or 1)) uct_score = self.c_uct * np.sqrt(np.log(simulation_n or 1) / (my_visit_counts or 1)) node_score = q_score + uct_score if node_score > best_score: best_score = node_score best_graph_idx = cand_graph_idx best_start_node = cand_start_node # if existing seed beats choosing a new seed if best_score >= self.c_uct * np.sqrt(np.log(simulation_n or 1)): graph_idx, start_node = best_graph_idx, best_start_node assert best_start_node in self.dataset[graph_idx].nodes graph = self.dataset[graph_idx] else: found = False while not found: graph_idx = np.arange(len(self.dataset))[graph_dist.rvs()] graph = self.dataset[graph_idx] start_node = random.choice(list(graph.nodes)) # don't pick isolated nodes or small islands if self.has_min_reachable_nodes(graph, start_node, self.min_pattern_size): found = True self.visited_seed_nodes.add((graph_idx, start_node)) neigh = [start_node] frontier = list(set(graph.neighbors(start_node)) - set(neigh)) visited = set([start_node]) neigh_g = nx.Graph() neigh_g.add_node(start_node, anchor=1) cur_state = graph_idx, start_node state_list = [cur_state] while frontier and len(neigh) < self.max_size: cand_neighs, anchors = [], [] for cand_node in frontier: cand_neigh = graph.subgraph(neigh + [cand_node]) cand_neighs.append(cand_neigh) if self.node_anchored: anchors.append(neigh[0]) cand_embs = self.model.emb_model(utils.batch_nx_graphs( cand_neighs, anchors=anchors if self.node_anchored else None)) best_v_score, best_node_score, best_node = 0, -float("inf"), None for cand_node, cand_emb in zip(frontier, cand_embs): score, n_embs = 0, 0 for emb_batch in self.embs: score += torch.sum(self.model.predict(( emb_batch.to(utils.get_device()), cand_emb))).item() n_embs += len(emb_batch) v_score = -np.log(score/n_embs + 1) + 1 # get wl hash of next state neigh_g = graph.subgraph(neigh + [cand_node]).copy() neigh_g.remove_edges_from(nx.selfloop_edges(neigh_g)) for v in neigh_g.nodes: neigh_g.nodes[v]["anchor"] = 1 if v == neigh[0] else 0 next_state = utils.wl_hash(neigh_g, node_anchored=self.node_anchored) # compute node score parent_visit_counts = sum(self.visit_counts[cur_state].values()) my_visit_counts = sum(self.visit_counts[next_state].values()) q_score = (sum(self.cum_action_values[next_state].values()) / (my_visit_counts or 1)) uct_score = self.c_uct * np.sqrt(np.log(parent_visit_counts or 1) / (my_visit_counts or 1)) node_score = q_score + uct_score if node_score > best_node_score: best_node_score = node_score best_v_score = v_score best_node = cand_node frontier = list(((set(frontier) | set(graph.neighbors(best_node))) - visited) - set([best_node])) visited.add(best_node) neigh.append(best_node) # update visit counts, wl cache neigh_g = graph.subgraph(neigh).copy() neigh_g.remove_edges_from(nx.selfloop_edges(neigh_g)) for v in neigh_g.nodes: neigh_g.nodes[v]["anchor"] = 1 if v == neigh[0] else 0 prev_state = cur_state cur_state = utils.wl_hash(neigh_g, node_anchored=self.node_anchored) state_list.append(cur_state) self.wl_hash_to_graphs[cur_state].append(neigh_g) # backprop value for i in range(0, len(state_list) - 1): self.cum_action_values[state_list[i]][ state_list[i+1]] += best_v_score self.visit_counts[state_list[i]][state_list[i+1]] += 1 self.max_size += 1
def gen_batch(self, a, b, c, train, max_size=15, min_size=5, seed=None, filter_negs=False, sample_method="tree-pair"): batch_size = a train_set, test_set, task = self.dataset graphs = train_set if train else test_set if seed is not None: random.seed(seed) pos_a, pos_b = [], [] pos_a_anchors, pos_b_anchors = [], [] for i in range(batch_size // 2): if sample_method == "tree-pair": size = random.randint(min_size + 1, max_size) graph, a = utils.sample_neigh(graphs, size) b = a[:random.randint(min_size, len(a) - 1)] elif sample_method == "subgraph-tree": graph = None while graph is None or len(graph) < min_size + 1: graph = random.choice(graphs) a = graph.nodes _, b = utils.sample_neigh([graph], random.randint( min_size, len(graph) - 1)) if self.node_anchored: anchor = list(graph.nodes)[0] pos_a_anchors.append(anchor) pos_b_anchors.append(anchor) neigh_a, neigh_b = graph.subgraph(a), graph.subgraph(b) pos_a.append(neigh_a) pos_b.append(neigh_b) neg_a, neg_b = [], [] neg_a_anchors, neg_b_anchors = [], [] while len(neg_a) < batch_size // 2: if sample_method == "tree-pair": size = random.randint(min_size + 1, max_size) graph_a, a = utils.sample_neigh(graphs, size) graph_b, b = utils.sample_neigh( graphs, random.randint(min_size, size - 1)) elif sample_method == "subgraph-tree": graph_a = None while graph_a is None or len(graph_a) < min_size + 1: graph_a = random.choice(graphs) a = graph_a.nodes graph_b, b = utils.sample_neigh( graphs, random.randint(min_size, len(graph_a) - 1)) if self.node_anchored: neg_a_anchors.append(list(graph_a.nodes)[0]) neg_b_anchors.append(list(graph_b.nodes)[0]) neigh_a, neigh_b = graph_a.subgraph(a), graph_b.subgraph(b) if filter_negs: matcher = nx.algorithms.isomorphism.GraphMatcher( neigh_a, neigh_b) if matcher.subgraph_is_isomorphic( ): # a <= b (b is subgraph of a) continue neg_a.append(neigh_a) neg_b.append(neigh_b) nx_graphs = ( deepcopy(pos_a), deepcopy(pos_b), deepcopy(neg_a), deepcopy(neg_b), deepcopy(pos_a_anchors), deepcopy(pos_b_anchors), deepcopy(neg_a_anchors), deepcopy(neg_b_anchors), ) pos_a = utils.batch_nx_graphs( pos_a, anchors=pos_a_anchors if self.node_anchored else None) pos_b = utils.batch_nx_graphs( pos_b, anchors=pos_b_anchors if self.node_anchored else None) neg_a = utils.batch_nx_graphs( neg_a, anchors=neg_a_anchors if self.node_anchored else None) neg_b = utils.batch_nx_graphs( neg_b, anchors=neg_b_anchors if self.node_anchored else None) return pos_a, pos_b, neg_a, neg_b, nx_graphs
def gen_batch(self, a, b, c, train, max_size=15, min_size=5, seed=None, filter_negs=False, sample_method="tree-pair", one_small=False): batch_size = a train_set, test_set, task = self.dataset graphs = train_set if train else test_set if seed is not None: random.seed(seed) pos_a, pos_b = [], [] pos_a_anchors, pos_b_anchors = [], [] for i in range(batch_size // 2): if sample_method == "tree-pair": size = random.randint(min_size + 1, max_size) graph, a = utils.sample_neigh(graphs, size) # hack if one_small: np.random.shuffle(a) b = a[1:] # hack else: b = a[:random.randint(min_size, len(a) - 1)] elif sample_method == "subgraph-tree": graph = None while graph is None or len(graph) < min_size + 1: graph = random.choice(graphs) a = graph.nodes _, b = utils.sample_neigh([graph], random.randint( min_size, len(graph) - 1)) if self.node_anchored: anchor = list(graph.nodes)[0] pos_a_anchors.append(anchor) pos_b_anchors.append(anchor) neigh_a, neigh_b = graph.subgraph(a), graph.subgraph(b) pos_a.append(neigh_a) pos_b.append(neigh_b) neg_a, neg_b = [], [] neg_a_anchors, neg_b_anchors = [], [] while len(neg_a) < batch_size // 2: if sample_method == "tree-pair": # hack if one_small: size = random.randint(min_size + 1, max_size) graph_a, a = utils.sample_neigh(graphs, size) graph_b, b = deepcopy((graph_a, a)) b = b[1:] # hack else: size = random.randint(min_size + 1, max_size) graph_a, a = utils.sample_neigh(graphs, size) graph_b, b = utils.sample_neigh( graphs, random.randint(min_size, size - 1)) elif sample_method == "subgraph-tree": graph_a = None while graph_a is None or len(graph_a) < min_size + 1: graph_a = random.choice(graphs) a = graph_a.nodes graph_b, b = utils.sample_neigh( graphs, random.randint(min_size, len(graph_a) - 1)) if self.node_anchored: neg_a_anchors.append(list(graph_a.nodes)[0]) neg_b_anchors.append(list(graph_b.nodes)[0]) if one_small: neigh_a, neigh_b = graph_a.subgraph(a), graph_b.subgraph(b) tmp = list(nx.connected_components(neigh_b.to_undirected())) tmp = sorted(tmp, key=lambda x: -len(x)) neigh_b = neigh_b.subgraph(tmp[0]) b_ne = list(nx.non_edges(neigh_b)) np.random.shuffle(b_ne) b_ne = b_ne[:np.random.randint(1, 3)] neigh_b = nx.DiGraph(neigh_b) neigh_b.add_edges_from(b_ne) for i in b_ne: neigh_b.edges[i]['edge_type'] = np.random.randint(0, 18) else: neigh_a, neigh_b = graph_a.subgraph(a), graph_b.subgraph(b) if filter_negs: matcher = nx.algorithms.isomorphism.GraphMatcher( neigh_a, neigh_b) if matcher.subgraph_is_isomorphic( ): # a <= b (b is subgraph of a) continue neg_a.append(neigh_a) neg_b.append(neigh_b) nx_graphs = ( deepcopy(pos_a), deepcopy(pos_b), deepcopy(neg_a), deepcopy(neg_b), deepcopy(pos_a_anchors), deepcopy(pos_b_anchors), deepcopy(neg_a_anchors), deepcopy(neg_b_anchors), ) pos_a = utils.batch_nx_graphs( pos_a, anchors=pos_a_anchors if self.node_anchored else None) pos_b = utils.batch_nx_graphs( pos_b, anchors=pos_b_anchors if self.node_anchored else None) neg_a = utils.batch_nx_graphs( neg_a, anchors=neg_a_anchors if self.node_anchored else None) neg_b = utils.batch_nx_graphs( neg_b, anchors=neg_b_anchors if self.node_anchored else None) return pos_a, pos_b, neg_a, neg_b, nx_graphs
def pattern_growth(dataset, task, args): # init model if args.method_type == "end2end": model = models.End2EndOrder(1, args.hidden_dim, args) elif args.method_type == "mlp": model = models.BaselineMLP(1, args.hidden_dim, args) else: model = models.OrderEmbedder(1, args.hidden_dim, args) model.to(utils.get_device()) model.eval() model.load_state_dict( torch.load(args.model_path, map_location=utils.get_device())) if task == "graph-labeled": dataset, labels = dataset # load data neighs_pyg, neighs = [], [] print(len(dataset), "graphs") print("search strategy:", args.search_strategy) if task == "graph-labeled": print("using label 0") graphs = [] for i, graph in enumerate(dataset): if task == "graph-labeled" and labels[i] != 0: continue if task == "graph-truncate" and i >= 1000: break if not type(graph) == nx.Graph: graph = pyg_utils.to_networkx(graph).to_undirected() graphs.append(graph) if args.use_whole_graphs: neighs = graphs else: anchors = [] if args.sample_method == "radial": for i, graph in enumerate(graphs): print(i) for j, node in enumerate(graph.nodes): if len(dataset) <= 10 and j % 100 == 0: print(i, j) if args.use_whole_graphs: neigh = graph.nodes else: neigh = list( nx.single_source_shortest_path_length( graph, node, cutoff=args.radius).keys()) if args.subgraph_sample_size != 0: neigh = random.sample( neigh, min(len(neigh), args.subgraph_sample_size)) if len(neigh) > 1: neigh = graph.subgraph(neigh) if args.subgraph_sample_size != 0: neigh = neigh.subgraph( max(nx.connected_components(neigh), key=len)) neigh = nx.convert_node_labels_to_integers(neigh) neigh.add_edge(0, 0) neighs.append(neigh) elif args.sample_method == "tree": start_time = time.time() for j in tqdm(range(args.n_neighborhoods)): graph, neigh = utils.sample_neigh( graphs, random.randint(args.min_neighborhood_size, args.max_neighborhood_size)) neigh = graph.subgraph(neigh) neigh = nx.convert_node_labels_to_integers(neigh) neigh.add_edge(0, 0) neighs.append(neigh) if args.node_anchored: anchors.append( 0) # after converting labels, 0 will be anchor embs = [] if len(neighs) % args.batch_size != 0: print("WARNING: number of graphs not multiple of batch size") for i in range(len(neighs) // args.batch_size): #top = min(len(neighs), (i+1)*args.batch_size) top = (i + 1) * args.batch_size with torch.no_grad(): batch = utils.batch_nx_graphs( neighs[i * args.batch_size:top], anchors=anchors if args.node_anchored else None) emb = model.emb_model(batch) emb = emb.to(torch.device("cpu")) embs.append(emb) if args.analyze: embs_np = torch.stack(embs).numpy() plt.scatter(embs_np[:, 0], embs_np[:, 1], label="node neighborhood") if args.search_strategy == "mcts": assert args.method_type == "order" agent = MCTSSearchAgent(args.min_pattern_size, args.max_pattern_size, model, graphs, embs, node_anchored=args.node_anchored, analyze=args.analyze, out_batch_size=args.out_batch_size) elif args.search_strategy == "greedy": agent = GreedySearchAgent(args.min_pattern_size, args.max_pattern_size, model, graphs, embs, node_anchored=args.node_anchored, analyze=args.analyze, model_type=args.method_type, out_batch_size=args.out_batch_size) out_graphs = agent.run_search(args.n_trials) print(time.time() - start_time, "TOTAL TIME") x = int(time.time() - start_time) print(x // 60, "mins", x % 60, "secs") # visualize out patterns count_by_size = defaultdict(int) for pattern in out_graphs: if args.node_anchored: colors = ["red"] + ["blue"] * (len(pattern) - 1) nx.draw(pattern, node_color=colors, with_labels=True) else: nx.draw(pattern) print("Saving plots/cluster/{}-{}.png".format( len(pattern), count_by_size[len(pattern)])) plt.savefig("plots/cluster/{}-{}.png".format( len(pattern), count_by_size[len(pattern)])) plt.savefig("plots/cluster/{}-{}.pdf".format( len(pattern), count_by_size[len(pattern)])) plt.close() count_by_size[len(pattern)] += 1 if not os.path.exists("results"): os.makedirs("results") with open(args.out_path, "wb") as f: pickle.dump(out_graphs, f)