def build_model(args):
    # build model
    if args.method_type == "order":
        model = models.OrderEmbedder(1, args.hidden_dim, args)
    elif args.method_type == "mlp":
        model = models.BaselineMLP(1, args.hidden_dim, args)
    model.to(utils.get_device())
    if args.test and args.model_path:
        model.load_state_dict(torch.load(args.model_path,
            map_location=utils.get_device()))
    return model
Esempio n. 2
0
from common import utils
from subgraph_matching.config import parse_encoder

# Now we load the model and a dataset to analyze embeddings on, here ENZYMES.

from subgraph_matching.train import make_data_source

parser = argparse.ArgumentParser()

utils.parse_optimizer(parser)
parse_encoder(parser)
args = parser.parse_args("")
args.model_path = os.path.join("..", args.model_path)

print("Using dataset {}".format(args.dataset))
model = models.OrderEmbedder(1, args.hidden_dim, args)
model.to(utils.get_device())
model.eval()
model.load_state_dict(
    torch.load(args.model_path, map_location=utils.get_device()))

train, test, task = data.load_dataset("wn18")

from collections import Counter

done = False
train_accs = []
while not done:
    data_source = make_data_source(args)
    loaders = data_source.gen_data_loaders(args.eval_interval *
                                           args.batch_size,
def pattern_growth(dataset, task, args):
    # init model
    if args.method_type == "end2end":
        model = models.End2EndOrder(1, args.hidden_dim, args)
    elif args.method_type == "mlp":
        model = models.BaselineMLP(1, args.hidden_dim, args)
    else:
        model = models.OrderEmbedder(1, args.hidden_dim, args)
    model.to(utils.get_device())
    model.eval()
    model.load_state_dict(
        torch.load(args.model_path, map_location=utils.get_device()))

    if task == "graph-labeled":
        dataset, labels = dataset

    # load data
    neighs_pyg, neighs = [], []
    print(len(dataset), "graphs")
    print("search strategy:", args.search_strategy)
    if task == "graph-labeled": print("using label 0")
    graphs = []
    for i, graph in enumerate(dataset):
        if task == "graph-labeled" and labels[i] != 0: continue
        if task == "graph-truncate" and i >= 1000: break
        if not type(graph) == nx.Graph:
            graph = pyg_utils.to_networkx(graph).to_undirected()
        graphs.append(graph)
    if args.use_whole_graphs:
        neighs = graphs
    else:
        anchors = []
        if args.sample_method == "radial":
            for i, graph in enumerate(graphs):
                print(i)
                for j, node in enumerate(graph.nodes):
                    if len(dataset) <= 10 and j % 100 == 0: print(i, j)
                    if args.use_whole_graphs:
                        neigh = graph.nodes
                    else:
                        neigh = list(
                            nx.single_source_shortest_path_length(
                                graph, node, cutoff=args.radius).keys())
                        if args.subgraph_sample_size != 0:
                            neigh = random.sample(
                                neigh,
                                min(len(neigh), args.subgraph_sample_size))
                    if len(neigh) > 1:
                        neigh = graph.subgraph(neigh)
                        if args.subgraph_sample_size != 0:
                            neigh = neigh.subgraph(
                                max(nx.connected_components(neigh), key=len))
                        neigh = nx.convert_node_labels_to_integers(neigh)
                        neigh.add_edge(0, 0)
                        neighs.append(neigh)
        elif args.sample_method == "tree":
            start_time = time.time()
            for j in tqdm(range(args.n_neighborhoods)):
                graph, neigh = utils.sample_neigh(
                    graphs,
                    random.randint(args.min_neighborhood_size,
                                   args.max_neighborhood_size))
                neigh = graph.subgraph(neigh)
                neigh = nx.convert_node_labels_to_integers(neigh)
                neigh.add_edge(0, 0)
                neighs.append(neigh)
                if args.node_anchored:
                    anchors.append(
                        0)  # after converting labels, 0 will be anchor

    embs = []
    if len(neighs) % args.batch_size != 0:
        print("WARNING: number of graphs not multiple of batch size")
    for i in range(len(neighs) // args.batch_size):
        #top = min(len(neighs), (i+1)*args.batch_size)
        top = (i + 1) * args.batch_size
        with torch.no_grad():
            batch = utils.batch_nx_graphs(
                neighs[i * args.batch_size:top],
                anchors=anchors if args.node_anchored else None)
            emb = model.emb_model(batch)
            emb = emb.to(torch.device("cpu"))

        embs.append(emb)

    if args.analyze:
        embs_np = torch.stack(embs).numpy()
        plt.scatter(embs_np[:, 0], embs_np[:, 1], label="node neighborhood")

    if args.search_strategy == "mcts":
        assert args.method_type == "order"
        agent = MCTSSearchAgent(args.min_pattern_size,
                                args.max_pattern_size,
                                model,
                                graphs,
                                embs,
                                node_anchored=args.node_anchored,
                                analyze=args.analyze,
                                out_batch_size=args.out_batch_size)
    elif args.search_strategy == "greedy":
        agent = GreedySearchAgent(args.min_pattern_size,
                                  args.max_pattern_size,
                                  model,
                                  graphs,
                                  embs,
                                  node_anchored=args.node_anchored,
                                  analyze=args.analyze,
                                  model_type=args.method_type,
                                  out_batch_size=args.out_batch_size)
    out_graphs = agent.run_search(args.n_trials)
    print(time.time() - start_time, "TOTAL TIME")
    x = int(time.time() - start_time)
    print(x // 60, "mins", x % 60, "secs")

    # visualize out patterns
    count_by_size = defaultdict(int)
    for pattern in out_graphs:
        if args.node_anchored:
            colors = ["red"] + ["blue"] * (len(pattern) - 1)
            nx.draw(pattern, node_color=colors, with_labels=True)
        else:
            nx.draw(pattern)
        print("Saving plots/cluster/{}-{}.png".format(
            len(pattern), count_by_size[len(pattern)]))
        plt.savefig("plots/cluster/{}-{}.png".format(
            len(pattern), count_by_size[len(pattern)]))
        plt.savefig("plots/cluster/{}-{}.pdf".format(
            len(pattern), count_by_size[len(pattern)]))
        plt.close()
        count_by_size[len(pattern)] += 1

    if not os.path.exists("results"):
        os.makedirs("results")
    with open(args.out_path, "wb") as f:
        pickle.dump(out_graphs, f)