Ejemplo n.º 1
0
def test_discrete_train_test_split_dynamic_batch():

    snapshot_count = 250
    node_count = 100
    feature_count = 32
    graph_count = 10

    edge_indices, edge_weights, features, targets, batches = generate_signal(
        snapshot_count, node_count, feature_count, graph_count)

    feature = features[0]

    dataset = DynamicGraphStaticSignalBatch(edge_indices, edge_weights,
                                            feature, targets, batches)

    train_dataset, test_dataset = temporal_signal_split(dataset, 0.8)

    for epoch in range(2):
        for snapshot in test_dataset:
            assert snapshot.edge_index.shape[0] == 2
            assert snapshot.edge_index.shape[1] == snapshot.edge_attr.shape[0]
            assert snapshot.x.shape == (node_count * graph_count,
                                        feature_count)
            assert snapshot.y.shape == (node_count * graph_count, )

    for epoch in range(2):
        for snapshot in train_dataset:
            assert snapshot.edge_index.shape[0] == 2
            assert snapshot.edge_index.shape[1] == snapshot.edge_attr.shape[0]
            assert snapshot.x.shape == (node_count * graph_count,
                                        feature_count)
            assert snapshot.y.shape == (node_count * graph_count, )
Ejemplo n.º 2
0
def test_dynamic_graph_static_signal_typing_batch():
    dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None],
                                            [None])
    for snapshot in dataset:
        assert snapshot.edge_index is None
        assert snapshot.edge_attr is None
        assert snapshot.x is None
        assert snapshot.y is None
        assert snapshot.batch is None
Ejemplo n.º 3
0
def test_dynamic_graph_static_signal_batch_additional_attrs():
    dataset = DynamicGraphStaticSignalBatch([None], [None],
                                            None, [None], [None],
                                            optional1=[np.array([1])],
                                            optional2=[np.array([2])])
    assert dataset.additional_feature_keys == ["optional1", "optional2"]
    for snapshot in dataset:
        assert snapshot.optional1.shape == (1, )
        assert snapshot.optional2.shape == (1, )