Example #1
0
def data(request):
    deterministic_seed(0)
    data_cls = request.param
    if data_cls is GraphData:
        return GraphData.random(5, 4, 3)
    else:
        return GraphBatch.random_batch(10, 5, 4, 3)
Example #2
0
def data(request):
    args = (5, 4, 3)
    kwargs = dict(min_nodes=10, max_nodes=10, min_edges=5, max_edges=5)
    if request.param is GraphData:
        return GraphData.random(*args, **kwargs)
    else:
        return GraphBatch.random_batch(100, *args, **kwargs)
Example #3
0
def test_shuffle_graphs(shuffle):
    args = (5, 4, 3)
    kwargs = dict(min_nodes=5, max_nodes=5, min_edges=5, max_edges=5)
    data = GraphBatch.random_batch(100, *args, **kwargs)
    data1, data2 = shuffle(data)
    if data.__class__ is GraphData:
        pytest.xfail("GraphData has no `shuffle_graphs` method")

    assert torch.all(data1.e == data2.e)
    assert not torch.all(data1.g == data2.g)
    assert torch.all(data1.x == data2.x)
    assert torch.all(data1.edges == data2.edges)
    assert not torch.all(data1.node_idx == data2.node_idx)
    assert not torch.all(data1.edge_idx == data2.edge_idx)
Example #4
0
def test_k_hop_random_graph_benchmark(benchmark):
    """Bench mark for using tensor_induce for k-hop.

    :return:
    """
    k = 2
    batch = GraphBatch.random_batch(1000, 50, 20, 30)

    def run():
        nodes = torch.full((batch.num_nodes, ), False, dtype=torch.bool)
        idx = torch.randint(batch.num_nodes, (10, ))
        nodes[idx] = True
        node_mask = tensor_induce(batch, nodes, k)
        subgraph = batch.apply_node_mask(node_mask)

    benchmark(run)
Example #5
0
def test_k_hop_random_graph_benchmark2(benchmark):
    """Benchmark for using floydwarshall for k-hop.

    :return:
    """
    k = 2
    batch = GraphBatch.random_batch(1000, 50, 20, 30)

    def run(n):
        nodes_list = []
        for _ in range(n):
            nodes = torch.full((batch.num_nodes, ), False, dtype=torch.bool)
            idx = torch.randint(batch.num_nodes, (10, ))
            nodes[idx] = True
            nodes_list.append(nodes)
        nodes_list = tuple(nodes_list)

        masks = floyd_warshall_neighbors(batch, nodes_list, depth=k)

        for mask in masks:
            subgraph = batch.apply_node_mask(mask)

    benchmark(run, 1)
Example #6
0
def test_fully_connected_graph_batch():
    deterministic_seed(0)
    batch = GraphBatch.random_batch(100, 5, 4, 3)
    t = FullyConnected()
    batch2 = t(batch)
    assert batch2.edges.shape[1] > batch.edges.shape[1]
Example #7
0
def test_serialize_graph_batch():
    data = GraphBatch.random_batch(100, 5, 4, 3)
    pickle.loads(pickle.dumps(data))
Example #8
0
def test_compose(seeds):
    compose = Compose([RandomNodeMask(0.1), RandomHop(1, 3), Shuffle()])

    data = GraphBatch.random_batch(1000, 5, 4, 3)
    data = compose(data)
Example #9
0
 def test_from_datalist(self, device):
     batch = GraphBatch.random_batch(2, 5, 4, 3)
     batch = batch.to(device)
     datalist = batch.to_data_list()
     GraphBatch.from_data_list(datalist)
Example #10
0
def test_graph_batch_random_batch():
    batch = GraphBatch.random_batch(10, 5, 5, 5)
    print(batch.size)
    print(batch.shape)
def data(seeds, request):
    data = GraphBatch.random_batch(request.param, 5, 4, 3)
    return data