def test_metrla(): loader = METRLADatasetLoader(raw_data_dir="/tmp/") dataset = loader.get_dataset() for epoch in range(2): for snapshot in dataset: assert snapshot.edge_index.shape == (2, 1722) assert snapshot.edge_attr.shape == (1722, ) assert snapshot.x.shape == (207, 2, 12) assert snapshot.y.shape == (207, 12)
def test_metrla_task_generator(): loader = METRLADatasetLoader(raw_data_dir="/tmp/") dataset = loader.get_dataset(num_timesteps_in=6, num_timesteps_out=5) for epoch in range(2): for snapshot in dataset: assert snapshot.edge_index.shape == (2, 1722) assert snapshot.edge_attr.shape == (1722, ) assert snapshot.x.shape == (207, 2, 6) assert snapshot.y.shape == (207, 5)
from torch_geometric.nn import GCNConv from torch_geometric_temporal.nn.recurrent import A3TGCN2 # GPU support DEVICE = torch.device('cuda') # cuda shuffle = True batch_size = 32 #Dataset #Traffic forecasting dataset based on Los Angeles Metropolitan traffic #207 loop detectors on highways #March 2012 - June 2012 #From the paper: Diffusion Convolutional Recurrent Neural Network from torch_geometric_temporal.dataset import METRLADatasetLoader loader = METRLADatasetLoader() dataset = loader.get_dataset(num_timesteps_in=12, num_timesteps_out=12) print("Dataset type: ", dataset) print("Number of samples / sequences: ", len(set(dataset))) # Visualize traffic over time sensor_number = 1 hours = 24 sensor_labels = [ bucket.y[sensor_number][0].item() for bucket in list(dataset)[:hours] ] plt.plot(sensor_labels) # Train test split from torch_geometric_temporal.signal import temporal_signal_split