def gen_data_loaders(self,
                      size,
                      batch_size,
                      train=True,
                      use_distributed_sampling=False):
     loaders = []
     for i in range(2):
         neighs = []
         for j in range(size // 2):
             graph, neigh = utils.sample_neigh(
                 self.train_set if train else self.test_set,
                 random.randint(self.min_size, self.max_size))
             neighs.append(graph.subgraph(neigh))
         dataset = GraphDataset(GraphDataset.list_to_graphs(neighs))
         loaders.append(
             TorchDataLoader(dataset,
                             collate_fn=Batch.collate([]),
                             batch_size=batch_size //
                             2 if i == 0 else batch_size // 2,
                             sampler=None,
                             shuffle=False))
     loaders.append([None] * (size // batch_size))
     return loaders
    def gen_batch(self,
                  a,
                  b,
                  c,
                  train,
                  max_size=15,
                  min_size=5,
                  seed=None,
                  filter_negs=False,
                  sample_method="tree-pair"):
        batch_size = a
        train_set, test_set, task = self.dataset
        graphs = train_set if train else test_set
        if seed is not None:
            random.seed(seed)

        pos_a, pos_b = [], []
        pos_a_anchors, pos_b_anchors = [], []
        for i in range(batch_size // 2):
            if sample_method == "tree-pair":
                size = random.randint(min_size + 1, max_size)
                graph, a = utils.sample_neigh(graphs, size)
                b = a[:random.randint(min_size, len(a) - 1)]
            elif sample_method == "subgraph-tree":
                graph = None
                while graph is None or len(graph) < min_size + 1:
                    graph = random.choice(graphs)
                a = graph.nodes
                _, b = utils.sample_neigh([graph],
                                          random.randint(
                                              min_size,
                                              len(graph) - 1))
            if self.node_anchored:
                anchor = list(graph.nodes)[0]
                pos_a_anchors.append(anchor)
                pos_b_anchors.append(anchor)
            neigh_a, neigh_b = graph.subgraph(a), graph.subgraph(b)
            pos_a.append(neigh_a)
            pos_b.append(neigh_b)

        neg_a, neg_b = [], []
        neg_a_anchors, neg_b_anchors = [], []
        while len(neg_a) < batch_size // 2:
            if sample_method == "tree-pair":
                size = random.randint(min_size + 1, max_size)
                graph_a, a = utils.sample_neigh(graphs, size)
                graph_b, b = utils.sample_neigh(
                    graphs, random.randint(min_size, size - 1))
            elif sample_method == "subgraph-tree":
                graph_a = None
                while graph_a is None or len(graph_a) < min_size + 1:
                    graph_a = random.choice(graphs)
                a = graph_a.nodes
                graph_b, b = utils.sample_neigh(
                    graphs, random.randint(min_size,
                                           len(graph_a) - 1))
            if self.node_anchored:
                neg_a_anchors.append(list(graph_a.nodes)[0])
                neg_b_anchors.append(list(graph_b.nodes)[0])
            neigh_a, neigh_b = graph_a.subgraph(a), graph_b.subgraph(b)
            if filter_negs:
                matcher = nx.algorithms.isomorphism.GraphMatcher(
                    neigh_a, neigh_b)
                if matcher.subgraph_is_isomorphic(
                ):  # a <= b (b is subgraph of a)
                    continue
            neg_a.append(neigh_a)
            neg_b.append(neigh_b)

        nx_graphs = (
            deepcopy(pos_a),
            deepcopy(pos_b),
            deepcopy(neg_a),
            deepcopy(neg_b),
            deepcopy(pos_a_anchors),
            deepcopy(pos_b_anchors),
            deepcopy(neg_a_anchors),
            deepcopy(neg_b_anchors),
        )

        pos_a = utils.batch_nx_graphs(
            pos_a, anchors=pos_a_anchors if self.node_anchored else None)
        pos_b = utils.batch_nx_graphs(
            pos_b, anchors=pos_b_anchors if self.node_anchored else None)
        neg_a = utils.batch_nx_graphs(
            neg_a, anchors=neg_a_anchors if self.node_anchored else None)
        neg_b = utils.batch_nx_graphs(
            neg_b, anchors=neg_b_anchors if self.node_anchored else None)
        return pos_a, pos_b, neg_a, neg_b, nx_graphs
        if pos_a:
            pos_a = utils.batch_nx_graphs(pos_a)
            pos_b = utils.batch_nx_graphs(pos_b)
        neg_a = utils.batch_nx_graphs(neg_a)
        neg_b = utils.batch_nx_graphs(neg_b)
        self.batch_idx += 1
        return pos_a, pos_b, neg_a, neg_b, nx_graphs


if __name__ == "__main__":
    import matplotlib.pyplot as plt
    plt.rcParams.update({"font.size": 14})
    for name in ["enzymes", "reddit-binary", "cox2"]:
        data_source = DiskDataSource(name)
        train, test, _ = data_source.dataset
        i = 11
        neighs = [utils.sample_neigh(train, i) for j in range(10000)]
        clustering = [
            nx.average_clustering(graph.subgraph(nodes))
            for graph, nodes in neighs
        ]
        path_length = [
            nx.average_shortest_path_length(graph.subgraph(nodes))
            for graph, nodes in neighs
        ]
        #plt.subplot(1, 2, i-9)
        plt.scatter(clustering, path_length, s=10, label=name)
    plt.legend()
    plt.savefig("plots/clustering-vs-path-length.png")
    def gen_batch(self,
                  a,
                  b,
                  c,
                  train,
                  max_size=15,
                  min_size=5,
                  seed=None,
                  filter_negs=False,
                  sample_method="tree-pair",
                  one_small=False):
        batch_size = a
        train_set, test_set, task = self.dataset

        graphs = train_set if train else test_set
        if seed is not None:
            random.seed(seed)

        pos_a, pos_b = [], []
        pos_a_anchors, pos_b_anchors = [], []
        for i in range(batch_size // 2):
            if sample_method == "tree-pair":
                size = random.randint(min_size + 1, max_size)
                graph, a = utils.sample_neigh(graphs, size)

                # hack
                if one_small:
                    np.random.shuffle(a)
                    b = a[1:]
                # hack
                else:
                    b = a[:random.randint(min_size, len(a) - 1)]
            elif sample_method == "subgraph-tree":
                graph = None
                while graph is None or len(graph) < min_size + 1:
                    graph = random.choice(graphs)
                a = graph.nodes
                _, b = utils.sample_neigh([graph],
                                          random.randint(
                                              min_size,
                                              len(graph) - 1))
            if self.node_anchored:
                anchor = list(graph.nodes)[0]
                pos_a_anchors.append(anchor)
                pos_b_anchors.append(anchor)
            neigh_a, neigh_b = graph.subgraph(a), graph.subgraph(b)
            pos_a.append(neigh_a)
            pos_b.append(neigh_b)

        neg_a, neg_b = [], []
        neg_a_anchors, neg_b_anchors = [], []
        while len(neg_a) < batch_size // 2:
            if sample_method == "tree-pair":

                # hack
                if one_small:
                    size = random.randint(min_size + 1, max_size)
                    graph_a, a = utils.sample_neigh(graphs, size)
                    graph_b, b = deepcopy((graph_a, a))
                    b = b[1:]
                # hack
                else:
                    size = random.randint(min_size + 1, max_size)
                    graph_a, a = utils.sample_neigh(graphs, size)
                    graph_b, b = utils.sample_neigh(
                        graphs, random.randint(min_size, size - 1))
            elif sample_method == "subgraph-tree":
                graph_a = None
                while graph_a is None or len(graph_a) < min_size + 1:
                    graph_a = random.choice(graphs)
                a = graph_a.nodes
                graph_b, b = utils.sample_neigh(
                    graphs, random.randint(min_size,
                                           len(graph_a) - 1))
            if self.node_anchored:
                neg_a_anchors.append(list(graph_a.nodes)[0])
                neg_b_anchors.append(list(graph_b.nodes)[0])
            if one_small:
                neigh_a, neigh_b = graph_a.subgraph(a), graph_b.subgraph(b)
                tmp = list(nx.connected_components(neigh_b.to_undirected()))
                tmp = sorted(tmp, key=lambda x: -len(x))
                neigh_b = neigh_b.subgraph(tmp[0])
                b_ne = list(nx.non_edges(neigh_b))
                np.random.shuffle(b_ne)
                b_ne = b_ne[:np.random.randint(1, 3)]
                neigh_b = nx.DiGraph(neigh_b)
                neigh_b.add_edges_from(b_ne)
                for i in b_ne:
                    neigh_b.edges[i]['edge_type'] = np.random.randint(0, 18)
            else:
                neigh_a, neigh_b = graph_a.subgraph(a), graph_b.subgraph(b)
            if filter_negs:
                matcher = nx.algorithms.isomorphism.GraphMatcher(
                    neigh_a, neigh_b)
                if matcher.subgraph_is_isomorphic(
                ):  # a <= b (b is subgraph of a)
                    continue
            neg_a.append(neigh_a)
            neg_b.append(neigh_b)

        nx_graphs = (
            deepcopy(pos_a),
            deepcopy(pos_b),
            deepcopy(neg_a),
            deepcopy(neg_b),
            deepcopy(pos_a_anchors),
            deepcopy(pos_b_anchors),
            deepcopy(neg_a_anchors),
            deepcopy(neg_b_anchors),
        )

        pos_a = utils.batch_nx_graphs(
            pos_a, anchors=pos_a_anchors if self.node_anchored else None)
        pos_b = utils.batch_nx_graphs(
            pos_b, anchors=pos_b_anchors if self.node_anchored else None)
        neg_a = utils.batch_nx_graphs(
            neg_a, anchors=neg_a_anchors if self.node_anchored else None)
        neg_b = utils.batch_nx_graphs(
            neg_b, anchors=neg_b_anchors if self.node_anchored else None)
        return pos_a, pos_b, neg_a, neg_b, nx_graphs
def pattern_growth(dataset, task, args):
    # init model
    if args.method_type == "end2end":
        model = models.End2EndOrder(1, args.hidden_dim, args)
    elif args.method_type == "mlp":
        model = models.BaselineMLP(1, args.hidden_dim, args)
    else:
        model = models.OrderEmbedder(1, args.hidden_dim, args)
    model.to(utils.get_device())
    model.eval()
    model.load_state_dict(
        torch.load(args.model_path, map_location=utils.get_device()))

    if task == "graph-labeled":
        dataset, labels = dataset

    # load data
    neighs_pyg, neighs = [], []
    print(len(dataset), "graphs")
    print("search strategy:", args.search_strategy)
    if task == "graph-labeled": print("using label 0")
    graphs = []
    for i, graph in enumerate(dataset):
        if task == "graph-labeled" and labels[i] != 0: continue
        if task == "graph-truncate" and i >= 1000: break
        if not type(graph) == nx.Graph:
            graph = pyg_utils.to_networkx(graph).to_undirected()
        graphs.append(graph)
    if args.use_whole_graphs:
        neighs = graphs
    else:
        anchors = []
        if args.sample_method == "radial":
            for i, graph in enumerate(graphs):
                print(i)
                for j, node in enumerate(graph.nodes):
                    if len(dataset) <= 10 and j % 100 == 0: print(i, j)
                    if args.use_whole_graphs:
                        neigh = graph.nodes
                    else:
                        neigh = list(
                            nx.single_source_shortest_path_length(
                                graph, node, cutoff=args.radius).keys())
                        if args.subgraph_sample_size != 0:
                            neigh = random.sample(
                                neigh,
                                min(len(neigh), args.subgraph_sample_size))
                    if len(neigh) > 1:
                        neigh = graph.subgraph(neigh)
                        if args.subgraph_sample_size != 0:
                            neigh = neigh.subgraph(
                                max(nx.connected_components(neigh), key=len))
                        neigh = nx.convert_node_labels_to_integers(neigh)
                        neigh.add_edge(0, 0)
                        neighs.append(neigh)
        elif args.sample_method == "tree":
            start_time = time.time()
            for j in tqdm(range(args.n_neighborhoods)):
                graph, neigh = utils.sample_neigh(
                    graphs,
                    random.randint(args.min_neighborhood_size,
                                   args.max_neighborhood_size))
                neigh = graph.subgraph(neigh)
                neigh = nx.convert_node_labels_to_integers(neigh)
                neigh.add_edge(0, 0)
                neighs.append(neigh)
                if args.node_anchored:
                    anchors.append(
                        0)  # after converting labels, 0 will be anchor

    embs = []
    if len(neighs) % args.batch_size != 0:
        print("WARNING: number of graphs not multiple of batch size")
    for i in range(len(neighs) // args.batch_size):
        #top = min(len(neighs), (i+1)*args.batch_size)
        top = (i + 1) * args.batch_size
        with torch.no_grad():
            batch = utils.batch_nx_graphs(
                neighs[i * args.batch_size:top],
                anchors=anchors if args.node_anchored else None)
            emb = model.emb_model(batch)
            emb = emb.to(torch.device("cpu"))

        embs.append(emb)

    if args.analyze:
        embs_np = torch.stack(embs).numpy()
        plt.scatter(embs_np[:, 0], embs_np[:, 1], label="node neighborhood")

    if args.search_strategy == "mcts":
        assert args.method_type == "order"
        agent = MCTSSearchAgent(args.min_pattern_size,
                                args.max_pattern_size,
                                model,
                                graphs,
                                embs,
                                node_anchored=args.node_anchored,
                                analyze=args.analyze,
                                out_batch_size=args.out_batch_size)
    elif args.search_strategy == "greedy":
        agent = GreedySearchAgent(args.min_pattern_size,
                                  args.max_pattern_size,
                                  model,
                                  graphs,
                                  embs,
                                  node_anchored=args.node_anchored,
                                  analyze=args.analyze,
                                  model_type=args.method_type,
                                  out_batch_size=args.out_batch_size)
    out_graphs = agent.run_search(args.n_trials)
    print(time.time() - start_time, "TOTAL TIME")
    x = int(time.time() - start_time)
    print(x // 60, "mins", x % 60, "secs")

    # visualize out patterns
    count_by_size = defaultdict(int)
    for pattern in out_graphs:
        if args.node_anchored:
            colors = ["red"] + ["blue"] * (len(pattern) - 1)
            nx.draw(pattern, node_color=colors, with_labels=True)
        else:
            nx.draw(pattern)
        print("Saving plots/cluster/{}-{}.png".format(
            len(pattern), count_by_size[len(pattern)]))
        plt.savefig("plots/cluster/{}-{}.png".format(
            len(pattern), count_by_size[len(pattern)]))
        plt.savefig("plots/cluster/{}-{}.pdf".format(
            len(pattern), count_by_size[len(pattern)]))
        plt.close()
        count_by_size[len(pattern)] += 1

    if not os.path.exists("results"):
        os.makedirs("results")
    with open(args.out_path, "wb") as f:
        pickle.dump(out_graphs, f)