Exemple #1
0
def test_train_test_split_dynamic_graph_static_signal_batch():

    snapshot_count = 250
    node_count = 100
    feature_count = 32
    graph_count = 10

    edge_indices, edge_weights, features, targets, batches = generate_signal(
        snapshot_count, node_count, feature_count, graph_count)

    dataset = StaticGraphTemporalSignalBatch(edge_indices[0], edge_weights[0],
                                             features, targets, batches[0])

    train_dataset, test_dataset = temporal_signal_split(dataset, 0.8)

    for epoch in range(2):
        for snapshot in test_dataset:
            assert snapshot.edge_index.shape[0] == 2
            assert snapshot.edge_index.shape[1] == snapshot.edge_attr.shape[0]
            assert snapshot.x.shape == (node_count * graph_count,
                                        feature_count)
            assert snapshot.y.shape == (node_count * graph_count, )

    for epoch in range(2):
        for snapshot in train_dataset:
            assert snapshot.edge_index.shape[0] == 2
            assert snapshot.edge_index.shape[1] == snapshot.edge_attr.shape[0]
            assert snapshot.x.shape == (node_count * graph_count,
                                        feature_count)
            assert snapshot.y.shape == (node_count * graph_count, )
Exemple #2
0
def test_static_graph_temporal_signal_typing_batch():
    dataset = StaticGraphTemporalSignalBatch(None, None, [np.array([1])],
                                             [np.array([2])], None)
    for snapshot in dataset:
        assert snapshot.edge_index is None
        assert snapshot.edge_attr is None
        assert snapshot.x.shape == (1, )
        assert snapshot.y.shape == (1, )
        assert snapshot.batch is None
Exemple #3
0
def test_static_graph_temporal_signal_batch():
    dataset = StaticGraphTemporalSignalBatch(None, None, [None, None],
                                             [None, None], None)
    for snapshot in dataset:
        assert snapshot.edge_index is None
        assert snapshot.edge_attr is None
        assert snapshot.x is None
        assert snapshot.y is None
        assert snapshot.batch is None
Exemple #4
0
def test_static_graph_temporal_signal_batch_additional_attrs():
    dataset = StaticGraphTemporalSignalBatch(None,
                                             None, [None], [None],
                                             None,
                                             optional1=[np.array([1])],
                                             optional2=[np.array([2])])
    assert dataset.additional_feature_keys == ["optional1", "optional2"]
    for snapshot in dataset:
        assert snapshot.optional1.shape == (1, )
        assert snapshot.optional2.shape == (1, )