def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch: new_nodes = torch.tensor(0) if self.W_node is not None: new_nodes = lrp.add( new_nodes, lrp.linear_eps(graphs.node_features, self.W_node)) if self.W_incoming is not None: new_nodes = lrp.add( new_nodes, lrp.linear_eps( self.aggregation(graphs.edge_features, dim=0, index=graphs.receivers, dim_size=graphs.num_nodes), self.W_incoming)) if self.W_outgoing is not None: new_nodes = lrp.add( new_nodes, lrp.linear_eps( self.aggregation(graphs.edge_features, dim=0, index=graphs.senders, dim_size=graphs.num_nodes), self.W_outgoing)) if self.W_global is not None: new_nodes = lrp.add( new_nodes, lrp.repeat_tensor(lrp.linear_eps(graphs.global_features, self.W_global), dim=0, repeats=graphs.num_nodes_by_graph)) if self.bias is not None: new_nodes = lrp.add(new_nodes, self.bias) return graphs.evolve(node_features=new_nodes)
def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch: new_edges = torch.tensor(0) if self.W_edge is not None: new_edges = lrp.add( new_edges, lrp.linear_eps(graphs.edge_features, self.W_edge)) if self.W_sender is not None: new_edges = lrp.add( new_edges, lrp.index_select(lrp.linear_eps(graphs.node_features, self.W_sender), dim=0, index=graphs.senders)) if self.W_receiver is not None: new_edges = lrp.add( new_edges, lrp.index_select(lrp.linear_eps(graphs.node_features, self.W_receiver), dim=0, index=graphs.receivers)) if self.W_global is not None: new_edges = lrp.add( new_edges, lrp.repeat_tensor(lrp.linear_eps(graphs.global_features, self.W_global), dim=0, repeats=graphs.num_edges_by_graph)) if self.bias is not None: new_edges = lrp.add(new_edges, self.bias) return graphs.evolve(edge_features=new_edges)
def forward(self, graphs: tg.GraphBatch): edges = F.relu( self.f_e(graphs.edge_features) + self.f_s(graphs.node_features).index_select(dim=0, index=graphs.senders) + self.f_r(graphs.node_features).index_select( dim=0, index=graphs.receivers) + tg.utils.repeat_tensor(self.f_u(graphs.global_features), graphs.num_edges_by_graph)) nodes = F.relu( self.g_n(graphs.node_features) + self.g_in( torch_scatter.scatter_add( edges, graphs.receivers, dim=0, dim_size=graphs.num_nodes)) + self.g_out( torch_scatter.scatter_add( edges, graphs.senders, dim=0, dim_size=graphs.num_nodes)) + tg.utils.repeat_tensor(self.g_u(graphs.global_features), graphs.num_nodes_by_graph)) globals = (self.h_e( torch_scatter.scatter_add( edges, segment_lengths_to_ids(graphs.num_edges_by_graph), dim=0, dim_size=graphs.num_graphs)) + self.h_n( torch_scatter.scatter_add(nodes, segment_lengths_to_ids( graphs.num_nodes_by_graph), dim=0, dim_size=graphs.num_graphs)) + self.h_u(graphs.global_features)) return graphs.evolve( edge_features=edges, node_features=nodes, global_features=globals, )
def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch: new_globals = torch.tensor(0) if self.W_node is not None: index = tg.utils.segment_lengths_to_ids(graphs.num_nodes_by_graph) new_globals = lrp.add( new_globals, lrp.linear_eps( self.aggregation(graphs.node_features, dim=0, index=index, dim_size=graphs.num_graphs), self.W_node)) if self.W_edges is not None: index = tg.utils.segment_lengths_to_ids(graphs.num_edges_by_graph) new_globals = lrp.add( new_globals, lrp.linear_eps( self.aggregation(graphs.edge_features, dim=0, index=index, dim_size=graphs.num_graphs), self.W_edges)) if self.W_global is not None: new_globals = lrp.add( new_globals, lrp.linear_eps(graphs.global_features, self.W_global)) if self.bias is not None: new_globals = lrp.add(new_globals, self.bias) return graphs.evolve(global_features=new_globals)
def test_collate_dicts(graphs_nx, features_shapes, device): graphs_in = [ add_random_features(Graph.from_networkx(g), **features_shapes).to(device) for g in graphs_nx ] graphs_out = list(reversed(graphs_in)) xs = torch.rand(len(graphs_in), 10, 32) ys = torch.rand(len(graphs_in), 7) samples = [{ 'in': gi, 'x': x, 'y': y, 'out': go } for gi, x, y, go in zip(graphs_in, xs, ys, graphs_out)] batch = GraphBatch.collate(samples) for g1, g2 in zip(graphs_in, batch['in']): assert_graphs_equal(g1, g2) torch.testing.assert_allclose(xs, batch['x']) torch.testing.assert_allclose(ys, batch['y']) for g1, g2 in zip(graphs_out, batch['out']): assert_graphs_equal(g1, g2)
def forward(self, graphs: tg.GraphBatch): nodes = F.relu(self.g_n(graphs.node_features)) globals = self.h_n( torch_scatter.scatter_add(nodes, segment_lengths_to_ids( graphs.num_nodes_by_graph), dim=0, dim_size=graphs.num_graphs)) return graphs.evolve(num_edges=0, edge_features=None, node_features=None, global_features=globals, senders=None, receivers=None)
def test_corner_cases(features_shapes, device): # Only some graphs have node/edge features, global features are either present on all of them or absent from all gfs = features_shapes['global_features_shape'] graphs = [ add_random_features(Graph(num_nodes=0, num_edges=0), global_features_shape=gfs), add_random_features(Graph(num_nodes=0, num_edges=0), global_features_shape=gfs), add_random_features(Graph(num_nodes=3, num_edges=0), **features_shapes), add_random_features(Graph(num_nodes=0, num_edges=0), **features_shapes), add_random_features( Graph(num_nodes=2, senders=torch.tensor([0, 1]), receivers=torch.tensor([1, 0])), **features_shapes) ] graphbatch = GraphBatch.from_graphs(graphs).to(device) validate_batch(graphbatch) for g_orig, g_batch in zip(graphs, graphbatch): assert_graphs_equal(g_orig, g_batch.cpu()) # Global features should be either present on all graphs or absent from all graphs with pytest.raises(ValueError): GraphBatch.from_graphs([ Graph(num_nodes=0, num_edges=0), add_random_features(Graph(num_nodes=0, num_edges=0), global_features_shape=10) ]) with pytest.raises(ValueError): GraphBatch.from_graphs([ add_random_features(Graph(num_nodes=0, num_edges=0), global_features_shape=10), Graph(num_nodes=0, num_edges=0) ])
def test_from_graphs(graphs, features_shapes, device): graphs = [ add_random_features(g, **features_shapes).to(device) for g in graphs ] graphbatch = GraphBatch.from_graphs(graphs) validate_batch(graphbatch) assert len(graphs) == len(graphbatch) == graphbatch.num_graphs assert [g.num_nodes for g in graphs] == graphbatch.num_nodes_by_graph.tolist() assert [g.num_edges for g in graphs] == graphbatch.num_edges_by_graph.tolist() # Test sequential access (__iter__) for g, gb in zip(graphs, graphbatch): assert_graphs_equal(g, gb) # Test random access (__getitem__) for i in range(len(graphbatch)): assert_graphs_equal(graphs[i], graphbatch[i])
def test_collate_tuples(graphs_nx, features_shapes, device): graphs_in = [ add_random_features(Graph.from_networkx(g), **features_shapes).to(device) for g in graphs_nx ] graphs_out = list(reversed(graphs_in)) xs = torch.rand(len(graphs_in), 10, 32) ys = torch.rand(len(graphs_in), 7) samples = list(zip(graphs_in, xs, ys, graphs_out)) batch = GraphBatch.collate(samples) for g1, g2 in zip(graphs_in, batch[0]): assert_graphs_equal(g1, g2) torch.testing.assert_allclose(xs, batch[1]) torch.testing.assert_allclose(ys, batch[2]) for g1, g2 in zip(graphs_out, batch[3]): assert_graphs_equal(g1, g2)
def test_from_to_networkxs(graphs_nx, features_shapes, device): graphs_nx = [add_random_features(g, **features_shapes) for g in graphs_nx] graphbatch = GraphBatch.from_networkxs(graphs_nx).to(device) validate_batch(graphbatch) assert len(graphs_nx) == len(graphbatch) == graphbatch.num_graphs assert [g.number_of_nodes() for g in graphs_nx] == graphbatch.num_nodes_by_graph.tolist() assert [g.number_of_edges() for g in graphs_nx] == graphbatch.num_edges_by_graph.tolist() # Test sequential access (__iter__) for g_nx, g in zip(graphs_nx, graphbatch): assert_graphs_equal(g_nx, g.cpu()) # Test random access (__getitem__) for i in range(len(graphbatch)): assert_graphs_equal(graphs_nx[i], graphbatch[i].cpu()) # Test back conversion graphs_nx_back = graphbatch.cpu().to_networkxs() for g1, g2 in zip(graphs_nx, graphs_nx_back): assert_graphs_equal(g1, g2)
def graphbatch() -> GraphBatch: return GraphBatch.from_networkxs(graphs_for_test().values())