예제 #1
0
def test_fully_connected_singe_graph_batch():
    deterministic_seed(0)
    data = GraphData.random(5, 4, 3)
    batch = GraphBatch.from_data_list([data])
    t = FullyConnected()
    batch2 = t(batch)
    assert batch2.edges.shape[1] > batch.edges.shape[1]
예제 #2
0
파일: conftest.py 프로젝트: jvrana/caldera
def create_data_constructor(param):
    default_args = (5, 4, 3)
    default_batch_size = 100
    default_kwargs = {}

    args = None
    kwargs = None
    if isinstance(param, tuple):
        if len(param) == 2:
            data_cls, seed = param
        elif len(param) == 3:
            data_cls, seed, args = param
            kwargs = {}
        else:
            data_cls, seed, args, kwargs = param
        if seed is not None:
            deterministic_seed(seed)
    else:
        data_cls = param
        deterministic_seed(0)

    if args is None:
        if data_cls is GraphBatch:
            args = (default_batch_size, ) + default_args
        else:
            args = default_args
    if kwargs is None:
        kwargs = default_kwargs

    if hasattr(data_cls, GraphBatch.random_batch.__name__):
        rndm_func = data_cls.random_batch
    else:
        rndm_func = data_cls.random
    return partial(rndm_func, *args, **kwargs)
예제 #3
0
def data(request):
    deterministic_seed(0)
    data_cls = request.param
    if data_cls is GraphData:
        return GraphData.random(5, 4, 3)
    else:
        return GraphBatch.random_batch(10, 5, 4, 3)
예제 #4
0
def test_bfs_edges_call_signature(src, d):
    deterministic_seed(0)
    edges = torch.randint(20, (2, 1000))

    nodes = bfs_nodes(src, edges, depth=d)
    if d is None:
        assert len(nodes) == len(torch.unique(edges))
    elif d == 0:
        assert len(nodes) == 0
    print(nodes)
예제 #5
0
def test_shuffle_graph_data(random_data):
    deterministic_seed(0)
    shuffle = Shuffle()
    shuffled = shuffle(random_data)
    assert not torch.allclose(shuffled.x, random_data.x)
    assert not torch.allclose(shuffled.e, random_data.e)
    assert not torch.allclose(shuffled.edges, random_data.edges)
    if random_data.__class__ is GraphBatch:
        assert not torch.allclose(shuffled.g, random_data.g)
        assert not torch.all(shuffled.node_idx == random_data.node_idx)
        assert not torch.all(shuffled.edge_idx == random_data.edge_idx)
예제 #6
0
def test_mask_no_edges():
    deterministic_seed(0)

    data = GraphData(
        torch.randn(5, 5),
        torch.randn(3, 2),
        torch.randn(1, 1),
        edges=torch.LongTensor([[0, 0, 0], [1, 2, 3]]),
    )
    edge_mask = torch.BoolTensor([True, True, True])
    data2 = data.apply_edge_mask(edge_mask)
    assert data2.num_nodes == 5
    assert data2.num_edges == 3
예제 #7
0
def test_mask_all_nodes():
    deterministic_seed(0)

    data = GraphData(
        torch.randn(5, 5),
        torch.randn(3, 2),
        torch.randn(1, 1),
        edges=torch.LongTensor([[0, 0, 0], [1, 2, 3]]),
    )
    node_mask = torch.BoolTensor([False, False, False, False, False])
    data2 = data.apply_node_mask(node_mask)
    assert data2.num_nodes == 0
    assert data2.num_edges == 0
예제 #8
0
def test_fully_connected_singe_graph_batch_manual():
    deterministic_seed(0)
    x = torch.randn((3, 1))
    e = torch.randn((2, 2))
    g = torch.randn((3, 1))
    edges = torch.tensor([[0, 1], [0, 1]])
    data = GraphData(x, e, g, edges)
    batch = GraphBatch.from_data_list([data, data])
    batch2 = FullyConnected()(batch)
    print(batch2.edges)
    assert batch2.edges.shape[1] == 18
    edges_set = _edges_to_tuples_set(batch2.edges)
    assert len(edges_set) == 18
예제 #9
0
파일: conftest.py 프로젝트: jvrana/caldera
def seeds(request):
    """Example usage of fixture.

    .. code-block::

        @pytest.mark.parametrize("seeds", list(range(10)), ids=lambda x: "seed" + str(x), indirect=True)
        def test_foo(seeds):
            pass # do stuff


    :param request:
    :return:
    """
    deterministic_seed(request.param)
예제 #10
0
def data(request):
    data_cls = request.param
    deterministic_seed(0)

    x = torch.randn((4, 1))
    e = torch.randn((4, 2))
    g = torch.randn((3, 1))

    edges = torch.tensor([[0, 1, 2, 1], [1, 2, 3, 0]])

    data = GraphData(x, e, g, edges)
    if data_cls is GraphBatch:
        return GraphBatch.from_data_list([data])
    else:
        return data
예제 #11
0
def test_k_hop(edges, k, source, expected):
    deterministic_seed(0)
    data = GraphData.random(
        5,
        4,
        3,
        min_nodes=10,
        max_nodes=10,
        min_edges=edges.shape[1],
        max_edges=edges.shape[1],
    )
    data.edges = edges
    data.debug()

    res = induce(data, source, k)
    print(res)
    assert torch.all(res == expected)
예제 #12
0
def test_mask_one_edges():
    deterministic_seed(0)

    edges = torch.LongTensor([[0, 0, 0], [1, 2, 3]])
    expected_edges = torch.LongTensor([[0, 0], [1, 3]])

    e = torch.randn(3, 2)
    edge_mask = torch.BoolTensor([True, False, True])
    eidx = torch.where(edge_mask)
    expected_e = e[eidx]

    data = GraphData(torch.randn(5, 5), e, torch.randn(1, 1), edges=edges)

    data2 = data.apply_edge_mask(edge_mask)
    assert torch.all(data2.edges == expected_edges)
    assert torch.all(data2.e == expected_e)
    assert torch.all(data2.g == data.g)
    assert torch.all(data2.x == data.x)
예제 #13
0
def test_mask_one_node():
    deterministic_seed(0)

    edges = torch.LongTensor([[0, 1, 0], [1, 2, 3]])
    expected_edges = torch.LongTensor([[0], [1]])

    node_mask = torch.BoolTensor([False, True, True, True, True])

    x = torch.randn(5, 5)
    expected_x = x[node_mask]

    e = torch.randn(3, 2)
    expected_e = e[torch.LongTensor([1])]

    data = GraphData(x, e, torch.randn(1, 1), edges=edges)

    data2 = data.apply_node_mask(node_mask)
    assert torch.all(data2.edges == expected_edges)

    print(data2.x)
    print(expected_x)
    assert torch.allclose(data2.x, expected_x)
    assert torch.allclose(data2.e, expected_e)
    assert torch.allclose(data2.g, data.g)
예제 #14
0
def test_fully_connected_singe_graph_batch():
    deterministic_seed(0)
    data = GraphData.random(5, 4, 3)
    t = FullyConnected()
    data2 = t(data)
    assert data2.edges.shape[1] > data.edges.shape[1]
예제 #15
0
def test_fully_connected_graph_batch():
    deterministic_seed(0)
    batch = GraphBatch.random_batch(100, 5, 4, 3)
    t = FullyConnected()
    batch2 = t(batch)
    assert batch2.edges.shape[1] > batch.edges.shape[1]
예제 #16
0
파일: conf.py 프로젝트: jvrana/caldera
# flake8: noqa
import os
import sys

sys.path.insert(0, os.path.abspath("../.."))


# -- Project information -----------------------------------------------------
import caldera as pkg
from caldera.utils import deterministic_seed

deterministic_seed(0)

import datetime

now = datetime.datetime.now()
project = pkg.__title__
authors = pkg.__authors__
copyright = "{year}, {authors}".format(year=now.year, authors=",".join(authors))
author = authors[0]
release = pkg.__version__

# -- General configuration ---------------------------------------------------
autosummary_generate = (
    True  # glob.glob("*.rst")  # Make _autosummary files and include them
)
autoclass_content = "both"  # include both class docstring and __init__

autodoc_default_options = {
    "member-order": "bysource",
    "special-members": "__init__",
예제 #17
0
 def seed(self, seed: int = SEED):
     deterministic_seed(seed)