Ejemplo n.º 1
0
def test_a3tgcn_layer():
    """
    Testing the A3TGCN Layer.
    """
    number_of_nodes = 100
    edge_per_node = 10
    in_channels = 64
    out_channels = 16
    periods = 7

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    X, edge_index = create_mock_attention_data(number_of_nodes, edge_per_node,
                                               in_channels, periods)
    X = X.to(device)
    edge_index = edge_index.to(device)
    edge_weight = create_mock_edge_weight(edge_index).to(device)

    layer = A3TGCN(in_channels=in_channels,
                   out_channels=out_channels,
                   periods=periods).to(device)

    H = layer(X, edge_index)

    assert H.shape == (number_of_nodes, out_channels)

    H = layer(X, edge_index, edge_weight)

    assert H.shape == (number_of_nodes, out_channels)

    H = layer(X, edge_index, edge_weight, H)

    assert H.shape == (number_of_nodes, out_channels)
Ejemplo n.º 2
0
def test_a3tgcn_layer():
    """
    Testing the A3TGCN Layer.
    """
    number_of_nodes = 100
    edge_per_node = 10
    in_channels = 64
    out_channels = 16
    periods = 7

    X, edge_index = create_mock_attention_data(number_of_nodes, edge_per_node,
                                               in_channels, periods)
    edge_weight = create_mock_edge_weight(edge_index)

    layer = A3TGCN(in_channels=in_channels,
                   out_channels=out_channels,
                   periods=periods)

    H = layer(X, edge_index)

    assert H.shape == (number_of_nodes, out_channels)

    H = layer(X, edge_index, edge_weight)

    assert H.shape == (number_of_nodes, out_channels)

    H = layer(X, edge_index, edge_weight, H)

    assert H.shape == (number_of_nodes, out_channels)
Ejemplo n.º 3
0
 def __init__(self, node_features, periods):
     super(RecurrentGCN, self).__init__()
     self.recurrent = A3TGCN(node_features, 32, periods)
     self.linear = torch.nn.Linear(32, 1)