示例#1
0
def test_collate_dicts(graphs_nx, features_shapes, device):
    graphs_in = [
        add_random_features(Graph.from_networkx(g),
                            **features_shapes).to(device) for g in graphs_nx
    ]
    graphs_out = list(reversed(graphs_in))
    xs = torch.rand(len(graphs_in), 10, 32)
    ys = torch.rand(len(graphs_in), 7)

    samples = [{
        'in': gi,
        'x': x,
        'y': y,
        'out': go
    } for gi, x, y, go in zip(graphs_in, xs, ys, graphs_out)]
    batch = GraphBatch.collate(samples)

    for g1, g2 in zip(graphs_in, batch['in']):
        assert_graphs_equal(g1, g2)

    torch.testing.assert_allclose(xs, batch['x'])

    torch.testing.assert_allclose(ys, batch['y'])

    for g1, g2 in zip(graphs_out, batch['out']):
        assert_graphs_equal(g1, g2)
示例#2
0
def test_graph_properties(graph_nx):
    graph_nx = add_dummy_features(graph_nx)
    graph = Graph.from_networkx(graph_nx)

    assert list(graph.degree) == [d for _, d in graph_nx.degree]
    assert list(graph.in_degree) == [d for _, d in graph_nx.in_degree]
    assert list(graph.out_degree) == [d for _, d in graph_nx.out_degree]
示例#3
0
def test_global_functions(graph_nx):
    graph_nx = add_dummy_features(graph_nx)
    graph = Graph.from_networkx(graph_nx)

    assert graph.global_features.shape == graph.global_features_shape
    assert graph.global_features_as_nodes.shape == (
        graph.num_nodes, *graph.global_features_shape)
    assert graph.global_features_as_edges.shape == (
        graph.num_edges, *graph.global_features_shape)
示例#4
0
def test_node_functions(graph_nx):
    graph_nx = add_dummy_features(graph_nx)
    graph = Graph.from_networkx(graph_nx)

    # Features of the outgoing edges
    # By node index
    for node_index in range(graph.num_nodes):
        assert graph.out_edge_features[node_index].shape[
            1:] == graph.edge_features_shape
    # Iterator
    for out_edges in iter(graph.out_edge_features):
        assert out_edges.shape[1:] == graph.edge_features_shape
    # As tensor
    assert graph.out_edge_features(
        aggregation='sum').shape == (graph.num_nodes,
                                     *graph.edge_features_shape)

    # Features of the incoming edges
    # By node index
    for node_index in range(graph.num_nodes):
        assert graph.in_edge_features[node_index].shape[
            1:] == graph.edge_features_shape
    # Iterator
    for in_edges in iter(graph.in_edge_features):
        assert in_edges.shape[1:] == graph.edge_features_shape
    # As tensor
    assert graph.in_edge_features(
        aggregation='sum').shape == (graph.num_nodes,
                                     *graph.edge_features_shape)

    # Features of the successor nodes
    # By node index
    for node_index in range(graph.num_nodes):
        assert graph.successor_features[node_index].shape[
            1:] == graph.node_features_shape
    # Iterator
    for in_edges in iter(graph.successor_features):
        assert in_edges.shape[1:] == graph.node_features_shape
    # As tensor
    assert graph.successor_features(
        aggregation='sum').shape == (graph.num_nodes,
                                     *graph.node_features_shape)

    # Features of the predecessor nodes
    # By node index
    for node_index in range(graph.num_nodes):
        assert graph.predecessor_features[node_index].shape[
            1:] == graph.node_features_shape
    # Iterator
    for in_edges in iter(graph.predecessor_features):
        assert in_edges.shape[1:] == graph.node_features_shape
    # As tensor
    assert graph.predecessor_features(
        aggregation='sum').shape == (graph.num_nodes,
                                     *graph.node_features_shape)
示例#5
0
def test_collate_tuples(graphs_nx, features_shapes, device):
    graphs_in = [
        add_random_features(Graph.from_networkx(g),
                            **features_shapes).to(device) for g in graphs_nx
    ]
    graphs_out = list(reversed(graphs_in))
    xs = torch.rand(len(graphs_in), 10, 32)
    ys = torch.rand(len(graphs_in), 7)

    samples = list(zip(graphs_in, xs, ys, graphs_out))
    batch = GraphBatch.collate(samples)

    for g1, g2 in zip(graphs_in, batch[0]):
        assert_graphs_equal(g1, g2)

    torch.testing.assert_allclose(xs, batch[1])

    torch.testing.assert_allclose(ys, batch[2])

    for g1, g2 in zip(graphs_out, batch[3]):
        assert_graphs_equal(g1, g2)
示例#6
0
def test_edge_functions(graph_nx):
    graph_nx = add_dummy_features(graph_nx)
    graph = Graph.from_networkx(graph_nx)

    # Edge features
    # By edge index
    for edge_index in range(graph.num_edges):
        assert graph.edge_features[
            edge_index].shape == graph.edge_features_shape
    # Iterator
    for edge_features in iter(graph.edge_features):
        assert edge_features.shape == graph.edge_features_shape
    # As tensor
    assert graph.edge_features.shape == (graph.num_edges,
                                         *graph.edge_features_shape)

    # Features of the sender nodes
    # By edge index
    for edge_index in range(graph.num_edges):
        assert graph.sender_features[
            edge_index].shape == graph.node_features_shape
    # Iterator
    for edge_features in graph.sender_features:
        assert edge_features.shape == graph.node_features_shape
    # As tensor
    assert graph.sender_features().shape == (graph.num_edges,
                                             *graph.node_features_shape)

    # Features of the receiver nodes
    # By edge index
    for edge_index in range(graph.num_edges):
        assert graph.receiver_features[
            edge_index].shape == graph.node_features_shape
    # Iterator
    for edge_features in graph.receiver_features:
        assert edge_features.shape == graph.node_features_shape
    # As tensor
    assert graph.receiver_features().shape == (graph.num_edges,
                                               *graph.node_features_shape)
示例#7
0
def test_from_networkx(graph_nx, features_shapes):
    graph_nx = add_random_features(graph_nx, **features_shapes)
    graph = Graph.from_networkx(graph_nx)
    assert_graphs_equal(graph_nx, graph)
示例#8
0
def graphs() -> Sequence[Graph]:
    return [Graph.from_networkx(g) for g in graphs_for_test().values()]
示例#9
0
def graph(request) -> Graph:
    return Graph.from_networkx(request.param)