def gen_batch(self, graphs_a, graphs_b, _, train):
        def add_anchor(g):
            anchor = random.choice(list(g.G.nodes))
            for v in g.G.nodes:
                g.G.nodes[v]["node_feature"] = (torch.ones(1) if anchor == v
                                                or not self.node_anchored else
                                                torch.zeros(1))
            return g

        pos_a, pos_b, neg_a, neg_b = [], [], [], []
        fn = "data/cache/imbalanced-{}-{}-{}".format(self.dataset_name.lower(),
                                                     str(self.node_anchored),
                                                     self.batch_idx)
        if not os.path.exists(fn):
            graphs_a = graphs_a.apply_transform(add_anchor)
            graphs_b = graphs_b.apply_transform(add_anchor)
            for graph_a, graph_b in tqdm(list(zip(graphs_a.G, graphs_b.G))):
                matcher = nx.algorithms.isomorphism.GraphMatcher(
                    graph_a,
                    graph_b,
                    node_match=(lambda a, b: (a["node_feature"][0] > 0.5) ==
                                (b["node_feature"][0] > 0.5))
                    if self.node_anchored else None)
                if matcher.subgraph_is_isomorphic():
                    pos_a.append(graph_a)
                    pos_b.append(graph_b)
                else:
                    neg_a.append(graph_a)
                    neg_b.append(graph_b)
            if not os.path.exists("data/cache"):
                os.makedirs("data/cache")
            with open(fn, "wb") as f:
                pickle.dump((pos_a, pos_b, neg_a, neg_b), f)
            print("saved", fn)
        else:
            with open(fn, "rb") as f:
                print("loaded", fn)
                pos_a, pos_b, neg_a, neg_b = pickle.load(f)
        print(len(pos_a), len(neg_a))

        nx_graphs = (
            deepcopy(pos_a),
            deepcopy(pos_b),
            deepcopy(neg_a),
            deepcopy(neg_b),
        )

        if pos_a:
            pos_a = utils.batch_nx_graphs(pos_a)
            pos_b = utils.batch_nx_graphs(pos_b)
        neg_a = utils.batch_nx_graphs(neg_a)
        neg_b = utils.batch_nx_graphs(neg_b)
        self.batch_idx += 1
        return pos_a, pos_b, neg_a, neg_b, nx_graphs
Esempio n. 2
0
def gen_alignment_matrix(model, query, target, method_type="order"):
    """Generate subgraph matching alignment matrix for a given query and
    target graph. Each entry (u, v) of the matrix contains the confidence score
    the model gives for the query graph, anchored at u, being a subgraph of the
    target graph, anchored at v.

    Args:
        model: the subgraph matching model. Must have been trained with
            node anchored setting (--node_anchored, default)
        query: the query graph (networkx Graph)
        target: the target graph (networkx Graph)
        method_type: the method used for the model.
            "order" for order embedding or "mlp" for MLP model
    """

    mat = np.zeros((len(query), len(target)))
    for u in query.nodes:
        for v in target.nodes:
            batch = utils.batch_nx_graphs([query, target], anchors=[u, v])
            embs = model.emb_model(batch)
            pred = model(embs[0].unsqueeze(0), embs[1].unsqueeze(0))
            raw_pred = model.predict(pred)
            if method_type == "order":
                raw_pred = torch.log(raw_pred)
            elif method_type == "mlp":
                raw_pred = raw_pred[0][1]
            mat[u][v] = raw_pred.item()
    return mat
 def step(self):
     new_beam_sets = []
     print("seeds come from", len(set(b[0][-1] for b in self.beam_sets)),
         "distinct graphs")
     analyze_embs_cur = []
     for beam_set in tqdm(self.beam_sets):
         new_beams = []
         for _, neigh, frontier, visited, graph_idx in beam_set:
             graph = self.dataset[graph_idx]
             if len(neigh) >= self.max_pattern_size or not frontier: continue
             cand_neighs, anchors = [], []
             for cand_node in frontier:
                 cand_neigh = graph.subgraph(neigh + [cand_node])
                 cand_neighs.append(cand_neigh)
                 if self.node_anchored:
                     anchors.append(neigh[0])
             cand_embs = self.model.emb_model(utils.batch_nx_graphs(
                 cand_neighs, anchors=anchors if self.node_anchored else None))
             best_score, best_node = float("inf"), None
             for cand_node, cand_emb in zip(frontier, cand_embs):
                 score, n_embs = 0, 0
                 for emb_batch in self.embs:
                     n_embs += len(emb_batch)
                     if self.model_type == "order":
                         score -= torch.sum(torch.argmax(
                             self.model.clf_model(self.model.predict((
                             emb_batch.to(utils.get_device()),
                             cand_emb)).unsqueeze(1)), axis=1)).item()
                     elif self.model_type == "mlp":
                         score += torch.sum(self.model(
                             emb_batch.to(utils.get_device()),
                             cand_emb.unsqueeze(0).expand(len(emb_batch), -1)
                             )[:,0]).item()
                     else:
                         print("unrecognized model type")
                 if score < best_score:
                     best_score = score
                     best_node = cand_node
                 new_frontier = list(((set(frontier) |
                     set(graph.neighbors(cand_node))) - visited) -
                     set([cand_node]))
                 new_beams.append((
                     score, neigh + [cand_node],
                     new_frontier, visited | set([cand_node]), graph_idx))
         new_beams = list(sorted(new_beams, key=lambda x:
             x[0]))[:self.n_beams]
         for score, neigh, frontier, visited, graph_idx in new_beams[:1]:
             graph = self.dataset[graph_idx]
             # add to record
             neigh_g = graph.subgraph(neigh).copy()
             neigh_g.remove_edges_from(nx.selfloop_edges(neigh_g))
             for v in neigh_g.nodes:
                 neigh_g.nodes[v]["anchor"] = 1 if v == neigh[0] else 0
             self.cand_patterns[len(neigh_g)].append((score, neigh_g))
             if self.rank_method in ["counts", "hybrid"]:
                 self.counts[len(neigh_g)][utils.wl_hash(neigh_g,
                     node_anchored=self.node_anchored)].append(neigh_g)
             if self.analyze and len(neigh) >= 3:
                 emb = self.model.emb_model(utils.batch_nx_graphs(
                     [neigh_g], anchors=[neigh[0]] if self.node_anchored
                     else None)).squeeze(0)
                 analyze_embs_cur.append(emb.detach().cpu().numpy())
         if len(new_beams) > 0:
             new_beam_sets.append(new_beams)
     self.beam_sets = new_beam_sets
     self.analyze_embs.append(analyze_embs_cur)
    def step(self):
        ps = np.array([len(g) for g in self.dataset], dtype=np.float)
        ps /= np.sum(ps)
        graph_dist = stats.rv_discrete(values=(np.arange(len(self.dataset)), ps))

        print("Size", self.max_size)
        print(len(self.visited_seed_nodes), "distinct seeds")
        for simulation_n in tqdm(range(self.n_trials //
            (self.max_pattern_size+1-self.min_pattern_size))):
            # pick seed node
            best_graph_idx, best_start_node, best_score = None, None, -float("inf")
            for cand_graph_idx, cand_start_node in self.visited_seed_nodes:
                state = cand_graph_idx, cand_start_node
                my_visit_counts = sum(self.visit_counts[state].values())
                q_score = (sum(self.cum_action_values[state].values()) /
                    (my_visit_counts or 1))
                uct_score = self.c_uct * np.sqrt(np.log(simulation_n or 1) /
                    (my_visit_counts or 1))
                node_score = q_score + uct_score
                if node_score > best_score:
                    best_score = node_score
                    best_graph_idx = cand_graph_idx
                    best_start_node = cand_start_node
            # if existing seed beats choosing a new seed
            if best_score >= self.c_uct * np.sqrt(np.log(simulation_n or 1)):
                graph_idx, start_node = best_graph_idx, best_start_node
                assert best_start_node in self.dataset[graph_idx].nodes
                graph = self.dataset[graph_idx]
            else:
                found = False
                while not found:
                    graph_idx = np.arange(len(self.dataset))[graph_dist.rvs()]
                    graph = self.dataset[graph_idx]
                    start_node = random.choice(list(graph.nodes))
                    # don't pick isolated nodes or small islands
                    if self.has_min_reachable_nodes(graph, start_node,
                        self.min_pattern_size):
                        found = True
                self.visited_seed_nodes.add((graph_idx, start_node))
            neigh = [start_node]
            frontier = list(set(graph.neighbors(start_node)) - set(neigh))
            visited = set([start_node])
            neigh_g = nx.Graph()
            neigh_g.add_node(start_node, anchor=1)
            cur_state = graph_idx, start_node
            state_list = [cur_state]
            while frontier and len(neigh) < self.max_size:
                cand_neighs, anchors = [], []
                for cand_node in frontier:
                    cand_neigh = graph.subgraph(neigh + [cand_node])
                    cand_neighs.append(cand_neigh)
                    if self.node_anchored:
                        anchors.append(neigh[0])
                cand_embs = self.model.emb_model(utils.batch_nx_graphs(
                    cand_neighs, anchors=anchors if self.node_anchored else None))
                best_v_score, best_node_score, best_node = 0, -float("inf"), None
                for cand_node, cand_emb in zip(frontier, cand_embs):
                    score, n_embs = 0, 0
                    for emb_batch in self.embs:
                        score += torch.sum(self.model.predict((
                            emb_batch.to(utils.get_device()), cand_emb))).item()
                        n_embs += len(emb_batch)
                    v_score = -np.log(score/n_embs + 1) + 1
                    # get wl hash of next state
                    neigh_g = graph.subgraph(neigh + [cand_node]).copy()
                    neigh_g.remove_edges_from(nx.selfloop_edges(neigh_g))
                    for v in neigh_g.nodes:
                        neigh_g.nodes[v]["anchor"] = 1 if v == neigh[0] else 0
                    next_state = utils.wl_hash(neigh_g,
                        node_anchored=self.node_anchored)
                    # compute node score
                    parent_visit_counts = sum(self.visit_counts[cur_state].values())
                    my_visit_counts = sum(self.visit_counts[next_state].values())
                    q_score = (sum(self.cum_action_values[next_state].values()) /
                        (my_visit_counts or 1))
                    uct_score = self.c_uct * np.sqrt(np.log(parent_visit_counts or
                        1) / (my_visit_counts or 1))
                    node_score = q_score + uct_score
                    if node_score > best_node_score:
                        best_node_score = node_score
                        best_v_score = v_score
                        best_node = cand_node
                frontier = list(((set(frontier) |
                    set(graph.neighbors(best_node))) - visited) -
                    set([best_node]))
                visited.add(best_node)
                neigh.append(best_node)

                # update visit counts, wl cache
                neigh_g = graph.subgraph(neigh).copy()
                neigh_g.remove_edges_from(nx.selfloop_edges(neigh_g))
                for v in neigh_g.nodes:
                    neigh_g.nodes[v]["anchor"] = 1 if v == neigh[0] else 0
                prev_state = cur_state
                cur_state = utils.wl_hash(neigh_g, node_anchored=self.node_anchored)
                state_list.append(cur_state)
                self.wl_hash_to_graphs[cur_state].append(neigh_g)

            # backprop value
            for i in range(0, len(state_list) - 1):
                self.cum_action_values[state_list[i]][
                    state_list[i+1]] += best_v_score
                self.visit_counts[state_list[i]][state_list[i+1]] += 1
        self.max_size += 1
    def gen_batch(self,
                  a,
                  b,
                  c,
                  train,
                  max_size=15,
                  min_size=5,
                  seed=None,
                  filter_negs=False,
                  sample_method="tree-pair"):
        batch_size = a
        train_set, test_set, task = self.dataset
        graphs = train_set if train else test_set
        if seed is not None:
            random.seed(seed)

        pos_a, pos_b = [], []
        pos_a_anchors, pos_b_anchors = [], []
        for i in range(batch_size // 2):
            if sample_method == "tree-pair":
                size = random.randint(min_size + 1, max_size)
                graph, a = utils.sample_neigh(graphs, size)
                b = a[:random.randint(min_size, len(a) - 1)]
            elif sample_method == "subgraph-tree":
                graph = None
                while graph is None or len(graph) < min_size + 1:
                    graph = random.choice(graphs)
                a = graph.nodes
                _, b = utils.sample_neigh([graph],
                                          random.randint(
                                              min_size,
                                              len(graph) - 1))
            if self.node_anchored:
                anchor = list(graph.nodes)[0]
                pos_a_anchors.append(anchor)
                pos_b_anchors.append(anchor)
            neigh_a, neigh_b = graph.subgraph(a), graph.subgraph(b)
            pos_a.append(neigh_a)
            pos_b.append(neigh_b)

        neg_a, neg_b = [], []
        neg_a_anchors, neg_b_anchors = [], []
        while len(neg_a) < batch_size // 2:
            if sample_method == "tree-pair":
                size = random.randint(min_size + 1, max_size)
                graph_a, a = utils.sample_neigh(graphs, size)
                graph_b, b = utils.sample_neigh(
                    graphs, random.randint(min_size, size - 1))
            elif sample_method == "subgraph-tree":
                graph_a = None
                while graph_a is None or len(graph_a) < min_size + 1:
                    graph_a = random.choice(graphs)
                a = graph_a.nodes
                graph_b, b = utils.sample_neigh(
                    graphs, random.randint(min_size,
                                           len(graph_a) - 1))
            if self.node_anchored:
                neg_a_anchors.append(list(graph_a.nodes)[0])
                neg_b_anchors.append(list(graph_b.nodes)[0])
            neigh_a, neigh_b = graph_a.subgraph(a), graph_b.subgraph(b)
            if filter_negs:
                matcher = nx.algorithms.isomorphism.GraphMatcher(
                    neigh_a, neigh_b)
                if matcher.subgraph_is_isomorphic(
                ):  # a <= b (b is subgraph of a)
                    continue
            neg_a.append(neigh_a)
            neg_b.append(neigh_b)

        nx_graphs = (
            deepcopy(pos_a),
            deepcopy(pos_b),
            deepcopy(neg_a),
            deepcopy(neg_b),
            deepcopy(pos_a_anchors),
            deepcopy(pos_b_anchors),
            deepcopy(neg_a_anchors),
            deepcopy(neg_b_anchors),
        )

        pos_a = utils.batch_nx_graphs(
            pos_a, anchors=pos_a_anchors if self.node_anchored else None)
        pos_b = utils.batch_nx_graphs(
            pos_b, anchors=pos_b_anchors if self.node_anchored else None)
        neg_a = utils.batch_nx_graphs(
            neg_a, anchors=neg_a_anchors if self.node_anchored else None)
        neg_b = utils.batch_nx_graphs(
            neg_b, anchors=neg_b_anchors if self.node_anchored else None)
        return pos_a, pos_b, neg_a, neg_b, nx_graphs
    def gen_batch(self,
                  a,
                  b,
                  c,
                  train,
                  max_size=15,
                  min_size=5,
                  seed=None,
                  filter_negs=False,
                  sample_method="tree-pair",
                  one_small=False):
        batch_size = a
        train_set, test_set, task = self.dataset

        graphs = train_set if train else test_set
        if seed is not None:
            random.seed(seed)

        pos_a, pos_b = [], []
        pos_a_anchors, pos_b_anchors = [], []
        for i in range(batch_size // 2):
            if sample_method == "tree-pair":
                size = random.randint(min_size + 1, max_size)
                graph, a = utils.sample_neigh(graphs, size)

                # hack
                if one_small:
                    np.random.shuffle(a)
                    b = a[1:]
                # hack
                else:
                    b = a[:random.randint(min_size, len(a) - 1)]
            elif sample_method == "subgraph-tree":
                graph = None
                while graph is None or len(graph) < min_size + 1:
                    graph = random.choice(graphs)
                a = graph.nodes
                _, b = utils.sample_neigh([graph],
                                          random.randint(
                                              min_size,
                                              len(graph) - 1))
            if self.node_anchored:
                anchor = list(graph.nodes)[0]
                pos_a_anchors.append(anchor)
                pos_b_anchors.append(anchor)
            neigh_a, neigh_b = graph.subgraph(a), graph.subgraph(b)
            pos_a.append(neigh_a)
            pos_b.append(neigh_b)

        neg_a, neg_b = [], []
        neg_a_anchors, neg_b_anchors = [], []
        while len(neg_a) < batch_size // 2:
            if sample_method == "tree-pair":

                # hack
                if one_small:
                    size = random.randint(min_size + 1, max_size)
                    graph_a, a = utils.sample_neigh(graphs, size)
                    graph_b, b = deepcopy((graph_a, a))
                    b = b[1:]
                # hack
                else:
                    size = random.randint(min_size + 1, max_size)
                    graph_a, a = utils.sample_neigh(graphs, size)
                    graph_b, b = utils.sample_neigh(
                        graphs, random.randint(min_size, size - 1))
            elif sample_method == "subgraph-tree":
                graph_a = None
                while graph_a is None or len(graph_a) < min_size + 1:
                    graph_a = random.choice(graphs)
                a = graph_a.nodes
                graph_b, b = utils.sample_neigh(
                    graphs, random.randint(min_size,
                                           len(graph_a) - 1))
            if self.node_anchored:
                neg_a_anchors.append(list(graph_a.nodes)[0])
                neg_b_anchors.append(list(graph_b.nodes)[0])
            if one_small:
                neigh_a, neigh_b = graph_a.subgraph(a), graph_b.subgraph(b)
                tmp = list(nx.connected_components(neigh_b.to_undirected()))
                tmp = sorted(tmp, key=lambda x: -len(x))
                neigh_b = neigh_b.subgraph(tmp[0])
                b_ne = list(nx.non_edges(neigh_b))
                np.random.shuffle(b_ne)
                b_ne = b_ne[:np.random.randint(1, 3)]
                neigh_b = nx.DiGraph(neigh_b)
                neigh_b.add_edges_from(b_ne)
                for i in b_ne:
                    neigh_b.edges[i]['edge_type'] = np.random.randint(0, 18)
            else:
                neigh_a, neigh_b = graph_a.subgraph(a), graph_b.subgraph(b)
            if filter_negs:
                matcher = nx.algorithms.isomorphism.GraphMatcher(
                    neigh_a, neigh_b)
                if matcher.subgraph_is_isomorphic(
                ):  # a <= b (b is subgraph of a)
                    continue
            neg_a.append(neigh_a)
            neg_b.append(neigh_b)

        nx_graphs = (
            deepcopy(pos_a),
            deepcopy(pos_b),
            deepcopy(neg_a),
            deepcopy(neg_b),
            deepcopy(pos_a_anchors),
            deepcopy(pos_b_anchors),
            deepcopy(neg_a_anchors),
            deepcopy(neg_b_anchors),
        )

        pos_a = utils.batch_nx_graphs(
            pos_a, anchors=pos_a_anchors if self.node_anchored else None)
        pos_b = utils.batch_nx_graphs(
            pos_b, anchors=pos_b_anchors if self.node_anchored else None)
        neg_a = utils.batch_nx_graphs(
            neg_a, anchors=neg_a_anchors if self.node_anchored else None)
        neg_b = utils.batch_nx_graphs(
            neg_b, anchors=neg_b_anchors if self.node_anchored else None)
        return pos_a, pos_b, neg_a, neg_b, nx_graphs
def pattern_growth(dataset, task, args):
    # init model
    if args.method_type == "end2end":
        model = models.End2EndOrder(1, args.hidden_dim, args)
    elif args.method_type == "mlp":
        model = models.BaselineMLP(1, args.hidden_dim, args)
    else:
        model = models.OrderEmbedder(1, args.hidden_dim, args)
    model.to(utils.get_device())
    model.eval()
    model.load_state_dict(
        torch.load(args.model_path, map_location=utils.get_device()))

    if task == "graph-labeled":
        dataset, labels = dataset

    # load data
    neighs_pyg, neighs = [], []
    print(len(dataset), "graphs")
    print("search strategy:", args.search_strategy)
    if task == "graph-labeled": print("using label 0")
    graphs = []
    for i, graph in enumerate(dataset):
        if task == "graph-labeled" and labels[i] != 0: continue
        if task == "graph-truncate" and i >= 1000: break
        if not type(graph) == nx.Graph:
            graph = pyg_utils.to_networkx(graph).to_undirected()
        graphs.append(graph)
    if args.use_whole_graphs:
        neighs = graphs
    else:
        anchors = []
        if args.sample_method == "radial":
            for i, graph in enumerate(graphs):
                print(i)
                for j, node in enumerate(graph.nodes):
                    if len(dataset) <= 10 and j % 100 == 0: print(i, j)
                    if args.use_whole_graphs:
                        neigh = graph.nodes
                    else:
                        neigh = list(
                            nx.single_source_shortest_path_length(
                                graph, node, cutoff=args.radius).keys())
                        if args.subgraph_sample_size != 0:
                            neigh = random.sample(
                                neigh,
                                min(len(neigh), args.subgraph_sample_size))
                    if len(neigh) > 1:
                        neigh = graph.subgraph(neigh)
                        if args.subgraph_sample_size != 0:
                            neigh = neigh.subgraph(
                                max(nx.connected_components(neigh), key=len))
                        neigh = nx.convert_node_labels_to_integers(neigh)
                        neigh.add_edge(0, 0)
                        neighs.append(neigh)
        elif args.sample_method == "tree":
            start_time = time.time()
            for j in tqdm(range(args.n_neighborhoods)):
                graph, neigh = utils.sample_neigh(
                    graphs,
                    random.randint(args.min_neighborhood_size,
                                   args.max_neighborhood_size))
                neigh = graph.subgraph(neigh)
                neigh = nx.convert_node_labels_to_integers(neigh)
                neigh.add_edge(0, 0)
                neighs.append(neigh)
                if args.node_anchored:
                    anchors.append(
                        0)  # after converting labels, 0 will be anchor

    embs = []
    if len(neighs) % args.batch_size != 0:
        print("WARNING: number of graphs not multiple of batch size")
    for i in range(len(neighs) // args.batch_size):
        #top = min(len(neighs), (i+1)*args.batch_size)
        top = (i + 1) * args.batch_size
        with torch.no_grad():
            batch = utils.batch_nx_graphs(
                neighs[i * args.batch_size:top],
                anchors=anchors if args.node_anchored else None)
            emb = model.emb_model(batch)
            emb = emb.to(torch.device("cpu"))

        embs.append(emb)

    if args.analyze:
        embs_np = torch.stack(embs).numpy()
        plt.scatter(embs_np[:, 0], embs_np[:, 1], label="node neighborhood")

    if args.search_strategy == "mcts":
        assert args.method_type == "order"
        agent = MCTSSearchAgent(args.min_pattern_size,
                                args.max_pattern_size,
                                model,
                                graphs,
                                embs,
                                node_anchored=args.node_anchored,
                                analyze=args.analyze,
                                out_batch_size=args.out_batch_size)
    elif args.search_strategy == "greedy":
        agent = GreedySearchAgent(args.min_pattern_size,
                                  args.max_pattern_size,
                                  model,
                                  graphs,
                                  embs,
                                  node_anchored=args.node_anchored,
                                  analyze=args.analyze,
                                  model_type=args.method_type,
                                  out_batch_size=args.out_batch_size)
    out_graphs = agent.run_search(args.n_trials)
    print(time.time() - start_time, "TOTAL TIME")
    x = int(time.time() - start_time)
    print(x // 60, "mins", x % 60, "secs")

    # visualize out patterns
    count_by_size = defaultdict(int)
    for pattern in out_graphs:
        if args.node_anchored:
            colors = ["red"] + ["blue"] * (len(pattern) - 1)
            nx.draw(pattern, node_color=colors, with_labels=True)
        else:
            nx.draw(pattern)
        print("Saving plots/cluster/{}-{}.png".format(
            len(pattern), count_by_size[len(pattern)]))
        plt.savefig("plots/cluster/{}-{}.png".format(
            len(pattern), count_by_size[len(pattern)]))
        plt.savefig("plots/cluster/{}-{}.pdf".format(
            len(pattern), count_by_size[len(pattern)]))
        plt.close()
        count_by_size[len(pattern)] += 1

    if not os.path.exists("results"):
        os.makedirs("results")
    with open(args.out_path, "wb") as f:
        pickle.dump(out_graphs, f)