def _forward_core(self, latent0, data): e = torch.cat([latent0.e, data.e], dim=1) x = torch.cat([latent0.x, data.x], dim=1) g = torch.cat([latent0.g, data.g], dim=1) data = GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx) e, x, g = self.core(data) return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
def core_process(self, latent0: GraphBatch, data: GraphBatch) -> GraphBatch: e = torch.cat([latent0.e, data.e], dim=1) x = torch.cat([latent0.x, data.x], dim=1) g = torch.cat([latent0.g, data.g], dim=1) data = GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx) e, x, g = self.core(data) return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
def data(request): args = (5, 4, 3) kwargs = dict(min_nodes=10, max_nodes=10, min_edges=5, max_edges=5) if request.param is GraphData: return GraphData.random(*args, **kwargs) else: return GraphBatch.random_batch(100, *args, **kwargs)
def test_to_networkx_list(self, fkey_gkey): fkey, gkey = fkey_gkey datalist = [random_graph_data(5, 5, 5) for _ in range(3)] batch = GraphBatch.from_data_list(datalist) graphs = batch.to_networkx_list(feature_key=fkey, global_attr_key=gkey) for data, graph in zip(datalist, graphs): Comparator.data_to_nx(data, graph, fkey, gkey)
def data(request): deterministic_seed(0) data_cls = request.param if data_cls is GraphData: return GraphData.random(5, 4, 3) else: return GraphBatch.random_batch(10, 5, 4, 3)
def test_fully_connected_singe_graph_batch(): deterministic_seed(0) data = GraphData.random(5, 4, 3) batch = GraphBatch.from_data_list([data]) t = FullyConnected() batch2 = t(batch) assert batch2.edges.shape[1] > batch.edges.shape[1]
def test_node_mask_entire_graph(): data1 = GraphData( torch.randn(5, 3), torch.randn(3, 3), torch.randn(1, 1), edges=torch.LongTensor([[0, 1, 2], [2, 3, 4]]), ) data2 = GraphData( torch.randn(5, 3), torch.randn(3, 3), torch.randn(1, 1), edges=torch.LongTensor([[0, 1, 2], [2, 3, 4]]), ) data3 = GraphData( torch.randn(5, 3), torch.randn(3, 3), torch.randn(1, 1), edges=torch.LongTensor([[0, 1, 2], [2, 3, 4]]), ) batch = GraphBatch.from_data_list([data1, data2, data3]) node_mask = torch.BoolTensor([True] * 5 + [False] * 5 + [True] * 5) masked = batch.apply_node_mask(node_mask) print(masked.node_idx) print(masked.edge_idx)
def test_invalid_batch(self, offsets): data1 = GraphData( torch.randn(10, 10), torch.randn(3, 4), torch.randn(1, 3), torch.randint(0, 10, torch.Size([2, 3])), ) data2 = GraphData( torch.randn(12, 10 + offsets[0]), torch.randn(4, 4 + offsets[1]), torch.randn(1, 3 + offsets[2]), torch.randint(0, 10, torch.Size([2, 4])), ) with pytest.raises(RuntimeError): GraphBatch.from_data_list([data1, data2])
def __getitem__(self, idx: int) -> GraphBatch: samples = self.get(idx, transform=None) if isinstance(samples, GraphData): samples = [samples] batch = GraphBatch.from_data_list(samples) if self.transform: batch = self.transform(batch) return batch
def test_from_data_list(self, n): datalist = [random_graph_data(5, 3, 4) for _ in range(n)] batch = GraphBatch.from_data_list(datalist) assert batch.x.shape[0] > n assert batch.e.shape[0] > n assert batch.g.shape[0] == n assert batch.x.shape[1] == 5 assert batch.e.shape[1] == 3 assert batch.g.shape[1] == 4
def to(batch, device, **kwargs): return GraphBatch( batch.x.to(device, **kwargs), batch.e.to(device, **kwargs), batch.g.to(device, **kwargs), batch.edges.to(device, **kwargs), batch.node_idx.to(device, **kwargs), batch.edge_idx.to(device, **kwargs), )
def random_data_example(request): if request.param[0] is GraphData: graph_data = random_graph_data(*request.param[1]) return graph_data elif request.param[0] is GraphBatch: datalist = [random_graph_data(*request.param[1]) for _ in range(10)] batch = GraphBatch.from_data_list(datalist) return batch else: raise Exception("Parameter not acceptable: {}".format(request.param))
def test_is_differentiable__to_datalist(self, attr): datalist = [random_graph_data(5, 3, 4) for _ in range(300)] batch = GraphBatch.from_data_list(datalist) getattr(batch, attr).requires_grad = True datalist = batch.to_data_list() for data in datalist: assert getattr(data, attr).requires_grad
def validate_plot(self, batch, num_graphs=10): x, y = batch x.to_data_list() y.to_data_list() x = GraphBatch.from_data_list(x.to_data_list()[:num_graphs]) y = GraphBatch.from_data_list(y.to_data_list()[:num_graphs]) y_hat = self.model.forward(x, 10)[-1] y_graphs = y.to_networkx_list() y_hat_graphs = y_hat.to_networkx_list() figs = [] for idx in range(len(y_graphs)): yg = y_graphs[idx] yhg = y_hat_graphs[idx] fig, axes = comparison_plot(yhg, yg) figs.append(fig) with figs_to_pils(figs) as pils: self.logger.experiment.log({"image": [wandb.Image(pil) for pil in pils]})
def create_loader(generator, graphs, batch_size, shuffle, pin_memory=False): train_batch = GraphBatch.from_networkx_list(graphs, n_edge_feat=1, n_node_feat=generator.n_parts, n_glob_feat=1) target_batch = GraphBatch.from_networkx_list(graphs, n_edge_feat=16, n_node_feat=1, n_glob_feat=1, feature_key="target") train_list = train_batch.to_data_list() target_list = target_batch.to_data_list() if batch_size is None: batch_size = len(train_list) return GraphDataLoader( list(zip(train_list, target_list)), batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory, )
def collate(data_list): if isinstance(data_list[0], tuple): if issubclass(type(data_list[0][0]), GraphData): return tuple([ collate([x[i] for x in data_list]) for i in range(len(data_list[0])) ]) else: raise RuntimeError("Cannot collate {}({})({})".format( type(data_list), type(data_list[0]), type(data_list[0][0]))) return GraphBatch.from_data_list(data_list)
def test_is_differentiable__append_nodes(self, attr): datalist = [random_graph_data(5, 3, 4) for _ in range(300)] for data in datalist: getattr(data, attr).requires_grad = True batch = GraphBatch.from_data_list(datalist) new_nodes = torch.randn(10, 5) idx = torch.ones(10, dtype=torch.long) n_nodes = batch.x.shape[0] batch.append_nodes(new_nodes, idx) assert batch.x.shape[0] == n_nodes + 10 assert getattr(batch, attr).requires_grad
def test_fully_connected_singe_graph_batch_manual(): deterministic_seed(0) x = torch.randn((3, 1)) e = torch.randn((2, 2)) g = torch.randn((3, 1)) edges = torch.tensor([[0, 1], [0, 1]]) data = GraphData(x, e, g, edges) batch = GraphBatch.from_data_list([data, data]) batch2 = FullyConnected()(batch) print(batch2.edges) assert batch2.edges.shape[1] == 18 edges_set = _edges_to_tuples_set(batch2.edges) assert len(edges_set) == 18
def forward(self, data, steps, save_all: bool = False): # encoded e, x, g = self.encoder(data) data = GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx) # graph topography data edges = data.edges node_idx = data.node_idx edge_idx = data.edge_idx latent0 = data meta = (edges, node_idx, edge_idx) outputs = [] for _ in range(steps): # core processing step e = torch.cat([latent0.e, e], dim=1) x = torch.cat([latent0.x, x], dim=1) g = torch.cat([latent0.g, g], dim=1) data = GraphBatch(x, e, g, *meta) e, x, g = self.core(data) # decode data = GraphBatch(x, e, g, *meta) _e, _x, _g = self.decoder(data) decoded = GraphBatch(_x, _e, _g, *meta) # transform _e, _x, _g = self.output_transform(decoded) print() gt = GraphBatch(_x, _e, _g, edges, node_idx, edge_idx) if save_all: outputs.append(gt) else: outputs = [gt] return outputs
def test_k_hop_random_graph(k): g1 = nx.grid_graph(dim=[2, 3, 4]) g2 = nx.grid_graph(dim=[2, 3, 4]) g1 = nx_to_directed(g1) g2 = nx_to_directed(g2) nx_random_features(g1, 5, 4, 3) nx_random_features(g2, 5, 4, 3) batch = GraphBatch.from_networkx_list([g1, g2]) nodes = torch.BoolTensor([False] * batch.num_nodes) nodes[0] = True node_mask = induce(batch, nodes, k) subgraph = batch.apply_node_mask(node_mask) print(subgraph.info())
def test_shuffle_graphs(shuffle): args = (5, 4, 3) kwargs = dict(min_nodes=5, max_nodes=5, min_edges=5, max_edges=5) data = GraphBatch.random_batch(100, *args, **kwargs) data1, data2 = shuffle(data) if data.__class__ is GraphData: pytest.xfail("GraphData has no `shuffle_graphs` method") assert torch.all(data1.e == data2.e) assert not torch.all(data1.g == data2.g) assert torch.all(data1.x == data2.x) assert torch.all(data1.edges == data2.edges) assert not torch.all(data1.node_idx == data2.node_idx) assert not torch.all(data1.edge_idx == data2.edge_idx)
def grid_data(request): def newg(g): return nx_random_features(g, 5, 4, 3) if request.param is GraphData: g = nx_to_directed(newg(nx.grid_graph([2, 4, 3]))) return GraphData.from_networkx(g) elif request.param is GraphBatch: graphs = [ nx_to_directed(newg(nx.grid_graph([2, 4, 3]))) for _ in range(10) ] return GraphBatch.from_networkx_list(graphs) else: raise ValueError()
def data(request): data_cls = request.param deterministic_seed(0) x = torch.randn((4, 1)) e = torch.randn((4, 2)) g = torch.randn((3, 1)) edges = torch.tensor([[0, 1, 2, 1], [1, 2, 3, 0]]) data = GraphData(x, e, g, edges) if data_cls is GraphBatch: return GraphBatch.from_data_list([data]) else: return data
def test_batch_append_nodes(self): datalist = [random_graph_data(5, 6, 7) for _ in range(10)] batch = GraphBatch.from_data_list(datalist) x = torch.randn(3, 5) idx = torch.tensor([0, 1, 2]) node_shape = batch.x.shape batch.append_nodes(x, idx) print(batch.node_idx.shape) assert node_shape[0] < batch.x.shape[0] print(batch.node_idx.shape) print(batch.x.shape)
def test_k_hop_random_graph_benchmark(benchmark): """Bench mark for using tensor_induce for k-hop. :return: """ k = 2 batch = GraphBatch.random_batch(1000, 50, 20, 30) def run(): nodes = torch.full((batch.num_nodes, ), False, dtype=torch.bool) idx = torch.randint(batch.num_nodes, (10, )) nodes[idx] = True node_mask = tensor_induce(batch, nodes, k) subgraph = batch.apply_node_mask(node_mask) benchmark(run)
def test_(): data1 = GraphData( torch.randn(1, 5), torch.randn(1, 4), torch.randn(1, 3), torch.LongTensor([[0], [0]]), ) data2 = GraphData( torch.randn(1, 5), torch.randn(0, 4), torch.randn(1, 3), torch.LongTensor([[], []]), ) batch = GraphBatch.from_data_list([data1, data2])
def test_is_differentiable__append_edges(self, attr): datalist = [random_graph_data(5, 3, 4) for _ in range(300)] for data in datalist: getattr(data, attr).requires_grad = True batch = GraphBatch.from_data_list(datalist) new_edge_attr = torch.randn(20, 3) new_edges = torch.randint(0, batch.x.shape[0], (2, 20)) idx = torch.randint(0, 30, (new_edges.shape[1], )) idx = torch.sort(idx).values n_edges = batch.e.shape[0] batch.append_edges(new_edge_attr, new_edges, idx) assert batch.e.shape[0] == n_edges + 20 assert batch.edge_idx.shape[0] == n_edges + 20 assert batch.edges.shape[1] == n_edges + 20 assert getattr(batch, attr).requires_grad
def test_to_datalist(self): datalist = [random_graph_data(5, 6, 7) for _ in range(1000)] batch = GraphBatch.from_data_list(datalist) print(batch.shape) print(batch.size) datalist2 = batch.to_data_list() assert len(datalist) == len(datalist2) def sort(a): return a[:, torch.sort(a).indices[0]] for data in datalist2: print(sort(data.edges)) for data in datalist2: print(sort(data.edges)) for d1, d2 in zip(datalist, datalist2): assert d1.allclose(d2)
def test_basic_batch2(self): data1 = GraphData( torch.tensor([[0], [0]]), torch.tensor([[0], [0]]), torch.tensor([[0]]), torch.tensor([[0, 1], [1, 0]]), ) data2 = GraphData( torch.tensor([[0], [0], [0], [0], [0]]), torch.tensor([[0], [0], [0]]), torch.tensor([[0]]), torch.tensor([[1, 2, 1], [4, 2, 1]]), ) batch = GraphBatch.from_data_list([data1, data2]) print(batch.edges) datalist2 = batch.to_data_list() print(datalist2[0].edges) print(datalist2[1].edges)
def test_k_hop_random_graph_benchmark2(benchmark): """Benchmark for using floydwarshall for k-hop. :return: """ k = 2 batch = GraphBatch.random_batch(1000, 50, 20, 30) def run(n): nodes_list = [] for _ in range(n): nodes = torch.full((batch.num_nodes, ), False, dtype=torch.bool) idx = torch.randint(batch.num_nodes, (10, )) nodes[idx] = True nodes_list.append(nodes) nodes_list = tuple(nodes_list) masks = floyd_warshall_neighbors(batch, nodes_list, depth=k) for mask in masks: subgraph = batch.apply_node_mask(mask) benchmark(run, 1)