Example #1
0
    def test_to_networkx(self, keys):
        kwargs = {"feature_key": "features", "global_attr_key": "data"}
        feature_key, global_attr_key = keys
        if feature_key is not None:
            kwargs["feature_key"] = feature_key
        else:
            del kwargs["feature_key"]
        if global_attr_key is not None:
            kwargs["global_attr_key"] = global_attr_key
        else:
            del kwargs["global_attr_key"]

        data = GraphData(
            torch.randn(10, 5),
            torch.randn(5, 4),
            torch.randn(1, 3),
            torch.tensor([[0, 1, 2, 3, 4], [4, 3, 2, 1, 0]]),
        )

        g = data.to_networkx(**kwargs)
        assert isinstance(g, nx.OrderedMultiDiGraph)
        assert g.number_of_nodes() == 10
        assert g.number_of_edges() == 5

        fkey = kwargs.get("feature_key", "features")
        gkey = kwargs.get("global_attr_key", None)

        Comparator.data_to_nx(data, g, fkey, gkey)
Example #2
0
def test_node_mask_entire_graph():
    data1 = GraphData(
        torch.randn(5, 3),
        torch.randn(3, 3),
        torch.randn(1, 1),
        edges=torch.LongTensor([[0, 1, 2], [2, 3, 4]]),
    )
    data2 = GraphData(
        torch.randn(5, 3),
        torch.randn(3, 3),
        torch.randn(1, 1),
        edges=torch.LongTensor([[0, 1, 2], [2, 3, 4]]),
    )
    data3 = GraphData(
        torch.randn(5, 3),
        torch.randn(3, 3),
        torch.randn(1, 1),
        edges=torch.LongTensor([[0, 1, 2], [2, 3, 4]]),
    )

    batch = GraphBatch.from_data_list([data1, data2, data3])
    node_mask = torch.BoolTensor([True] * 5 + [False] * 5 + [True] * 5)
    masked = batch.apply_node_mask(node_mask)
    print(masked.node_idx)
    print(masked.edge_idx)
Example #3
0
def test_train_shortest_path():
    graphs = [
        generate_shorest_path_example(100, 0.01, 1000) for _ in range(10)
    ]
    input_data = [
        GraphData.from_networkx(g, feature_key="_features") for g in graphs
    ]
    target_data = [
        GraphData.from_networkx(g, feature_key="_target") for g in graphs
    ]

    loader = GraphDataLoader(input_data,
                             target_data,
                             batch_size=32,
                             shuffle=True)

    agg = lambda: Flex(MultiAggregator)(Flex.d(),
                                        ["add", "mean", "max", "min"])

    network = Network()

    for input_batch, _ in loader:
        network(input_batch, 10)
        break

    loss_fn = torch.nn.BCELoss()
    optimizer = torch.optim.AdamW(network.parameters())
    for _ in range(10):
        for input_batch, target_batch in loader:
            output = network(input_batch, 10)[0]
            x, y = output.x, target_batch.x
            loss = loss_fn(x.flatten(), y[:, 0].flatten())
            loss.backward()
            print(loss.detach())
            optimizer.step()
Example #4
0
def test_loader_zipped():
    datalist1 = [GraphData.random(5, 4, 3) for _ in range(32 * 5)]
    datalist2 = [GraphData.random(5, 4, 3) for _ in range(32 * 5)]
    loader = GraphDataLoader(datalist1, datalist2, batch_size=32, shuffle=True)

    for a, b in loader:
        assert isinstance(a, GraphBatch)
        assert isinstance(b, GraphBatch)
        assert a is not b
Example #5
0
    def test_append_edges(self):
        data = GraphData(
            torch.randn(10, 5),
            torch.randn(5, 4),
            torch.randn(1, 3),
            torch.randint(0, 10, torch.Size([2, 5])),
        )

        e = torch.randn(3, 4)
        edges = torch.randint(0, 10, torch.Size([2, 3]))
        data.append_edges(e, edges)
Example #6
0
    def test_append_nodes(self):
        data = GraphData(
            torch.randn(10, 5),
            torch.randn(5, 4),
            torch.randn(1, 3),
            torch.randint(0, 10, torch.Size([2, 5])),
        )

        assert data.x.shape[0] == 10
        data.append_nodes(torch.randn(2, 5))
        assert data.x.shape[0] == 12
Example #7
0
    def test_invalid_append_edges(self):
        data = GraphData(
            torch.randn(10, 5),
            torch.randn(5, 4),
            torch.randn(1, 3),
            torch.randint(0, 10, torch.Size([2, 5])),
        )

        e = torch.randn(3, 4)
        edges = torch.randint(0, 10, torch.Size([2, 4]))
        with pytest.raises(RuntimeError):
            data.append_edges(e, edges)
Example #8
0
 def test_not_eq(self):
     args1 = (
         torch.randn(20, 5),
         torch.randn(3, 4),
         torch.randn(1, 3),
         torch.randint(0, 10, torch.Size([2, 3])),
     )
     args2 = (args1[0][:10], args1[1][:], args1[2][:], args1[3][:])
     data1 = GraphData(*args1)
     data2 = GraphData(*args2)
     assert not data1 == data2
     assert not id(data1) == id(data2)
Example #9
0
def test_generate_shortest_path_example():
    g = generate_shorest_path_example(100, 0.01, 10)

    for n, ndata in g.nodes(data=True):
        assert "source" in ndata
        assert "target" in ndata

    d1 = GraphData.from_networkx(g, feature_key="_features")
    d2 = GraphData.from_networkx(g, feature_key="_target")

    assert tuple(d1.shape) == (4, 1, 1)
    assert tuple(d2.shape) == (2, 2, 1)
Example #10
0
def test_mask_all_nodes():
    deterministic_seed(0)

    data = GraphData(
        torch.randn(5, 5),
        torch.randn(3, 2),
        torch.randn(1, 1),
        edges=torch.LongTensor([[0, 0, 0], [1, 2, 3]]),
    )
    node_mask = torch.BoolTensor([False, False, False, False, False])
    data2 = data.apply_node_mask(node_mask)
    assert data2.num_nodes == 0
    assert data2.num_edges == 0
Example #11
0
def test_mask_no_edges():
    deterministic_seed(0)

    data = GraphData(
        torch.randn(5, 5),
        torch.randn(3, 2),
        torch.randn(1, 1),
        edges=torch.LongTensor([[0, 0, 0], [1, 2, 3]]),
    )
    edge_mask = torch.BoolTensor([True, True, True])
    data2 = data.apply_edge_mask(edge_mask)
    assert data2.num_nodes == 5
    assert data2.num_edges == 3
Example #12
0
    def test_from_networkx_no_edge(self, keys):
        kwargs = {"feature_key": "features", "global_attr_key": "data"}
        feature_key, global_attr_key = keys
        if feature_key is not None:
            kwargs["feature_key"] = feature_key
        else:
            del kwargs["feature_key"]
        if global_attr_key is not None:
            kwargs["global_attr_key"] = global_attr_key
        else:
            del kwargs["global_attr_key"]

        fkey = kwargs.get("feature_key", "features")
        gkey = kwargs.get("global_attr_key", None)

        g = nx.OrderedMultiDiGraph()
        g.add_node("node1", **{fkey: np.random.randn(5)})
        g.add_node("node2", **{fkey: np.random.randn(5)})
        g.ordered_edges = []
        # g.add_edge('node1', 'node2', **{fkey: torch.randn(4)})
        g.set_global({fkey: np.random.randn(3)}, gkey)

        data = GraphData.from_networkx(g, **kwargs)

        Comparator.data_to_nx(data, g, fkey, gkey)
Example #13
0
def data(request):
    deterministic_seed(0)
    data_cls = request.param
    if data_cls is GraphData:
        return GraphData.random(5, 4, 3)
    else:
        return GraphBatch.random_batch(10, 5, 4, 3)
Example #14
0
 def run():
     data = GraphData.random(5, 4, 3, min_edges=1000, min_nodes=1000)
     ij = torch.cat([data.edges])
     sparse_mask = scatter_coo(ij,
                               1,
                               expand=True,
                               size=(data.num_nodes, data.num_nodes))
Example #15
0
def data(request):
    args = (5, 4, 3)
    kwargs = dict(min_nodes=10, max_nodes=10, min_edges=5, max_edges=5)
    if request.param is GraphData:
        return GraphData.random(*args, **kwargs)
    else:
        return GraphBatch.random_batch(100, *args, **kwargs)
Example #16
0
def test_fully_connected_singe_graph_batch():
    deterministic_seed(0)
    data = GraphData.random(5, 4, 3)
    batch = GraphBatch.from_data_list([data])
    t = FullyConnected()
    batch2 = t(batch)
    assert batch2.edges.shape[1] > batch.edges.shape[1]
Example #17
0
 def test_view_graph_data(self, slices):
     data = GraphData.random(5, 5, 5, min_nodes=10, max_nodes=10)
     assert data.shape == (5, 5, 5)
     data_view = data.view(*slices)
     print(data_view.shape)
     # assert data_view.shape == (5, 3, 5)
     assert data_view.share_storage(data)
Example #18
0
 def test_invalid_n_edges(self):
     with pytest.raises(RuntimeError):
         GraphData(
             torch.randn(10, 5),
             torch.randn(5, 4),
             torch.randn(1, 3),
             torch.randint(0, 10, torch.Size([3, 5])),
         )
Example #19
0
def test_loader_first():
    datalist = [GraphData.random(5, 4, 3) for _ in range(32 * 5)]
    loader = GraphDataLoader(datalist, batch_size=32, shuffle=True)

    batch = loader.first()
    assert isinstance(batch, GraphBatch)
    assert batch.shape == (5, 4, 3)
    assert batch.num_graphs == 32
Example #20
0
def test_loader_dataset():
    datalist = [GraphData.random(5, 4, 3) for _ in range(32 * 4)]
    dataset = GraphDataset(datalist)

    for batch in GraphDataLoader(dataset, shuffle=True, batch_size=32):
        print(batch.size)
        assert isinstance(batch, GraphBatch)
        assert batch.size[-1] == 32
Example #21
0
 def test_floyd_warshall_neighbors(self, nodes, depth, return_matrix):
     data = GraphData.random(5, 4, 3, min_edges=300, min_nodes=100)
     print(data.density())
     x = floyd_warshall_neighbors(data,
                                  nodes,
                                  depth,
                                  return_matrix=return_matrix)
     print(x)
Example #22
0
def test_():

    data1 = GraphData(
        torch.randn(1, 5),
        torch.randn(1, 4),
        torch.randn(1, 3),
        torch.LongTensor([[0], [0]]),
    )

    data2 = GraphData(
        torch.randn(1, 5),
        torch.randn(0, 4),
        torch.randn(1, 3),
        torch.LongTensor([[], []]),
    )

    batch = GraphBatch.from_data_list([data1, data2])
Example #23
0
    def test_invalid_batch(self, offsets):
        data1 = GraphData(
            torch.randn(10, 10),
            torch.randn(3, 4),
            torch.randn(1, 3),
            torch.randint(0, 10, torch.Size([2, 3])),
        )

        data2 = GraphData(
            torch.randn(12, 10 + offsets[0]),
            torch.randn(4, 4 + offsets[1]),
            torch.randn(1, 3 + offsets[2]),
            torch.randint(0, 10, torch.Size([2, 4])),
        )

        with pytest.raises(RuntimeError):
            GraphBatch.from_data_list([data1, data2])
Example #24
0
 def test_invalid_global_ndims(self):
     with pytest.raises(RuntimeError):
         GraphData(
             torch.randn(10, 5),
             torch.randn(5, 4),
             torch.randn(1),
             torch.randint(0, 10, torch.Size([2, 5])),
         )
Example #25
0
 def test_invalid_number_of_nodes(self):
     with pytest.raises(RuntimeError):
         GraphData(
             torch.randn(10, 5),
             torch.randn(5, 4),
             torch.randn(1, 3),
             torch.randint(11, 12, torch.Size([2, 6])),
         )
Example #26
0
 def test_invalid_global_shape(self):
     with pytest.raises(RuntimeError):
         GraphData(
             torch.randn(10, 5),
             torch.randn(5, 4),
             torch.randn(3),
             torch.randint(11, 12, torch.Size([2, 6])),
         )
Example #27
0
    def sigmoid_circuit(cls, data_size, batch_size):
        import math

        def func(x):
            return 1 - 1.0 / (1 + math.exp(-x))

        input_data = []
        output_data = []
        for _ in range(data_size):
            n_size = np.random.randint(2, 20)
            tree = nx.random_tree(n_size)

            # randomize node directions
            g = nx.DiGraph()
            for n1, n2, edata in tree.edges(data=True):
                i = np.random.randint(2)
                if i % 2 == 0:
                    g.add_edge(n1, n2)
                else:
                    g.add_edge(n2, n1)
            cls._default_g(g)

            for n in nx_utils.iter_roots(g):
                ndata = g.nodes[n]
                ndata["target"] = np.array(10.0)

            for n in nx.topological_sort(g):
                ndata = g.nodes[n]
                if "target" not in ndata:
                    incoming = []
                    for p in g.predecessors(n):
                        pdata = g.nodes[p]
                        incoming.append(pdata["target"])
                    incoming = np.concatenate(incoming)
                    i = incoming.sum()
                    o = func(i)
                    ndata["target"] = o

            input_data.append(
                GraphData.from_networkx(g, feature_key="features"))
            output_data.append(GraphData.from_networkx(g,
                                                       feature_key="target"))

        return GraphDataLoader(list(zip(input_data, output_data)),
                               batch_size=batch_size)
Example #28
0
def test_mask_one_edges():
    deterministic_seed(0)

    edges = torch.LongTensor([[0, 0, 0], [1, 2, 3]])
    expected_edges = torch.LongTensor([[0, 0], [1, 3]])

    e = torch.randn(3, 2)
    edge_mask = torch.BoolTensor([True, False, True])
    eidx = torch.where(edge_mask)
    expected_e = e[eidx]

    data = GraphData(torch.randn(5, 5), e, torch.randn(1, 1), edges=edges)

    data2 = data.apply_edge_mask(edge_mask)
    assert torch.all(data2.edges == expected_edges)
    assert torch.all(data2.e == expected_e)
    assert torch.all(data2.g == data.g)
    assert torch.all(data2.x == data.x)
Example #29
0
 def test_from_networkx_missing_node_data(self):
     g = nx.DiGraph()
     g.add_node(1)
     g.add_edge(1, 2, feature=np.array([10.0]))
     g.set_global({"feature": np.array([12.0])})
     data = GraphData.from_networkx(g, feature_key="feature")
     assert data.x.shape == (2, 0)
     assert torch.all(data.e == torch.tensor([[10.0]]))
     assert torch.all(data.g == torch.tensor([[12.0]]))
Example #30
0
 def test_from_networkx_missing_glob_data(self):
     g = nx.DiGraph()
     g.add_node(1, feature=np.array([10.0]))
     g.add_node(2, feature=np.array([11.0]))
     g.add_edge(1, 2, feature=np.array([12.0]))
     data = GraphData.from_networkx(g, feature_key="feature")
     assert torch.all(data.x == torch.tensor([[10.0], [11.0]]))
     assert torch.all(data.e == torch.tensor([[12.0]]))
     assert data.g.shape == (1, 0)