def test_dcrnn_layer(): """ Testing the DCRNN Layer. """ number_of_nodes = 100 edge_per_node = 10 in_channels = 64 out_channels = 16 K = 2 X, edge_index = create_mock_data(number_of_nodes, edge_per_node, in_channels) edge_weight = create_mock_edge_weight(edge_index) layer = DCRNN(in_channels=in_channels, out_channels=out_channels, K=K) H = layer(X, edge_index) assert H.shape == (number_of_nodes, out_channels) H = layer(X, edge_index, edge_weight) assert H.shape == (number_of_nodes, out_channels) H = layer(X, edge_index, edge_weight, H) assert H.shape == (number_of_nodes, out_channels)
def test_dcrnn_layer(): """ Testing the DCRNN Layer. """ number_of_nodes = 100 edge_per_node = 10 in_channels = 64 out_channels = 16 K = 2 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") X, edge_index = create_mock_data(number_of_nodes, edge_per_node, in_channels) X = X.to(device) edge_index = edge_index.to(device) edge_weight = create_mock_edge_weight(edge_index).to(device) layer = DCRNN(in_channels=in_channels, out_channels=out_channels, K=K).to(device) H = layer(X, edge_index) assert H.shape == (number_of_nodes, out_channels) H = layer(X, edge_index, edge_weight) assert H.shape == (number_of_nodes, out_channels) H = layer(X, edge_index, edge_weight, H) assert H.shape == (number_of_nodes, out_channels) layer = DCRNN(in_channels=in_channels, out_channels=out_channels, K=3).to(device) H = layer(X, edge_index, edge_weight, H) assert H.shape == (number_of_nodes, out_channels) H = layer(X, edge_index, edge_weight) assert H.shape == (number_of_nodes, out_channels) H = layer(X, edge_index, edge_weight, H) assert H.shape == (number_of_nodes, out_channels)
def __init__(self, node_features): super(RecurrentGCN, self).__init__() self.recurrent = DCRNN(node_features, 32, 1) self.linear = torch.nn.Linear(32, 1)
def __init__(self, node_features, filters): super().__init__() self.recurrent = DCRNN(node_features, filters, 1) self.linear = torch.nn.Linear(filters, 1)
def __init__(self, node_features, num_classes): super(RecurrentGCN, self).__init__() self.recurrent_1 = DCRNN(node_features, 32, 5) self.recurrent_2 = DCRNN(32, 16, 5) self.linear = torch.nn.Linear(16, num_classes)