Beispiel #1
0
 def test_to_networkx_list(self, fkey_gkey):
     fkey, gkey = fkey_gkey
     datalist = [random_graph_data(5, 5, 5) for _ in range(3)]
     batch = GraphBatch.from_data_list(datalist)
     graphs = batch.to_networkx_list(feature_key=fkey, global_attr_key=gkey)
     for data, graph in zip(datalist, graphs):
         Comparator.data_to_nx(data, graph, fkey, gkey)
Beispiel #2
0
def test_fully_connected_singe_graph_batch():
    deterministic_seed(0)
    data = GraphData.random(5, 4, 3)
    batch = GraphBatch.from_data_list([data])
    t = FullyConnected()
    batch2 = t(batch)
    assert batch2.edges.shape[1] > batch.edges.shape[1]
Beispiel #3
0
def test_node_mask_entire_graph():
    data1 = GraphData(
        torch.randn(5, 3),
        torch.randn(3, 3),
        torch.randn(1, 1),
        edges=torch.LongTensor([[0, 1, 2], [2, 3, 4]]),
    )
    data2 = GraphData(
        torch.randn(5, 3),
        torch.randn(3, 3),
        torch.randn(1, 1),
        edges=torch.LongTensor([[0, 1, 2], [2, 3, 4]]),
    )
    data3 = GraphData(
        torch.randn(5, 3),
        torch.randn(3, 3),
        torch.randn(1, 1),
        edges=torch.LongTensor([[0, 1, 2], [2, 3, 4]]),
    )

    batch = GraphBatch.from_data_list([data1, data2, data3])
    node_mask = torch.BoolTensor([True] * 5 + [False] * 5 + [True] * 5)
    masked = batch.apply_node_mask(node_mask)
    print(masked.node_idx)
    print(masked.edge_idx)
Beispiel #4
0
    def test_invalid_batch(self, offsets):
        data1 = GraphData(
            torch.randn(10, 10),
            torch.randn(3, 4),
            torch.randn(1, 3),
            torch.randint(0, 10, torch.Size([2, 3])),
        )

        data2 = GraphData(
            torch.randn(12, 10 + offsets[0]),
            torch.randn(4, 4 + offsets[1]),
            torch.randn(1, 3 + offsets[2]),
            torch.randint(0, 10, torch.Size([2, 4])),
        )

        with pytest.raises(RuntimeError):
            GraphBatch.from_data_list([data1, data2])
Beispiel #5
0
 def __getitem__(self, idx: int) -> GraphBatch:
     samples = self.get(idx, transform=None)
     if isinstance(samples, GraphData):
         samples = [samples]
     batch = GraphBatch.from_data_list(samples)
     if self.transform:
         batch = self.transform(batch)
     return batch
Beispiel #6
0
 def test_from_data_list(self, n):
     datalist = [random_graph_data(5, 3, 4) for _ in range(n)]
     batch = GraphBatch.from_data_list(datalist)
     assert batch.x.shape[0] > n
     assert batch.e.shape[0] > n
     assert batch.g.shape[0] == n
     assert batch.x.shape[1] == 5
     assert batch.e.shape[1] == 3
     assert batch.g.shape[1] == 4
Beispiel #7
0
    def test_is_differentiable__to_datalist(self, attr):
        datalist = [random_graph_data(5, 3, 4) for _ in range(300)]

        batch = GraphBatch.from_data_list(datalist)

        getattr(batch, attr).requires_grad = True

        datalist = batch.to_data_list()
        for data in datalist:
            assert getattr(data, attr).requires_grad
Beispiel #8
0
    def validate_plot(self, batch, num_graphs=10):
        x, y = batch
        x.to_data_list()
        y.to_data_list()
        x = GraphBatch.from_data_list(x.to_data_list()[:num_graphs])
        y = GraphBatch.from_data_list(y.to_data_list()[:num_graphs])

        y_hat = self.model.forward(x, 10)[-1]
        y_graphs = y.to_networkx_list()
        y_hat_graphs = y_hat.to_networkx_list()

        figs = []
        for idx in range(len(y_graphs)):
            yg = y_graphs[idx]
            yhg = y_hat_graphs[idx]
            fig, axes = comparison_plot(yhg, yg)
            figs.append(fig)
        with figs_to_pils(figs) as pils:
            self.logger.experiment.log({"image": [wandb.Image(pil) for pil in pils]})
Beispiel #9
0
def random_data_example(request):
    if request.param[0] is GraphData:
        graph_data = random_graph_data(*request.param[1])
        return graph_data
    elif request.param[0] is GraphBatch:
        datalist = [random_graph_data(*request.param[1]) for _ in range(10)]
        batch = GraphBatch.from_data_list(datalist)
        return batch
    else:
        raise Exception("Parameter not acceptable: {}".format(request.param))
Beispiel #10
0
def collate(data_list):
    if isinstance(data_list[0], tuple):
        if issubclass(type(data_list[0][0]), GraphData):
            return tuple([
                collate([x[i] for x in data_list])
                for i in range(len(data_list[0]))
            ])
        else:
            raise RuntimeError("Cannot collate {}({})({})".format(
                type(data_list), type(data_list[0]), type(data_list[0][0])))
    return GraphBatch.from_data_list(data_list)
Beispiel #11
0
    def test_is_differentiable__append_nodes(self, attr):
        datalist = [random_graph_data(5, 3, 4) for _ in range(300)]
        for data in datalist:
            getattr(data, attr).requires_grad = True
        batch = GraphBatch.from_data_list(datalist)

        new_nodes = torch.randn(10, 5)
        idx = torch.ones(10, dtype=torch.long)
        n_nodes = batch.x.shape[0]
        batch.append_nodes(new_nodes, idx)
        assert batch.x.shape[0] == n_nodes + 10
        assert getattr(batch, attr).requires_grad
Beispiel #12
0
def test_fully_connected_singe_graph_batch_manual():
    deterministic_seed(0)
    x = torch.randn((3, 1))
    e = torch.randn((2, 2))
    g = torch.randn((3, 1))
    edges = torch.tensor([[0, 1], [0, 1]])
    data = GraphData(x, e, g, edges)
    batch = GraphBatch.from_data_list([data, data])
    batch2 = FullyConnected()(batch)
    print(batch2.edges)
    assert batch2.edges.shape[1] == 18
    edges_set = _edges_to_tuples_set(batch2.edges)
    assert len(edges_set) == 18
Beispiel #13
0
    def test_batch_append_nodes(self):

        datalist = [random_graph_data(5, 6, 7) for _ in range(10)]
        batch = GraphBatch.from_data_list(datalist)

        x = torch.randn(3, 5)
        idx = torch.tensor([0, 1, 2])

        node_shape = batch.x.shape
        batch.append_nodes(x, idx)
        print(batch.node_idx.shape)

        assert node_shape[0] < batch.x.shape[0]
        print(batch.node_idx.shape)
        print(batch.x.shape)
Beispiel #14
0
def data(request):
    data_cls = request.param
    deterministic_seed(0)

    x = torch.randn((4, 1))
    e = torch.randn((4, 2))
    g = torch.randn((3, 1))

    edges = torch.tensor([[0, 1, 2, 1], [1, 2, 3, 0]])

    data = GraphData(x, e, g, edges)
    if data_cls is GraphBatch:
        return GraphBatch.from_data_list([data])
    else:
        return data
Beispiel #15
0
def test_():

    data1 = GraphData(
        torch.randn(1, 5),
        torch.randn(1, 4),
        torch.randn(1, 3),
        torch.LongTensor([[0], [0]]),
    )

    data2 = GraphData(
        torch.randn(1, 5),
        torch.randn(0, 4),
        torch.randn(1, 3),
        torch.LongTensor([[], []]),
    )

    batch = GraphBatch.from_data_list([data1, data2])
Beispiel #16
0
    def test_is_differentiable__append_edges(self, attr):
        datalist = [random_graph_data(5, 3, 4) for _ in range(300)]
        for data in datalist:
            getattr(data, attr).requires_grad = True
        batch = GraphBatch.from_data_list(datalist)

        new_edge_attr = torch.randn(20, 3)
        new_edges = torch.randint(0, batch.x.shape[0], (2, 20))
        idx = torch.randint(0, 30, (new_edges.shape[1], ))
        idx = torch.sort(idx).values

        n_edges = batch.e.shape[0]
        batch.append_edges(new_edge_attr, new_edges, idx)
        assert batch.e.shape[0] == n_edges + 20
        assert batch.edge_idx.shape[0] == n_edges + 20
        assert batch.edges.shape[1] == n_edges + 20

        assert getattr(batch, attr).requires_grad
Beispiel #17
0
    def test_to_datalist(self):
        datalist = [random_graph_data(5, 6, 7) for _ in range(1000)]
        batch = GraphBatch.from_data_list(datalist)
        print(batch.shape)
        print(batch.size)
        datalist2 = batch.to_data_list()

        assert len(datalist) == len(datalist2)

        def sort(a):
            return a[:, torch.sort(a).indices[0]]

        for data in datalist2:
            print(sort(data.edges))

        for data in datalist2:
            print(sort(data.edges))

        for d1, d2 in zip(datalist, datalist2):
            assert d1.allclose(d2)
Beispiel #18
0
    def test_basic_batch2(self):

        data1 = GraphData(
            torch.tensor([[0], [0]]),
            torch.tensor([[0], [0]]),
            torch.tensor([[0]]),
            torch.tensor([[0, 1], [1, 0]]),
        )

        data2 = GraphData(
            torch.tensor([[0], [0], [0], [0], [0]]),
            torch.tensor([[0], [0], [0]]),
            torch.tensor([[0]]),
            torch.tensor([[1, 2, 1], [4, 2, 1]]),
        )

        batch = GraphBatch.from_data_list([data1, data2])
        print(batch.edges)

        datalist2 = batch.to_data_list()
        print(datalist2[0].edges)
        print(datalist2[1].edges)
Beispiel #19
0
    def test_to_and_from_datalist(self):
        data1 = GraphData(
            torch.randn(4, 2),
            torch.randn(3, 4),
            torch.randn(1, 3),
            torch.randint(0, 4, torch.Size([2, 3])),
        )

        data2 = GraphData(
            torch.randn(2, 2),
            torch.randn(4, 4),
            torch.randn(1, 3),
            torch.randint(0, 2, torch.Size([2, 4])),
        )

        batch = GraphBatch.from_data_list([data1, data2])

        datalist2 = batch.to_data_list()

        print(data1.x)

        print(datalist2[0].x)

        print(data2.x)
        print(datalist2[1].x)

        print(data1.edges)
        print(data2.edges)
        print(datalist2[0].edges)
        print(datalist2[1].edges)

        for d1, d2 in zip([data1, data2], datalist2):
            assert torch.allclose(d1.x, d2.x)
            assert torch.allclose(d1.e, d2.e)
            assert torch.allclose(d1.g, d2.g)
            assert torch.all(torch.eq(d1.edges, d2.edges))
            assert d1.allclose(d2)
Beispiel #20
0
    def test_basic_batch(self):
        data1 = GraphData(
            torch.randn(10, 10),
            torch.randn(3, 4),
            torch.randn(1, 3),
            torch.randint(0, 10, torch.Size([2, 3])),
        )

        data2 = GraphData(
            torch.randn(12, 10),
            torch.randn(4, 4),
            torch.randn(1, 3),
            torch.randint(0, 10, torch.Size([2, 4])),
        )

        batch = GraphBatch.from_data_list([data1, data2])
        assert batch.x.shape[0] == 22
        assert batch.e.shape[0] == 7
        assert batch.edges.shape[1] == 7
        assert batch.g.shape[0] == 2
        assert torch.all(
            torch.eq(batch.node_idx, torch.tensor([0] * 10 + [1] * 12)))
        assert torch.all(
            torch.eq(batch.edge_idx, torch.tensor([0] * 3 + [1] * 4)))
Beispiel #21
0
 def test_from_datalist(self, device):
     batch = GraphBatch.random_batch(2, 5, 4, 3)
     batch = batch.to(device)
     datalist = batch.to_data_list()
     GraphBatch.from_data_list(datalist)
Beispiel #22
0
def collate_list(data_list: List[GraphData]) -> GraphBatch:
    return GraphBatch.from_data_list(data_list)