Beispiel #1
0
    def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch:
        new_nodes = torch.tensor(0)

        if self.W_node is not None:
            new_nodes = lrp.add(
                new_nodes, lrp.linear_eps(graphs.node_features, self.W_node))
        if self.W_incoming is not None:
            new_nodes = lrp.add(
                new_nodes,
                lrp.linear_eps(
                    self.aggregation(graphs.edge_features,
                                     dim=0,
                                     index=graphs.receivers,
                                     dim_size=graphs.num_nodes),
                    self.W_incoming))
        if self.W_outgoing is not None:
            new_nodes = lrp.add(
                new_nodes,
                lrp.linear_eps(
                    self.aggregation(graphs.edge_features,
                                     dim=0,
                                     index=graphs.senders,
                                     dim_size=graphs.num_nodes),
                    self.W_outgoing))
        if self.W_global is not None:
            new_nodes = lrp.add(
                new_nodes,
                lrp.repeat_tensor(lrp.linear_eps(graphs.global_features,
                                                 self.W_global),
                                  dim=0,
                                  repeats=graphs.num_nodes_by_graph))
        if self.bias is not None:
            new_nodes = lrp.add(new_nodes, self.bias)

        return graphs.evolve(node_features=new_nodes)
Beispiel #2
0
    def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch:
        new_edges = torch.tensor(0)

        if self.W_edge is not None:
            new_edges = lrp.add(
                new_edges, lrp.linear_eps(graphs.edge_features, self.W_edge))
        if self.W_sender is not None:
            new_edges = lrp.add(
                new_edges,
                lrp.index_select(lrp.linear_eps(graphs.node_features,
                                                self.W_sender),
                                 dim=0,
                                 index=graphs.senders))
        if self.W_receiver is not None:
            new_edges = lrp.add(
                new_edges,
                lrp.index_select(lrp.linear_eps(graphs.node_features,
                                                self.W_receiver),
                                 dim=0,
                                 index=graphs.receivers))
        if self.W_global is not None:
            new_edges = lrp.add(
                new_edges,
                lrp.repeat_tensor(lrp.linear_eps(graphs.global_features,
                                                 self.W_global),
                                  dim=0,
                                  repeats=graphs.num_edges_by_graph))
        if self.bias is not None:
            new_edges = lrp.add(new_edges, self.bias)

        return graphs.evolve(edge_features=new_edges)
Beispiel #3
0
 def forward(self, graphs: tg.GraphBatch):
     edges = F.relu(
         self.f_e(graphs.edge_features) +
         self.f_s(graphs.node_features).index_select(dim=0,
                                                     index=graphs.senders) +
         self.f_r(graphs.node_features).index_select(
             dim=0, index=graphs.receivers) +
         tg.utils.repeat_tensor(self.f_u(graphs.global_features),
                                graphs.num_edges_by_graph))
     nodes = F.relu(
         self.g_n(graphs.node_features) + self.g_in(
             torch_scatter.scatter_add(
                 edges, graphs.receivers, dim=0, dim_size=graphs.num_nodes))
         + self.g_out(
             torch_scatter.scatter_add(
                 edges, graphs.senders, dim=0, dim_size=graphs.num_nodes)) +
         tg.utils.repeat_tensor(self.g_u(graphs.global_features),
                                graphs.num_nodes_by_graph))
     globals = (self.h_e(
         torch_scatter.scatter_add(
             edges,
             segment_lengths_to_ids(graphs.num_edges_by_graph),
             dim=0,
             dim_size=graphs.num_graphs)) + self.h_n(
                 torch_scatter.scatter_add(nodes,
                                           segment_lengths_to_ids(
                                               graphs.num_nodes_by_graph),
                                           dim=0,
                                           dim_size=graphs.num_graphs)) +
                self.h_u(graphs.global_features))
     return graphs.evolve(
         edge_features=edges,
         node_features=nodes,
         global_features=globals,
     )
Beispiel #4
0
    def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch:
        new_globals = torch.tensor(0)

        if self.W_node is not None:
            index = tg.utils.segment_lengths_to_ids(graphs.num_nodes_by_graph)
            new_globals = lrp.add(
                new_globals,
                lrp.linear_eps(
                    self.aggregation(graphs.node_features,
                                     dim=0,
                                     index=index,
                                     dim_size=graphs.num_graphs), self.W_node))
        if self.W_edges is not None:
            index = tg.utils.segment_lengths_to_ids(graphs.num_edges_by_graph)
            new_globals = lrp.add(
                new_globals,
                lrp.linear_eps(
                    self.aggregation(graphs.edge_features,
                                     dim=0,
                                     index=index,
                                     dim_size=graphs.num_graphs),
                    self.W_edges))
        if self.W_global is not None:
            new_globals = lrp.add(
                new_globals,
                lrp.linear_eps(graphs.global_features, self.W_global))
        if self.bias is not None:
            new_globals = lrp.add(new_globals, self.bias)

        return graphs.evolve(global_features=new_globals)
Beispiel #5
0
def test_collate_dicts(graphs_nx, features_shapes, device):
    graphs_in = [
        add_random_features(Graph.from_networkx(g),
                            **features_shapes).to(device) for g in graphs_nx
    ]
    graphs_out = list(reversed(graphs_in))
    xs = torch.rand(len(graphs_in), 10, 32)
    ys = torch.rand(len(graphs_in), 7)

    samples = [{
        'in': gi,
        'x': x,
        'y': y,
        'out': go
    } for gi, x, y, go in zip(graphs_in, xs, ys, graphs_out)]
    batch = GraphBatch.collate(samples)

    for g1, g2 in zip(graphs_in, batch['in']):
        assert_graphs_equal(g1, g2)

    torch.testing.assert_allclose(xs, batch['x'])

    torch.testing.assert_allclose(ys, batch['y'])

    for g1, g2 in zip(graphs_out, batch['out']):
        assert_graphs_equal(g1, g2)
Beispiel #6
0
 def forward(self, graphs: tg.GraphBatch):
     nodes = F.relu(self.g_n(graphs.node_features))
     globals = self.h_n(
         torch_scatter.scatter_add(nodes,
                                   segment_lengths_to_ids(
                                       graphs.num_nodes_by_graph),
                                   dim=0,
                                   dim_size=graphs.num_graphs))
     return graphs.evolve(num_edges=0,
                          edge_features=None,
                          node_features=None,
                          global_features=globals,
                          senders=None,
                          receivers=None)
def test_corner_cases(features_shapes, device):
    # Only some graphs have node/edge features, global features are either present on all of them or absent from all
    gfs = features_shapes['global_features_shape']
    graphs = [
        add_random_features(Graph(num_nodes=0, num_edges=0),
                            global_features_shape=gfs),
        add_random_features(Graph(num_nodes=0, num_edges=0),
                            global_features_shape=gfs),
        add_random_features(Graph(num_nodes=3, num_edges=0),
                            **features_shapes),
        add_random_features(Graph(num_nodes=0, num_edges=0),
                            **features_shapes),
        add_random_features(
            Graph(num_nodes=2,
                  senders=torch.tensor([0, 1]),
                  receivers=torch.tensor([1, 0])), **features_shapes)
    ]
    graphbatch = GraphBatch.from_graphs(graphs).to(device)
    validate_batch(graphbatch)

    for g_orig, g_batch in zip(graphs, graphbatch):
        assert_graphs_equal(g_orig, g_batch.cpu())

    # Global features should be either present on all graphs or absent from all graphs
    with pytest.raises(ValueError):
        GraphBatch.from_graphs([
            Graph(num_nodes=0, num_edges=0),
            add_random_features(Graph(num_nodes=0, num_edges=0),
                                global_features_shape=10)
        ])
    with pytest.raises(ValueError):
        GraphBatch.from_graphs([
            add_random_features(Graph(num_nodes=0, num_edges=0),
                                global_features_shape=10),
            Graph(num_nodes=0, num_edges=0)
        ])
def test_from_graphs(graphs, features_shapes, device):
    graphs = [
        add_random_features(g, **features_shapes).to(device) for g in graphs
    ]
    graphbatch = GraphBatch.from_graphs(graphs)

    validate_batch(graphbatch)

    assert len(graphs) == len(graphbatch) == graphbatch.num_graphs
    assert [g.num_nodes
            for g in graphs] == graphbatch.num_nodes_by_graph.tolist()
    assert [g.num_edges
            for g in graphs] == graphbatch.num_edges_by_graph.tolist()

    # Test sequential access (__iter__)
    for g, gb in zip(graphs, graphbatch):
        assert_graphs_equal(g, gb)

    # Test random access (__getitem__)
    for i in range(len(graphbatch)):
        assert_graphs_equal(graphs[i], graphbatch[i])
Beispiel #9
0
def test_collate_tuples(graphs_nx, features_shapes, device):
    graphs_in = [
        add_random_features(Graph.from_networkx(g),
                            **features_shapes).to(device) for g in graphs_nx
    ]
    graphs_out = list(reversed(graphs_in))
    xs = torch.rand(len(graphs_in), 10, 32)
    ys = torch.rand(len(graphs_in), 7)

    samples = list(zip(graphs_in, xs, ys, graphs_out))
    batch = GraphBatch.collate(samples)

    for g1, g2 in zip(graphs_in, batch[0]):
        assert_graphs_equal(g1, g2)

    torch.testing.assert_allclose(xs, batch[1])

    torch.testing.assert_allclose(ys, batch[2])

    for g1, g2 in zip(graphs_out, batch[3]):
        assert_graphs_equal(g1, g2)
def test_from_to_networkxs(graphs_nx, features_shapes, device):
    graphs_nx = [add_random_features(g, **features_shapes) for g in graphs_nx]
    graphbatch = GraphBatch.from_networkxs(graphs_nx).to(device)

    validate_batch(graphbatch)

    assert len(graphs_nx) == len(graphbatch) == graphbatch.num_graphs
    assert [g.number_of_nodes()
            for g in graphs_nx] == graphbatch.num_nodes_by_graph.tolist()
    assert [g.number_of_edges()
            for g in graphs_nx] == graphbatch.num_edges_by_graph.tolist()

    # Test sequential access (__iter__)
    for g_nx, g in zip(graphs_nx, graphbatch):
        assert_graphs_equal(g_nx, g.cpu())

    # Test random access (__getitem__)
    for i in range(len(graphbatch)):
        assert_graphs_equal(graphs_nx[i], graphbatch[i].cpu())

    # Test back conversion
    graphs_nx_back = graphbatch.cpu().to_networkxs()
    for g1, g2 in zip(graphs_nx, graphs_nx_back):
        assert_graphs_equal(g1, g2)
Beispiel #11
0
def graphbatch() -> GraphBatch:
    return GraphBatch.from_networkxs(graphs_for_test().values())