def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch: new_edges = torch.tensor(0) if self.W_edge is not None: new_edges = lrp.add( new_edges, lrp.linear_eps(graphs.edge_features, self.W_edge)) if self.W_sender is not None: new_edges = lrp.add( new_edges, lrp.index_select(lrp.linear_eps(graphs.node_features, self.W_sender), dim=0, index=graphs.senders)) if self.W_receiver is not None: new_edges = lrp.add( new_edges, lrp.index_select(lrp.linear_eps(graphs.node_features, self.W_receiver), dim=0, index=graphs.receivers)) if self.W_global is not None: new_edges = lrp.add( new_edges, lrp.repeat_tensor(lrp.linear_eps(graphs.global_features, self.W_global), dim=0, repeats=graphs.num_edges_by_graph)) if self.bias is not None: new_edges = lrp.add(new_edges, self.bias) return graphs.evolve(edge_features=new_edges)
def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch: new_nodes = torch.tensor(0) if self.W_node is not None: new_nodes = lrp.add( new_nodes, lrp.linear_eps(graphs.node_features, self.W_node)) if self.W_incoming is not None: new_nodes = lrp.add( new_nodes, lrp.linear_eps( self.aggregation(graphs.edge_features, dim=0, index=graphs.receivers, dim_size=graphs.num_nodes), self.W_incoming)) if self.W_outgoing is not None: new_nodes = lrp.add( new_nodes, lrp.linear_eps( self.aggregation(graphs.edge_features, dim=0, index=graphs.senders, dim_size=graphs.num_nodes), self.W_outgoing)) if self.W_global is not None: new_nodes = lrp.add( new_nodes, lrp.repeat_tensor(lrp.linear_eps(graphs.global_features, self.W_global), dim=0, repeats=graphs.num_nodes_by_graph)) if self.bias is not None: new_nodes = lrp.add(new_nodes, self.bias) return graphs.evolve(node_features=new_nodes)
def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch: new_globals = torch.tensor(0) if self.W_node is not None: index = tg.utils.segment_lengths_to_ids(graphs.num_nodes_by_graph) new_globals = lrp.add( new_globals, lrp.linear_eps( self.aggregation(graphs.node_features, dim=0, index=index, dim_size=graphs.num_graphs), self.W_node)) if self.W_edges is not None: index = tg.utils.segment_lengths_to_ids(graphs.num_edges_by_graph) new_globals = lrp.add( new_globals, lrp.linear_eps( self.aggregation(graphs.edge_features, dim=0, index=index, dim_size=graphs.num_graphs), self.W_edges)) if self.W_global is not None: new_globals = lrp.add( new_globals, lrp.linear_eps(graphs.global_features, self.W_global)) if self.bias is not None: new_globals = lrp.add(new_globals, self.bias) return graphs.evolve(global_features=new_globals)
def forward(self, graphs: tg.GraphBatch): edges = F.relu( self.f_e(graphs.edge_features) + self.f_s(graphs.node_features).index_select(dim=0, index=graphs.senders) + self.f_r(graphs.node_features).index_select( dim=0, index=graphs.receivers) + tg.utils.repeat_tensor(self.f_u(graphs.global_features), graphs.num_edges_by_graph)) nodes = F.relu( self.g_n(graphs.node_features) + self.g_in( torch_scatter.scatter_add( edges, graphs.receivers, dim=0, dim_size=graphs.num_nodes)) + self.g_out( torch_scatter.scatter_add( edges, graphs.senders, dim=0, dim_size=graphs.num_nodes)) + tg.utils.repeat_tensor(self.g_u(graphs.global_features), graphs.num_nodes_by_graph)) globals = (self.h_e( torch_scatter.scatter_add( edges, segment_lengths_to_ids(graphs.num_edges_by_graph), dim=0, dim_size=graphs.num_graphs)) + self.h_n( torch_scatter.scatter_add(nodes, segment_lengths_to_ids( graphs.num_nodes_by_graph), dim=0, dim_size=graphs.num_graphs)) + self.h_u(graphs.global_features)) return graphs.evolve( edge_features=edges, node_features=nodes, global_features=globals, )
def forward(self, graphs: tg.GraphBatch): nodes = F.relu(self.g_n(graphs.node_features)) globals = self.h_n( torch_scatter.scatter_add(nodes, segment_lengths_to_ids( graphs.num_nodes_by_graph), dim=0, dim_size=graphs.num_graphs)) return graphs.evolve(num_edges=0, edge_features=None, node_features=None, global_features=globals, senders=None, receivers=None)