Beispiel #1
0
def test_collate_dicts(graphs_nx, features_shapes, device):
    graphs_in = [
        add_random_features(Graph.from_networkx(g),
                            **features_shapes).to(device) for g in graphs_nx
    ]
    graphs_out = list(reversed(graphs_in))
    xs = torch.rand(len(graphs_in), 10, 32)
    ys = torch.rand(len(graphs_in), 7)

    samples = [{
        'in': gi,
        'x': x,
        'y': y,
        'out': go
    } for gi, x, y, go in zip(graphs_in, xs, ys, graphs_out)]
    batch = GraphBatch.collate(samples)

    for g1, g2 in zip(graphs_in, batch['in']):
        assert_graphs_equal(g1, g2)

    torch.testing.assert_allclose(xs, batch['x'])

    torch.testing.assert_allclose(ys, batch['y'])

    for g1, g2 in zip(graphs_out, batch['out']):
        assert_graphs_equal(g1, g2)
Beispiel #2
0
def test_device(graph, features_shapes, device):
    graph = add_random_features(graph, **features_shapes)
    other_graph = graph.to(device)

    for k in other_graph._feature_fields:
        assert (getattr(other_graph, k) is None) or (getattr(
            other_graph, k).device == device)

    assert_graphs_equal(graph, other_graph.cpu())
def test_linear_graph_network(graphbatch: GraphBatch, device):
    graphbatch = add_random_features(graphbatch, **linear_features).to(device)

    node_linear = NodeLinear(
        out_features=linear_features['node_features_shape'],
        incoming_features=linear_features['edge_features_shape'],
        node_features=linear_features['node_features_shape'],
        global_features=linear_features['global_features_shape'],
        aggregation='mean')
    edge_linear = EdgeLinear(
        out_features=linear_features['edge_features_shape'],
        edge_features=linear_features['edge_features_shape'],
        sender_features=linear_features['node_features_shape'],
        receiver_features=linear_features['node_features_shape'],
        global_features=linear_features['global_features_shape'])
    global_linear = GlobalLinear(
        out_features=linear_features['global_features_shape'],
        edge_features=linear_features['edge_features_shape'],
        node_features=linear_features['node_features_shape'],
        global_features=linear_features['global_features_shape'],
        aggregation='mean')

    net = torch.nn.Sequential(
        OrderedDict([
            ('edge', edge_linear),
            ('edge_relu', EdgeReLU()),
            ('node', node_linear),
            ('node_relu', NodeReLU()),
            ('global', global_linear),
            ('global_relu', GlobalReLU()),
        ]))
    net.to(device)

    result = net.forward(graphbatch)

    assert graphbatch.num_graphs == result.num_graphs
    assert graphbatch.num_nodes == result.num_nodes
    assert graphbatch.num_edges == result.num_edges
    assert (graphbatch.num_nodes_by_graph == result.num_nodes_by_graph).all()
    assert (graphbatch.num_edges_by_graph == result.num_edges_by_graph).all()
    assert (graphbatch.senders == result.senders).all()
    assert (graphbatch.receivers == result.receivers).all()
def test_from_graphs(graphs, features_shapes, device):
    graphs = [
        add_random_features(g, **features_shapes).to(device) for g in graphs
    ]
    graphbatch = GraphBatch.from_graphs(graphs)

    validate_batch(graphbatch)

    assert len(graphs) == len(graphbatch) == graphbatch.num_graphs
    assert [g.num_nodes
            for g in graphs] == graphbatch.num_nodes_by_graph.tolist()
    assert [g.num_edges
            for g in graphs] == graphbatch.num_edges_by_graph.tolist()

    # Test sequential access (__iter__)
    for g, gb in zip(graphs, graphbatch):
        assert_graphs_equal(g, gb)

    # Test random access (__getitem__)
    for i in range(len(graphbatch)):
        assert_graphs_equal(graphs[i], graphbatch[i])
Beispiel #5
0
def test_collate_tuples(graphs_nx, features_shapes, device):
    graphs_in = [
        add_random_features(Graph.from_networkx(g),
                            **features_shapes).to(device) for g in graphs_nx
    ]
    graphs_out = list(reversed(graphs_in))
    xs = torch.rand(len(graphs_in), 10, 32)
    ys = torch.rand(len(graphs_in), 7)

    samples = list(zip(graphs_in, xs, ys, graphs_out))
    batch = GraphBatch.collate(samples)

    for g1, g2 in zip(graphs_in, batch[0]):
        assert_graphs_equal(g1, g2)

    torch.testing.assert_allclose(xs, batch[1])

    torch.testing.assert_allclose(ys, batch[2])

    for g1, g2 in zip(graphs_out, batch[3]):
        assert_graphs_equal(g1, g2)
def test_from_to_networkxs(graphs_nx, features_shapes, device):
    graphs_nx = [add_random_features(g, **features_shapes) for g in graphs_nx]
    graphbatch = GraphBatch.from_networkxs(graphs_nx).to(device)

    validate_batch(graphbatch)

    assert len(graphs_nx) == len(graphbatch) == graphbatch.num_graphs
    assert [g.number_of_nodes()
            for g in graphs_nx] == graphbatch.num_nodes_by_graph.tolist()
    assert [g.number_of_edges()
            for g in graphs_nx] == graphbatch.num_edges_by_graph.tolist()

    # Test sequential access (__iter__)
    for g_nx, g in zip(graphs_nx, graphbatch):
        assert_graphs_equal(g_nx, g.cpu())

    # Test random access (__getitem__)
    for i in range(len(graphbatch)):
        assert_graphs_equal(graphs_nx[i], graphbatch[i].cpu())

    # Test back conversion
    graphs_nx_back = graphbatch.cpu().to_networkxs()
    for g1, g2 in zip(graphs_nx, graphs_nx_back):
        assert_graphs_equal(g1, g2)
def test_corner_cases(features_shapes, device):
    # Only some graphs have node/edge features, global features are either present on all of them or absent from all
    gfs = features_shapes['global_features_shape']
    graphs = [
        add_random_features(Graph(num_nodes=0, num_edges=0),
                            global_features_shape=gfs),
        add_random_features(Graph(num_nodes=0, num_edges=0),
                            global_features_shape=gfs),
        add_random_features(Graph(num_nodes=3, num_edges=0),
                            **features_shapes),
        add_random_features(Graph(num_nodes=0, num_edges=0),
                            **features_shapes),
        add_random_features(
            Graph(num_nodes=2,
                  senders=torch.tensor([0, 1]),
                  receivers=torch.tensor([1, 0])), **features_shapes)
    ]
    graphbatch = GraphBatch.from_graphs(graphs).to(device)
    validate_batch(graphbatch)

    for g_orig, g_batch in zip(graphs, graphbatch):
        assert_graphs_equal(g_orig, g_batch.cpu())

    # Global features should be either present on all graphs or absent from all graphs
    with pytest.raises(ValueError):
        GraphBatch.from_graphs([
            Graph(num_nodes=0, num_edges=0),
            add_random_features(Graph(num_nodes=0, num_edges=0),
                                global_features_shape=10)
        ])
    with pytest.raises(ValueError):
        GraphBatch.from_graphs([
            add_random_features(Graph(num_nodes=0, num_edges=0),
                                global_features_shape=10),
            Graph(num_nodes=0, num_edges=0)
        ])
Beispiel #8
0
def test_from_networkx(graph_nx, features_shapes):
    graph_nx = add_random_features(graph_nx, **features_shapes)
    graph = Graph.from_networkx(graph_nx)
    assert_graphs_equal(graph_nx, graph)
Beispiel #9
0
def test_to_networkx(graph, features_shapes):
    graph = add_random_features(graph, **features_shapes)
    graph_nx = graph.to_networkx()
    assert_graphs_equal(graph, graph_nx)