def build_model(args): # build model if args.method_type == "order": model = models.OrderEmbedder(1, args.hidden_dim, args) elif args.method_type == "mlp": model = models.BaselineMLP(1, args.hidden_dim, args) model.to(utils.get_device()) if args.test and args.model_path: model.load_state_dict(torch.load(args.model_path, map_location=utils.get_device())) return model
from common import utils from subgraph_matching.config import parse_encoder # Now we load the model and a dataset to analyze embeddings on, here ENZYMES. from subgraph_matching.train import make_data_source parser = argparse.ArgumentParser() utils.parse_optimizer(parser) parse_encoder(parser) args = parser.parse_args("") args.model_path = os.path.join("..", args.model_path) print("Using dataset {}".format(args.dataset)) model = models.OrderEmbedder(1, args.hidden_dim, args) model.to(utils.get_device()) model.eval() model.load_state_dict( torch.load(args.model_path, map_location=utils.get_device())) train, test, task = data.load_dataset("wn18") from collections import Counter done = False train_accs = [] while not done: data_source = make_data_source(args) loaders = data_source.gen_data_loaders(args.eval_interval * args.batch_size,
def pattern_growth(dataset, task, args): # init model if args.method_type == "end2end": model = models.End2EndOrder(1, args.hidden_dim, args) elif args.method_type == "mlp": model = models.BaselineMLP(1, args.hidden_dim, args) else: model = models.OrderEmbedder(1, args.hidden_dim, args) model.to(utils.get_device()) model.eval() model.load_state_dict( torch.load(args.model_path, map_location=utils.get_device())) if task == "graph-labeled": dataset, labels = dataset # load data neighs_pyg, neighs = [], [] print(len(dataset), "graphs") print("search strategy:", args.search_strategy) if task == "graph-labeled": print("using label 0") graphs = [] for i, graph in enumerate(dataset): if task == "graph-labeled" and labels[i] != 0: continue if task == "graph-truncate" and i >= 1000: break if not type(graph) == nx.Graph: graph = pyg_utils.to_networkx(graph).to_undirected() graphs.append(graph) if args.use_whole_graphs: neighs = graphs else: anchors = [] if args.sample_method == "radial": for i, graph in enumerate(graphs): print(i) for j, node in enumerate(graph.nodes): if len(dataset) <= 10 and j % 100 == 0: print(i, j) if args.use_whole_graphs: neigh = graph.nodes else: neigh = list( nx.single_source_shortest_path_length( graph, node, cutoff=args.radius).keys()) if args.subgraph_sample_size != 0: neigh = random.sample( neigh, min(len(neigh), args.subgraph_sample_size)) if len(neigh) > 1: neigh = graph.subgraph(neigh) if args.subgraph_sample_size != 0: neigh = neigh.subgraph( max(nx.connected_components(neigh), key=len)) neigh = nx.convert_node_labels_to_integers(neigh) neigh.add_edge(0, 0) neighs.append(neigh) elif args.sample_method == "tree": start_time = time.time() for j in tqdm(range(args.n_neighborhoods)): graph, neigh = utils.sample_neigh( graphs, random.randint(args.min_neighborhood_size, args.max_neighborhood_size)) neigh = graph.subgraph(neigh) neigh = nx.convert_node_labels_to_integers(neigh) neigh.add_edge(0, 0) neighs.append(neigh) if args.node_anchored: anchors.append( 0) # after converting labels, 0 will be anchor embs = [] if len(neighs) % args.batch_size != 0: print("WARNING: number of graphs not multiple of batch size") for i in range(len(neighs) // args.batch_size): #top = min(len(neighs), (i+1)*args.batch_size) top = (i + 1) * args.batch_size with torch.no_grad(): batch = utils.batch_nx_graphs( neighs[i * args.batch_size:top], anchors=anchors if args.node_anchored else None) emb = model.emb_model(batch) emb = emb.to(torch.device("cpu")) embs.append(emb) if args.analyze: embs_np = torch.stack(embs).numpy() plt.scatter(embs_np[:, 0], embs_np[:, 1], label="node neighborhood") if args.search_strategy == "mcts": assert args.method_type == "order" agent = MCTSSearchAgent(args.min_pattern_size, args.max_pattern_size, model, graphs, embs, node_anchored=args.node_anchored, analyze=args.analyze, out_batch_size=args.out_batch_size) elif args.search_strategy == "greedy": agent = GreedySearchAgent(args.min_pattern_size, args.max_pattern_size, model, graphs, embs, node_anchored=args.node_anchored, analyze=args.analyze, model_type=args.method_type, out_batch_size=args.out_batch_size) out_graphs = agent.run_search(args.n_trials) print(time.time() - start_time, "TOTAL TIME") x = int(time.time() - start_time) print(x // 60, "mins", x % 60, "secs") # visualize out patterns count_by_size = defaultdict(int) for pattern in out_graphs: if args.node_anchored: colors = ["red"] + ["blue"] * (len(pattern) - 1) nx.draw(pattern, node_color=colors, with_labels=True) else: nx.draw(pattern) print("Saving plots/cluster/{}-{}.png".format( len(pattern), count_by_size[len(pattern)])) plt.savefig("plots/cluster/{}-{}.png".format( len(pattern), count_by_size[len(pattern)])) plt.savefig("plots/cluster/{}-{}.pdf".format( len(pattern), count_by_size[len(pattern)])) plt.close() count_by_size[len(pattern)] += 1 if not os.path.exists("results"): os.makedirs("results") with open(args.out_path, "wb") as f: pickle.dump(out_graphs, f)