Esempio n. 1
0
 def _forward_core(self, latent0, data):
     e = torch.cat([latent0.e, data.e], dim=1)
     x = torch.cat([latent0.x, data.x], dim=1)
     g = torch.cat([latent0.g, data.g], dim=1)
     data = GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
     e, x, g = self.core(data)
     return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
Esempio n. 2
0
 def core_process(self, latent0: GraphBatch, data: GraphBatch) -> GraphBatch:
     e = torch.cat([latent0.e, data.e], dim=1)
     x = torch.cat([latent0.x, data.x], dim=1)
     g = torch.cat([latent0.g, data.g], dim=1)
     data = GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
     e, x, g = self.core(data)
     return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
Esempio n. 3
0
def data(request):
    args = (5, 4, 3)
    kwargs = dict(min_nodes=10, max_nodes=10, min_edges=5, max_edges=5)
    if request.param is GraphData:
        return GraphData.random(*args, **kwargs)
    else:
        return GraphBatch.random_batch(100, *args, **kwargs)
Esempio n. 4
0
 def test_to_networkx_list(self, fkey_gkey):
     fkey, gkey = fkey_gkey
     datalist = [random_graph_data(5, 5, 5) for _ in range(3)]
     batch = GraphBatch.from_data_list(datalist)
     graphs = batch.to_networkx_list(feature_key=fkey, global_attr_key=gkey)
     for data, graph in zip(datalist, graphs):
         Comparator.data_to_nx(data, graph, fkey, gkey)
Esempio n. 5
0
def data(request):
    deterministic_seed(0)
    data_cls = request.param
    if data_cls is GraphData:
        return GraphData.random(5, 4, 3)
    else:
        return GraphBatch.random_batch(10, 5, 4, 3)
Esempio n. 6
0
def test_fully_connected_singe_graph_batch():
    deterministic_seed(0)
    data = GraphData.random(5, 4, 3)
    batch = GraphBatch.from_data_list([data])
    t = FullyConnected()
    batch2 = t(batch)
    assert batch2.edges.shape[1] > batch.edges.shape[1]
Esempio n. 7
0
def test_node_mask_entire_graph():
    data1 = GraphData(
        torch.randn(5, 3),
        torch.randn(3, 3),
        torch.randn(1, 1),
        edges=torch.LongTensor([[0, 1, 2], [2, 3, 4]]),
    )
    data2 = GraphData(
        torch.randn(5, 3),
        torch.randn(3, 3),
        torch.randn(1, 1),
        edges=torch.LongTensor([[0, 1, 2], [2, 3, 4]]),
    )
    data3 = GraphData(
        torch.randn(5, 3),
        torch.randn(3, 3),
        torch.randn(1, 1),
        edges=torch.LongTensor([[0, 1, 2], [2, 3, 4]]),
    )

    batch = GraphBatch.from_data_list([data1, data2, data3])
    node_mask = torch.BoolTensor([True] * 5 + [False] * 5 + [True] * 5)
    masked = batch.apply_node_mask(node_mask)
    print(masked.node_idx)
    print(masked.edge_idx)
Esempio n. 8
0
    def test_invalid_batch(self, offsets):
        data1 = GraphData(
            torch.randn(10, 10),
            torch.randn(3, 4),
            torch.randn(1, 3),
            torch.randint(0, 10, torch.Size([2, 3])),
        )

        data2 = GraphData(
            torch.randn(12, 10 + offsets[0]),
            torch.randn(4, 4 + offsets[1]),
            torch.randn(1, 3 + offsets[2]),
            torch.randint(0, 10, torch.Size([2, 4])),
        )

        with pytest.raises(RuntimeError):
            GraphBatch.from_data_list([data1, data2])
Esempio n. 9
0
 def __getitem__(self, idx: int) -> GraphBatch:
     samples = self.get(idx, transform=None)
     if isinstance(samples, GraphData):
         samples = [samples]
     batch = GraphBatch.from_data_list(samples)
     if self.transform:
         batch = self.transform(batch)
     return batch
Esempio n. 10
0
 def test_from_data_list(self, n):
     datalist = [random_graph_data(5, 3, 4) for _ in range(n)]
     batch = GraphBatch.from_data_list(datalist)
     assert batch.x.shape[0] > n
     assert batch.e.shape[0] > n
     assert batch.g.shape[0] == n
     assert batch.x.shape[1] == 5
     assert batch.e.shape[1] == 3
     assert batch.g.shape[1] == 4
Esempio n. 11
0
def to(batch, device, **kwargs):
    return GraphBatch(
        batch.x.to(device, **kwargs),
        batch.e.to(device, **kwargs),
        batch.g.to(device, **kwargs),
        batch.edges.to(device, **kwargs),
        batch.node_idx.to(device, **kwargs),
        batch.edge_idx.to(device, **kwargs),
    )
Esempio n. 12
0
def random_data_example(request):
    if request.param[0] is GraphData:
        graph_data = random_graph_data(*request.param[1])
        return graph_data
    elif request.param[0] is GraphBatch:
        datalist = [random_graph_data(*request.param[1]) for _ in range(10)]
        batch = GraphBatch.from_data_list(datalist)
        return batch
    else:
        raise Exception("Parameter not acceptable: {}".format(request.param))
Esempio n. 13
0
    def test_is_differentiable__to_datalist(self, attr):
        datalist = [random_graph_data(5, 3, 4) for _ in range(300)]

        batch = GraphBatch.from_data_list(datalist)

        getattr(batch, attr).requires_grad = True

        datalist = batch.to_data_list()
        for data in datalist:
            assert getattr(data, attr).requires_grad
Esempio n. 14
0
    def validate_plot(self, batch, num_graphs=10):
        x, y = batch
        x.to_data_list()
        y.to_data_list()
        x = GraphBatch.from_data_list(x.to_data_list()[:num_graphs])
        y = GraphBatch.from_data_list(y.to_data_list()[:num_graphs])

        y_hat = self.model.forward(x, 10)[-1]
        y_graphs = y.to_networkx_list()
        y_hat_graphs = y_hat.to_networkx_list()

        figs = []
        for idx in range(len(y_graphs)):
            yg = y_graphs[idx]
            yhg = y_hat_graphs[idx]
            fig, axes = comparison_plot(yhg, yg)
            figs.append(fig)
        with figs_to_pils(figs) as pils:
            self.logger.experiment.log({"image": [wandb.Image(pil) for pil in pils]})
Esempio n. 15
0
def create_loader(generator, graphs, batch_size, shuffle, pin_memory=False):
    train_batch = GraphBatch.from_networkx_list(graphs,
                                                n_edge_feat=1,
                                                n_node_feat=generator.n_parts,
                                                n_glob_feat=1)
    target_batch = GraphBatch.from_networkx_list(graphs,
                                                 n_edge_feat=16,
                                                 n_node_feat=1,
                                                 n_glob_feat=1,
                                                 feature_key="target")
    train_list = train_batch.to_data_list()
    target_list = target_batch.to_data_list()
    if batch_size is None:
        batch_size = len(train_list)
    return GraphDataLoader(
        list(zip(train_list, target_list)),
        batch_size=batch_size,
        shuffle=shuffle,
        pin_memory=pin_memory,
    )
Esempio n. 16
0
def collate(data_list):
    if isinstance(data_list[0], tuple):
        if issubclass(type(data_list[0][0]), GraphData):
            return tuple([
                collate([x[i] for x in data_list])
                for i in range(len(data_list[0]))
            ])
        else:
            raise RuntimeError("Cannot collate {}({})({})".format(
                type(data_list), type(data_list[0]), type(data_list[0][0])))
    return GraphBatch.from_data_list(data_list)
Esempio n. 17
0
    def test_is_differentiable__append_nodes(self, attr):
        datalist = [random_graph_data(5, 3, 4) for _ in range(300)]
        for data in datalist:
            getattr(data, attr).requires_grad = True
        batch = GraphBatch.from_data_list(datalist)

        new_nodes = torch.randn(10, 5)
        idx = torch.ones(10, dtype=torch.long)
        n_nodes = batch.x.shape[0]
        batch.append_nodes(new_nodes, idx)
        assert batch.x.shape[0] == n_nodes + 10
        assert getattr(batch, attr).requires_grad
Esempio n. 18
0
def test_fully_connected_singe_graph_batch_manual():
    deterministic_seed(0)
    x = torch.randn((3, 1))
    e = torch.randn((2, 2))
    g = torch.randn((3, 1))
    edges = torch.tensor([[0, 1], [0, 1]])
    data = GraphData(x, e, g, edges)
    batch = GraphBatch.from_data_list([data, data])
    batch2 = FullyConnected()(batch)
    print(batch2.edges)
    assert batch2.edges.shape[1] == 18
    edges_set = _edges_to_tuples_set(batch2.edges)
    assert len(edges_set) == 18
Esempio n. 19
0
    def forward(self, data, steps, save_all: bool = False):
        # encoded
        e, x, g = self.encoder(data)
        data = GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)

        # graph topography data
        edges = data.edges
        node_idx = data.node_idx
        edge_idx = data.edge_idx
        latent0 = data

        meta = (edges, node_idx, edge_idx)

        outputs = []
        for _ in range(steps):
            # core processing step
            e = torch.cat([latent0.e, e], dim=1)
            x = torch.cat([latent0.x, x], dim=1)
            g = torch.cat([latent0.g, g], dim=1)
            data = GraphBatch(x, e, g, *meta)
            e, x, g = self.core(data)

            # decode
            data = GraphBatch(x, e, g, *meta)

            _e, _x, _g = self.decoder(data)
            decoded = GraphBatch(_x, _e, _g, *meta)

            # transform
            _e, _x, _g = self.output_transform(decoded)
            print()
            gt = GraphBatch(_x, _e, _g, edges, node_idx, edge_idx)
            if save_all:
                outputs.append(gt)
            else:
                outputs = [gt]

        return outputs
Esempio n. 20
0
def test_k_hop_random_graph(k):
    g1 = nx.grid_graph(dim=[2, 3, 4])
    g2 = nx.grid_graph(dim=[2, 3, 4])
    g1 = nx_to_directed(g1)
    g2 = nx_to_directed(g2)
    nx_random_features(g1, 5, 4, 3)
    nx_random_features(g2, 5, 4, 3)
    batch = GraphBatch.from_networkx_list([g1, g2])

    nodes = torch.BoolTensor([False] * batch.num_nodes)
    nodes[0] = True
    node_mask = induce(batch, nodes, k)
    subgraph = batch.apply_node_mask(node_mask)
    print(subgraph.info())
Esempio n. 21
0
def test_shuffle_graphs(shuffle):
    args = (5, 4, 3)
    kwargs = dict(min_nodes=5, max_nodes=5, min_edges=5, max_edges=5)
    data = GraphBatch.random_batch(100, *args, **kwargs)
    data1, data2 = shuffle(data)
    if data.__class__ is GraphData:
        pytest.xfail("GraphData has no `shuffle_graphs` method")

    assert torch.all(data1.e == data2.e)
    assert not torch.all(data1.g == data2.g)
    assert torch.all(data1.x == data2.x)
    assert torch.all(data1.edges == data2.edges)
    assert not torch.all(data1.node_idx == data2.node_idx)
    assert not torch.all(data1.edge_idx == data2.edge_idx)
Esempio n. 22
0
def grid_data(request):
    def newg(g):
        return nx_random_features(g, 5, 4, 3)

    if request.param is GraphData:
        g = nx_to_directed(newg(nx.grid_graph([2, 4, 3])))
        return GraphData.from_networkx(g)
    elif request.param is GraphBatch:
        graphs = [
            nx_to_directed(newg(nx.grid_graph([2, 4, 3]))) for _ in range(10)
        ]
        return GraphBatch.from_networkx_list(graphs)
    else:
        raise ValueError()
Esempio n. 23
0
def data(request):
    data_cls = request.param
    deterministic_seed(0)

    x = torch.randn((4, 1))
    e = torch.randn((4, 2))
    g = torch.randn((3, 1))

    edges = torch.tensor([[0, 1, 2, 1], [1, 2, 3, 0]])

    data = GraphData(x, e, g, edges)
    if data_cls is GraphBatch:
        return GraphBatch.from_data_list([data])
    else:
        return data
Esempio n. 24
0
    def test_batch_append_nodes(self):

        datalist = [random_graph_data(5, 6, 7) for _ in range(10)]
        batch = GraphBatch.from_data_list(datalist)

        x = torch.randn(3, 5)
        idx = torch.tensor([0, 1, 2])

        node_shape = batch.x.shape
        batch.append_nodes(x, idx)
        print(batch.node_idx.shape)

        assert node_shape[0] < batch.x.shape[0]
        print(batch.node_idx.shape)
        print(batch.x.shape)
Esempio n. 25
0
def test_k_hop_random_graph_benchmark(benchmark):
    """Bench mark for using tensor_induce for k-hop.

    :return:
    """
    k = 2
    batch = GraphBatch.random_batch(1000, 50, 20, 30)

    def run():
        nodes = torch.full((batch.num_nodes, ), False, dtype=torch.bool)
        idx = torch.randint(batch.num_nodes, (10, ))
        nodes[idx] = True
        node_mask = tensor_induce(batch, nodes, k)
        subgraph = batch.apply_node_mask(node_mask)

    benchmark(run)
Esempio n. 26
0
def test_():

    data1 = GraphData(
        torch.randn(1, 5),
        torch.randn(1, 4),
        torch.randn(1, 3),
        torch.LongTensor([[0], [0]]),
    )

    data2 = GraphData(
        torch.randn(1, 5),
        torch.randn(0, 4),
        torch.randn(1, 3),
        torch.LongTensor([[], []]),
    )

    batch = GraphBatch.from_data_list([data1, data2])
Esempio n. 27
0
    def test_is_differentiable__append_edges(self, attr):
        datalist = [random_graph_data(5, 3, 4) for _ in range(300)]
        for data in datalist:
            getattr(data, attr).requires_grad = True
        batch = GraphBatch.from_data_list(datalist)

        new_edge_attr = torch.randn(20, 3)
        new_edges = torch.randint(0, batch.x.shape[0], (2, 20))
        idx = torch.randint(0, 30, (new_edges.shape[1], ))
        idx = torch.sort(idx).values

        n_edges = batch.e.shape[0]
        batch.append_edges(new_edge_attr, new_edges, idx)
        assert batch.e.shape[0] == n_edges + 20
        assert batch.edge_idx.shape[0] == n_edges + 20
        assert batch.edges.shape[1] == n_edges + 20

        assert getattr(batch, attr).requires_grad
Esempio n. 28
0
    def test_to_datalist(self):
        datalist = [random_graph_data(5, 6, 7) for _ in range(1000)]
        batch = GraphBatch.from_data_list(datalist)
        print(batch.shape)
        print(batch.size)
        datalist2 = batch.to_data_list()

        assert len(datalist) == len(datalist2)

        def sort(a):
            return a[:, torch.sort(a).indices[0]]

        for data in datalist2:
            print(sort(data.edges))

        for data in datalist2:
            print(sort(data.edges))

        for d1, d2 in zip(datalist, datalist2):
            assert d1.allclose(d2)
Esempio n. 29
0
    def test_basic_batch2(self):

        data1 = GraphData(
            torch.tensor([[0], [0]]),
            torch.tensor([[0], [0]]),
            torch.tensor([[0]]),
            torch.tensor([[0, 1], [1, 0]]),
        )

        data2 = GraphData(
            torch.tensor([[0], [0], [0], [0], [0]]),
            torch.tensor([[0], [0], [0]]),
            torch.tensor([[0]]),
            torch.tensor([[1, 2, 1], [4, 2, 1]]),
        )

        batch = GraphBatch.from_data_list([data1, data2])
        print(batch.edges)

        datalist2 = batch.to_data_list()
        print(datalist2[0].edges)
        print(datalist2[1].edges)
Esempio n. 30
0
def test_k_hop_random_graph_benchmark2(benchmark):
    """Benchmark for using floydwarshall for k-hop.

    :return:
    """
    k = 2
    batch = GraphBatch.random_batch(1000, 50, 20, 30)

    def run(n):
        nodes_list = []
        for _ in range(n):
            nodes = torch.full((batch.num_nodes, ), False, dtype=torch.bool)
            idx = torch.randint(batch.num_nodes, (10, ))
            nodes[idx] = True
            nodes_list.append(nodes)
        nodes_list = tuple(nodes_list)

        masks = floyd_warshall_neighbors(batch, nodes_list, depth=k)

        for mask in masks:
            subgraph = batch.apply_node_mask(mask)

    benchmark(run, 1)