def test_discrete_train_test_split_dynamic_batch(): snapshot_count = 250 node_count = 100 feature_count = 32 graph_count = 10 edge_indices, edge_weights, features, targets, batches = generate_signal( snapshot_count, node_count, feature_count, graph_count) feature = features[0] dataset = DynamicGraphStaticSignalBatch(edge_indices, edge_weights, feature, targets, batches) train_dataset, test_dataset = temporal_signal_split(dataset, 0.8) for epoch in range(2): for snapshot in test_dataset: assert snapshot.edge_index.shape[0] == 2 assert snapshot.edge_index.shape[1] == snapshot.edge_attr.shape[0] assert snapshot.x.shape == (node_count * graph_count, feature_count) assert snapshot.y.shape == (node_count * graph_count, ) for epoch in range(2): for snapshot in train_dataset: assert snapshot.edge_index.shape[0] == 2 assert snapshot.edge_index.shape[1] == snapshot.edge_attr.shape[0] assert snapshot.x.shape == (node_count * graph_count, feature_count) assert snapshot.y.shape == (node_count * graph_count, )
def test_dynamic_graph_static_signal_typing_batch(): dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None], [None]) for snapshot in dataset: assert snapshot.edge_index is None assert snapshot.edge_attr is None assert snapshot.x is None assert snapshot.y is None assert snapshot.batch is None
def test_dynamic_graph_static_signal_batch_additional_attrs(): dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None], [None], optional1=[np.array([1])], optional2=[np.array([2])]) assert dataset.additional_feature_keys == ["optional1", "optional2"] for snapshot in dataset: assert snapshot.optional1.shape == (1, ) assert snapshot.optional2.shape == (1, )