Exemple #1
0
    def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch:
        new_edges = torch.tensor(0)

        if self.W_edge is not None:
            new_edges = lrp.add(
                new_edges, lrp.linear_eps(graphs.edge_features, self.W_edge))
        if self.W_sender is not None:
            new_edges = lrp.add(
                new_edges,
                lrp.index_select(lrp.linear_eps(graphs.node_features,
                                                self.W_sender),
                                 dim=0,
                                 index=graphs.senders))
        if self.W_receiver is not None:
            new_edges = lrp.add(
                new_edges,
                lrp.index_select(lrp.linear_eps(graphs.node_features,
                                                self.W_receiver),
                                 dim=0,
                                 index=graphs.receivers))
        if self.W_global is not None:
            new_edges = lrp.add(
                new_edges,
                lrp.repeat_tensor(lrp.linear_eps(graphs.global_features,
                                                 self.W_global),
                                  dim=0,
                                  repeats=graphs.num_edges_by_graph))
        if self.bias is not None:
            new_edges = lrp.add(new_edges, self.bias)

        return graphs.evolve(edge_features=new_edges)
Exemple #2
0
    def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch:
        new_nodes = torch.tensor(0)

        if self.W_node is not None:
            new_nodes = lrp.add(
                new_nodes, lrp.linear_eps(graphs.node_features, self.W_node))
        if self.W_incoming is not None:
            new_nodes = lrp.add(
                new_nodes,
                lrp.linear_eps(
                    self.aggregation(graphs.edge_features,
                                     dim=0,
                                     index=graphs.receivers,
                                     dim_size=graphs.num_nodes),
                    self.W_incoming))
        if self.W_outgoing is not None:
            new_nodes = lrp.add(
                new_nodes,
                lrp.linear_eps(
                    self.aggregation(graphs.edge_features,
                                     dim=0,
                                     index=graphs.senders,
                                     dim_size=graphs.num_nodes),
                    self.W_outgoing))
        if self.W_global is not None:
            new_nodes = lrp.add(
                new_nodes,
                lrp.repeat_tensor(lrp.linear_eps(graphs.global_features,
                                                 self.W_global),
                                  dim=0,
                                  repeats=graphs.num_nodes_by_graph))
        if self.bias is not None:
            new_nodes = lrp.add(new_nodes, self.bias)

        return graphs.evolve(node_features=new_nodes)
Exemple #3
0
    def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch:
        new_globals = torch.tensor(0)

        if self.W_node is not None:
            index = tg.utils.segment_lengths_to_ids(graphs.num_nodes_by_graph)
            new_globals = lrp.add(
                new_globals,
                lrp.linear_eps(
                    self.aggregation(graphs.node_features,
                                     dim=0,
                                     index=index,
                                     dim_size=graphs.num_graphs), self.W_node))
        if self.W_edges is not None:
            index = tg.utils.segment_lengths_to_ids(graphs.num_edges_by_graph)
            new_globals = lrp.add(
                new_globals,
                lrp.linear_eps(
                    self.aggregation(graphs.edge_features,
                                     dim=0,
                                     index=index,
                                     dim_size=graphs.num_graphs),
                    self.W_edges))
        if self.W_global is not None:
            new_globals = lrp.add(
                new_globals,
                lrp.linear_eps(graphs.global_features, self.W_global))
        if self.bias is not None:
            new_globals = lrp.add(new_globals, self.bias)

        return graphs.evolve(global_features=new_globals)
Exemple #4
0
 def forward(self, graphs: tg.GraphBatch):
     edges = F.relu(
         self.f_e(graphs.edge_features) +
         self.f_s(graphs.node_features).index_select(dim=0,
                                                     index=graphs.senders) +
         self.f_r(graphs.node_features).index_select(
             dim=0, index=graphs.receivers) +
         tg.utils.repeat_tensor(self.f_u(graphs.global_features),
                                graphs.num_edges_by_graph))
     nodes = F.relu(
         self.g_n(graphs.node_features) + self.g_in(
             torch_scatter.scatter_add(
                 edges, graphs.receivers, dim=0, dim_size=graphs.num_nodes))
         + self.g_out(
             torch_scatter.scatter_add(
                 edges, graphs.senders, dim=0, dim_size=graphs.num_nodes)) +
         tg.utils.repeat_tensor(self.g_u(graphs.global_features),
                                graphs.num_nodes_by_graph))
     globals = (self.h_e(
         torch_scatter.scatter_add(
             edges,
             segment_lengths_to_ids(graphs.num_edges_by_graph),
             dim=0,
             dim_size=graphs.num_graphs)) + self.h_n(
                 torch_scatter.scatter_add(nodes,
                                           segment_lengths_to_ids(
                                               graphs.num_nodes_by_graph),
                                           dim=0,
                                           dim_size=graphs.num_graphs)) +
                self.h_u(graphs.global_features))
     return graphs.evolve(
         edge_features=edges,
         node_features=nodes,
         global_features=globals,
     )
Exemple #5
0
 def forward(self, graphs: tg.GraphBatch):
     nodes = F.relu(self.g_n(graphs.node_features))
     globals = self.h_n(
         torch_scatter.scatter_add(nodes,
                                   segment_lengths_to_ids(
                                       graphs.num_nodes_by_graph),
                                   dim=0,
                                   dim_size=graphs.num_graphs))
     return graphs.evolve(num_edges=0,
                          edge_features=None,
                          node_features=None,
                          global_features=globals,
                          senders=None,
                          receivers=None)