Exemple #1
0
def test_discrete_train_test_split_dynamic_batch():

    snapshot_count = 250
    node_count = 100
    feature_count = 32
    graph_count = 10

    edge_indices, edge_weights, features, targets, batches = generate_signal(
        snapshot_count, node_count, feature_count, graph_count)

    dataset = DynamicGraphTemporalSignalBatch(edge_indices, edge_weights,
                                              features, targets, batches)

    train_dataset, test_dataset = temporal_signal_split(dataset, 0.8)

    for epoch in range(2):
        for snapshot in test_dataset:
            assert snapshot.edge_index.shape[0] == 2
            assert snapshot.edge_index.shape[1] == snapshot.edge_attr.shape[0]
            assert snapshot.x.shape == (node_count * graph_count,
                                        feature_count)
            assert snapshot.y.shape == (node_count * graph_count, )

    for epoch in range(2):
        for snapshot in train_dataset:
            assert snapshot.edge_index.shape[0] == 2
            assert snapshot.edge_index.shape[1] == snapshot.edge_attr.shape[0]
            assert snapshot.x.shape == (node_count * graph_count,
                                        feature_count)
            assert snapshot.y.shape == (node_count * graph_count, )
Exemple #2
0
def test_dynamic_graph_temporal_signal_batch_additional_attrs():
    dataset = DynamicGraphTemporalSignalBatch([None], [None], [None], [None],
                                              [None],
                                              optional1=[np.array([1])],
                                              optional2=[np.array([2])])
    assert dataset.additional_feature_keys == ["optional1", "optional2"]
    for snapshot in dataset:
        assert snapshot.optional1.shape == (1, )
        assert snapshot.optional2.shape == (1, )
Exemple #3
0
def test_dynamic_graph_temporal_signal_batch():
    dataset = DynamicGraphTemporalSignalBatch([None, None], [None, None],
                                              [None, None], [None, None],
                                              [None, None])
    for snapshot in dataset:
        assert snapshot.edge_index is None
        assert snapshot.edge_attr is None
        assert snapshot.x is None
        assert snapshot.y is None
        assert snapshot.batch is None
Exemple #4
0
def test_dynamic_graph_temporal_signal_real_batch():

    snapshot_count = 250
    n_count = 100
    feature_count = 32
    graph_count = 10

    edge_indices, edge_weights, features, targets, batches = generate_signal(
        250, 100, 32, graph_count)

    dataset = DynamicGraphTemporalSignalBatch(edge_indices, edge_weights,
                                              features, targets, batches)

    for epoch in range(2):
        for snapshot in dataset:
            assert snapshot.edge_index.shape[0] == 2
            assert snapshot.edge_index.shape[1] == snapshot.edge_attr.shape[0]
            assert snapshot.x.shape == (1000, 32)
            assert snapshot.y.shape == (1000, )
            assert snapshot.batch.shape == (1000, )