Example #1
0
def batch_nx_graphs(graphs, anchors=None):
    #motifs_batch = [pyg_utils.from_networkx(
    #    nx.convert_node_labels_to_integers(graph)) for graph in graphs]
    #loader = DataLoader(motifs_batch, batch_size=len(motifs_batch))
    #for b in loader: batch = b
    augmenter = feature_preprocess.FeatureAugment()

    if anchors is not None:
        for anchor, g in zip(anchors, graphs):
            for v in g.nodes:
                g.nodes[v]["node_feature"] = torch.tensor([float(v == anchor)])
    if 'aifb' == 'aifb' or 'wn18' == 'wn18':
        # 90 edge types
        for g in graphs:
            for e in g.edges:
                # tmp = torch.zeros(90)
                # tmp[g.edges[e]['edge_type']] = 1.

                g.edges[e]["edge_feature"] = torch.tensor(
                    [g.edges[e]['edge_type']], dtype=torch.long)

    batch = Batch.from_data_list(GraphDataset.list_to_graphs(graphs))
    batch = augmenter.augment(batch)
    batch = batch.to(get_device())
    return batch
Example #2
0
    def test_unbatch_nested(self):
        dims = [2, 3]
        G_sizes = [10, 5]
        G_list = []
        for i, size in enumerate(G_sizes):
            G = Graph()
            G.G = nx.complete_graph(i + 1)
            G.node_property = {
                "node_prop0": torch.ones(size, dims[0]) * i,
                "node_prop1": torch.ones(size, dims[1]) * i,
            }
            G_list.append(G)

        batch = Batch.from_data_list(G_list)

        # reconstruct graph list
        G_list_recon = batch.to_data_list()
        self.assertEqual(
            G_list_recon[0].node_property["node_prop0"].size(0),
            10,
        )
        self.assertEqual(
            G_list_recon[0].node_property["node_prop0"].size(1),
            2,
        )
        self.assertEqual(
            G_list_recon[1].node_property["node_prop1"].size(0),
            5,
        )
        self.assertEqual(
            G_list_recon[1].node_property["node_prop1"].size(1),
            3,
        )
Example #3
0
 def test_batch_basic(self):
     G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = \
         simple_networkx_graph()
     Graph.add_edge_attr(G, "edge_feature", edge_x)
     Graph.add_edge_attr(G, "edge_label", edge_y)
     Graph.add_node_attr(G, "node_feature", x)
     Graph.add_node_attr(G, "node_label", y)
     Graph.add_graph_attr(G, "graph_feature", graph_x)
     Graph.add_graph_attr(G, "graph_label", graph_y)
     H = deepcopy(G)
     graphs = GraphDataset.list_to_graphs([G, H])
     batch = Batch.from_data_list(graphs)
     self.assertEqual(batch.num_graphs, 2)
     self.assertEqual(len(batch.node_feature),
                      2 * len(graphs[0].node_feature))
Example #4
0
def batch_nx_graphs(graphs, anchors=None):
    #motifs_batch = [pyg_utils.from_networkx(
    #    nx.convert_node_labels_to_integers(graph)) for graph in graphs]
    #loader = DataLoader(motifs_batch, batch_size=len(motifs_batch))
    #for b in loader: batch = b
    augmenter = feature_preprocess.FeatureAugment()

    if anchors is not None:
        for anchor, g in zip(anchors, graphs):
            for v in g.nodes:
                g.nodes[v]["node_feature"] = torch.tensor([float(v == anchor)])

    batch = Batch.from_data_list([DSGraph(g) for g in graphs])
    batch = augmenter.augment(batch)
    batch = batch.to(get_device())
    return batch
Example #5
0
    def test_collate_batch_nested(self):
        dims = [2, 3]
        G_sizes = [10, 5]
        G_list = []
        for i, size in enumerate(G_sizes):
            G = Graph()
            G.G = nx.complete_graph(i + 1)
            G.node_property = {
                'node_prop0': torch.ones(size, dims[0]) * i,
                'node_prop1': torch.ones(size, dims[1]) * i
            }
            G_list.append(G)
        batch = Batch.from_data_list(G_list)

        self.assertEqual(batch.num_graphs, 2)
        self.assertEqual(batch.node_property['node_prop0'].size(0),
                         sum(G_sizes))
    def gen_batch(self, batch_target, batch_neg_target, batch_neg_query,
                  train):
        def sample_subgraph(graph,
                            offset=0,
                            use_precomp_sizes=False,
                            filter_negs=False,
                            supersample_small_graphs=False,
                            neg_target=None,
                            hard_neg_idxs=None):
            if neg_target is not None: graph_idx = graph.G.graph["idx"]
            use_hard_neg = (hard_neg_idxs is not None
                            and graph.G.graph["idx"] in hard_neg_idxs)
            done = False
            n_tries = 0
            while not done:
                if use_precomp_sizes:
                    size = graph.G.graph["subgraph_size"]
                else:
                    if train and supersample_small_graphs:
                        sizes = np.arange(self.min_size + offset,
                                          len(graph.G) + offset)
                        ps = (sizes - self.min_size + 2)**(-1.1)
                        ps /= ps.sum()
                        size = stats.rv_discrete(values=(sizes, ps)).rvs()
                    else:
                        d = 1 if train else 0
                        size = random.randint(self.min_size + offset - d,
                                              len(graph.G) - 1 + offset)
                start_node = random.choice(list(graph.G.nodes))
                neigh = [start_node]
                frontier = list(
                    set(graph.G.neighbors(start_node)) - set(neigh))
                visited = set([start_node])
                while len(neigh) < size:
                    new_node = random.choice(list(frontier))
                    assert new_node not in neigh
                    neigh.append(new_node)
                    visited.add(new_node)
                    frontier += list(graph.G.neighbors(new_node))
                    frontier = [x for x in frontier if x not in visited]
                if self.node_anchored:
                    anchor = neigh[0]
                    for v in graph.G.nodes:
                        graph.G.nodes[v]["node_feature"] = (
                            torch.ones(1) if anchor == v else torch.zeros(1))
                        #print(v, graph.G.nodes[v]["node_feature"])
                neigh = graph.G.subgraph(neigh)
                if use_hard_neg and train:
                    neigh = neigh.copy()
                    if random.random(
                    ) < 1.0 or not self.node_anchored:  # add edges
                        non_edges = list(nx.non_edges(neigh))
                        if len(non_edges) > 0:
                            for u, v in random.sample(
                                    non_edges,
                                    random.randint(1, min(len(non_edges), 5))):
                                neigh.add_edge(u, v)
                    else:  # perturb anchor
                        anchor = random.choice(list(neigh.nodes))
                        for v in neigh.nodes:
                            neigh.nodes[v]["node_feature"] = (torch.ones(1) if
                                                              anchor == v else
                                                              torch.zeros(1))

                if (filter_negs and train and len(neigh) <= 6
                        and neg_target is not None):
                    matcher = nx.algorithms.isomorphism.GraphMatcher(
                        neg_target[graph_idx], neigh)
                    if not matcher.subgraph_is_isomorphic(): done = True
                else:
                    done = True

            return graph, DSGraph(neigh)

        augmenter = feature_preprocess.FeatureAugment()

        pos_target = batch_target
        pos_target, pos_query = pos_target.apply_transform_multi(
            sample_subgraph)
        neg_target = batch_neg_target
        # TODO: use hard negs
        hard_neg_idxs = set(
            random.sample(range(len(neg_target.G)),
                          int(len(neg_target.G) * 1 / 2)))
        #hard_neg_idxs = set()
        batch_neg_query = Batch.from_data_list(
            GraphDataset.list_to_graphs([
                self.generator.generate(
                    size=len(g)) if i not in hard_neg_idxs else g
                for i, g in enumerate(neg_target.G)
            ]))
        for i, g in enumerate(batch_neg_query.G):
            g.graph["idx"] = i
        _, neg_query = batch_neg_query.apply_transform_multi(
            sample_subgraph, hard_neg_idxs=hard_neg_idxs)
        if self.node_anchored:

            def add_anchor(g, anchors=None):
                if anchors is not None:
                    anchor = anchors[g.G.graph["idx"]]
                else:
                    anchor = random.choice(list(g.G.nodes))
                for v in g.G.nodes:
                    if "node_feature" not in g.G.nodes[v]:
                        g.G.nodes[v]["node_feature"] = (
                            torch.ones(1) if anchor == v else torch.zeros(1))
                return g

            neg_target = neg_target.apply_transform(add_anchor)
        pos_target = augmenter.augment(pos_target).to(utils.get_device())
        pos_query = augmenter.augment(pos_query).to(utils.get_device())
        neg_target = augmenter.augment(neg_target).to(utils.get_device())
        neg_query = augmenter.augment(neg_query).to(utils.get_device())
        #print(len(pos_target.G[0]), len(pos_query.G[0]))
        return pos_target, pos_query, neg_target, neg_query