Esempio n. 1
0
def test_laplacian_lambda_max():
    out = LaplacianLambdaMax().__repr__()
    assert out == 'LaplacianLambdaMax(normalization=None)'

    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
    edge_attr = torch.tensor([1, 1, 2, 2], dtype=torch.float)

    data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3)
    out = LaplacianLambdaMax(normalization=None, is_undirected=True)(data)
    assert len(out) == 4
    assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(4.732049))

    data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3)
    out = LaplacianLambdaMax(normalization='sym', is_undirected=True)(data)
    assert len(out) == 4
    assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(2.0))

    data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3)
    out = LaplacianLambdaMax(normalization='rw', is_undirected=True)(data)
    assert len(out) == 4
    assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(2.0))

    data = Data(edge_index=edge_index,
                edge_attr=torch.randn(4, 2),
                num_nodes=3)
    out = LaplacianLambdaMax(normalization=None)(data)
    assert len(out) == 4
    assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(3.0))
Esempio n. 2
0
    def forward(
        self, X: torch.FloatTensor, edge_index: Union[torch.LongTensor,
                                                      List[torch.LongTensor]]
    ) -> torch.FloatTensor:
        """
        Making a forward pass with the ASTGCN block.
 
        Arg types:
            * **X** (PyTorch Float Tensor) - Node features for T time periods, with shape (B, N_nodes, F_in, T_in).
            * **edge_index** (LongTensor): Edge indices, can be an array of a list of Tensor arrays, depending on whether edges change over time.

        Return types:
            * **X** (PyTorch Float Tensor) - Hidden state tensor for all nodes, with shape (B, N_nodes, nb_time_filter, T_out).
        """
        batch_size, num_of_vertices, num_of_features, num_of_timesteps = X.shape

        X_tilde = self._temporal_attention(X)
        X_tilde = torch.matmul(X.reshape(batch_size, -1, num_of_timesteps),
                               X_tilde)
        X_tilde = X_tilde.reshape(batch_size, num_of_vertices, num_of_features,
                                  num_of_timesteps)
        X_tilde = self._spatial_attention(X_tilde)

        if not isinstance(edge_index, list):
            data = Data(edge_index=edge_index,
                        edge_attr=None,
                        num_nodes=num_of_vertices)
            lambda_max = LaplacianLambdaMax()(data).lambda_max
            X_hat = []
            for t in range(num_of_timesteps):
                X_hat.append(
                    torch.unsqueeze(
                        self._chebconv_attention(X[:, :, :, t],
                                                 edge_index,
                                                 X_tilde,
                                                 lambda_max=lambda_max), -1))

            X_hat = F.relu(torch.cat(X_hat, dim=-1))
        else:
            X_hat = []
            for t in range(num_of_timesteps):
                data = Data(edge_index=edge_index[t],
                            edge_attr=None,
                            num_nodes=num_of_vertices)
                lambda_max = LaplacianLambdaMax()(data).lambda_max
                X_hat.append(
                    torch.unsqueeze(
                        self._chebconv_attention(X[:, :, :, t],
                                                 edge_index[t],
                                                 X_tilde,
                                                 lambda_max=lambda_max), -1))
            X_hat = F.relu(torch.cat(X_hat, dim=-1))

        X_hat = self._time_convolution(X_hat.permute(0, 2, 1, 3))
        X = self._residual_convolution(X.permute(0, 2, 1, 3))
        X = self._layer_norm(F.relu(X + X_hat).permute(0, 3, 2, 1))
        X = X.permute(0, 2, 3, 1)
        return X
Esempio n. 3
0
    def forward(self, x, edge_index):
        """
        Making a forward pass. This is one MSTGCN block.
        B is the batch size. N_nodes is the number of nodes in the graph. F_in is the dimension of input features. 
        T_in is the length of input sequence in time. T_out is the length of output sequence in time.
        nb_time_filter is the number of time filters used.
        Arg types:
            * x (PyTorch Float Tensor) - Node features for T time periods, with shape (B, N_nodes, F_in, T_in).
            * edge_index (Tensor): Edge indices, can be an array of a list of Tensor arrays, depending on whether edges change over time.

        Return types:
            * output (PyTorch Float Tensor) - Hidden state tensor for all nodes, with shape (B, N_nodes, nb_time_filter, T_out).
        """
        # cheb gcn
        batch_size, num_of_vertices, in_channels, num_of_timesteps = x.shape
        if not isinstance(edge_index, list):
            data = Data(edge_index=edge_index,
                        edge_attr=None,
                        num_nodes=num_of_vertices)
            lambda_max = LaplacianLambdaMax()(data).lambda_max
            tmp = x.permute(2, 0, 1, 3).reshape(
                num_of_vertices, in_channels,
                num_of_timesteps * batch_size)  # (N_nodes, F_in, B*T_in)
            tmp = tmp.permute(2, 0, 1)  # (B*T_in, N_nodes, F_in)
            output = F.relu(
                self.cheb_conv(x=tmp,
                               edge_index=edge_index,
                               lambda_max=lambda_max))
            spatial_gcn = output.permute(1, 2, 0).reshape(
                num_of_vertices, self.nb_time_filter, batch_size,
                num_of_timesteps).permute(2, 0, 1, 3)  # (B,N_nodes,F_out,T_in)
        else:  # edge_index changes over time
            outputs = []
            for time_step in range(num_of_timesteps):
                data = Data(edge_index=edge_index[time_step],
                            edge_attr=None,
                            num_nodes=num_of_vertices)
                lambda_max = LaplacianLambdaMax()(data).lambda_max
                outputs.append(
                    torch.unsqueeze(
                        self.cheb_conv(x=x[:, :, :, time_step],
                                       edge_index=edge_index[time_step],
                                       lambda_max=lambda_max), -1))
            spatial_gcn = F.relu(torch.cat(outputs, dim=-1))  # (b,N,F,T)

        # convolution along the time axis
        time_conv_output = self.time_conv(spatial_gcn.permute(0, 2, 1,
                                                              3))  # (b,F,N,T)

        # residual shortcut
        x_residual = self.residual_conv(x.permute(0, 2, 1, 3))  # (b,F,N,T)

        x_residual = self.ln(
            F.relu(x_residual + time_conv_output).permute(0, 3, 2, 1)).permute(
                0, 2, 3, 1)  # (b,N,F,T)

        return x_residual
    def forward(self, X: torch.FloatTensor,
                edge_index: torch.LongTensor) -> torch.FloatTensor:
        """
        Making a forward pass with a single MSTGCN block.

        Arg types:
            * X (PyTorch FloatTensor) - Node features for T time periods, with shape (B, N_nodes, F_in, T_in).
            * edge_index (PyTorch LongTensor): Edge indices, can be an array of a list of Tensor arrays, depending on whether edges change over time.

        Return types:
            * X (PyTorch FloatTensor) - Hidden state tensor for all nodes, with shape (B, N_nodes, nb_time_filter, T_out).
        """

        batch_size, num_of_vertices, in_channels, num_of_timesteps = X.shape

        if not isinstance(edge_index, list):

            lambda_max = LaplacianLambdaMax()(Data(
                edge_index=edge_index,
                edge_attr=None,
                num_nodes=num_of_vertices)).lambda_max

            X_tilde = X.permute(2, 0, 1, 3)
            X_tilde = X_tilde.reshape(num_of_vertices, in_channels,
                                      num_of_timesteps * batch_size)
            X_tilde = X_tilde.permute(2, 0, 1)
            X_tilde = F.relu(
                self._cheb_conv(x=X_tilde,
                                edge_index=edge_index,
                                lambda_max=lambda_max))
            X_tilde = X_tilde.permute(1, 2, 0)
            X_tilde = X_tilde.reshape(num_of_vertices, self.nb_time_filter,
                                      batch_size, num_of_timesteps)
            X_tilde = X_tilde.permute(2, 0, 1, 3)

        else:
            X_tilde = []
            for t in range(num_of_timesteps):
                lambda_max = LaplacianLambdaMax()(Data(
                    edge_index=edge_index[t],
                    edge_attr=None,
                    num_nodes=num_of_vertices,
                )).lambda_max
                X_tilde.append(
                    torch.unsqueeze(
                        self._cheb_conv(X[:, :, :, t],
                                        edge_index[t],
                                        lambda_max=lambda_max),
                        -1,
                    ))
            X_tilde = F.relu(torch.cat(X_tilde, dim=-1))

        X_tilde = self._time_conv(X_tilde.permute(0, 2, 1, 3))
        X = self._residual_conv(X.permute(0, 2, 1, 3))
        X = self._layer_norm(F.relu(X + X_tilde).permute(0, 3, 2, 1))
        X = X.permute(0, 2, 3, 1)
        return X
Esempio n. 5
0
def test_chebconvatt():
    """
    Testing ChebCOnvAtt block
    """
    node_count = 307
    num_classes = 10
    edge_per_node = 15

    len_input = 12

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    node_features = 2
    K = 3
    nb_chev_filter = 64
    batch_size = 32

    x, edge_index = create_mock_data(node_count, edge_per_node, node_features)
    model = ChebConvAtt(node_features, nb_chev_filter, K)
    spatial_attention = torch.rand(batch_size, node_count, node_count)
    spatial_attention = torch.nn.functional.softmax(spatial_attention, dim=1)
    model.train()
    T = len_input
    x_seq = torch.zeros([batch_size, node_count, node_features, T]).to(device)
    target_seq = torch.zeros([batch_size, node_count, T]).to(device)
    for b in range(batch_size):
        for t in range(T):
            x, edge_index = create_mock_data(node_count, edge_per_node,
                                             node_features)
            x_seq[b, :, :, t] = x
            target = create_mock_target(node_count, num_classes)
            target_seq[b, :, t] = target
    shuffle = True
    train_dataset = torch.utils.data.TensorDataset(x_seq, target_seq)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=shuffle)
    for batch_data in train_loader:
        encoder_inputs, labels = batch_data
        data = Data(edge_index=edge_index,
                    edge_attr=None,
                    num_nodes=node_count)
        lambda_max = LaplacianLambdaMax()(data).lambda_max
        outputs = []
        for time_step in range(T):
            outputs.append(
                torch.unsqueeze(
                    model(encoder_inputs[:, :, :, time_step],
                          edge_index,
                          spatial_attention,
                          lambda_max=lambda_max), -1))
        spatial_gcn = torch.nn.functional.relu(torch.cat(
            outputs, dim=-1))  # (b,N,F,T) # (b,N,F,T)
    assert spatial_gcn.shape == (batch_size, node_count, nb_chev_filter, T)
Esempio n. 6
0
    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: Union[torch.LongTensor, List[torch.LongTensor]],
    ) -> torch.FloatTensor:
        """
        Making a forward pass with the ASTGCN block.

        Arg types:
            * **X** (PyTorch Float Tensor) - Node features for T time periods, with shape (B, N_nodes, F_in, T_in).
            * **edge_index** (LongTensor): Edge indices, can be an array of a list of Tensor arrays, depending on whether edges change over time.

        Return types:
            * **X** (PyTorch Float Tensor) - Hidden state tensor for all nodes, with shape (B, N_nodes, nb_time_filter, T_out).
        """
        batch_size, num_of_vertices, num_of_features, num_of_timesteps = X.shape  # (32, 307, 1, 12)

        X_tilde = self._temporal_attention(
            X
        )  # (b, T, T)  (32, 12, 12) * reshaped x(32, 307, 12)  -reshape> (32, 307, 1, 12)
        # xreshaped is e.g. (32, 307, 12) * (32, 12, 12) -then_reshaped> (32, 307, 1, 12)
        X_tilde = torch.matmul(X.reshape(batch_size, -1, num_of_timesteps),
                               X_tilde)
        X_tilde = X_tilde.reshape(batch_size, num_of_vertices, num_of_features,
                                  num_of_timesteps)
        X_tilde = self._spatial_attention(
            X_tilde)  # (B,N,N) for example (32, 307, 307)

        if not isinstance(edge_index, list):
            data = Data(edge_index=edge_index,
                        edge_attr=None,
                        num_nodes=num_of_vertices)
            if self._normalization != "sym":
                lambda_max = LaplacianLambdaMax()(data).lambda_max
            else:
                lambda_max = None
            X_hat = []
            for t in range(num_of_timesteps):
                X_hat.append(
                    torch.unsqueeze(
                        self._chebconv_attention(X[:, :, :, t],
                                                 edge_index,
                                                 X_tilde,
                                                 lambda_max=lambda_max),
                        -1,
                    ))

            X_hat = F.relu(torch.cat(X_hat, dim=-1))
        else:
            X_hat = []
            for t in range(num_of_timesteps):
                data = Data(edge_index=edge_index[t],
                            edge_attr=None,
                            num_nodes=num_of_vertices)
                if self._normalization != "sym":
                    lambda_max = LaplacianLambdaMax()(data).lambda_max
                else:
                    lambda_max = None
                X_hat.append(
                    torch.unsqueeze(
                        self._chebconv_attention(X[:, :, :, t],
                                                 edge_index[t],
                                                 X_tilde,
                                                 lambda_max=lambda_max),
                        -1,
                    ))
            X_hat = F.relu(torch.cat(X_hat, dim=-1))

        # (b,N,F,T)->(b,F,N,T) for example (32, 307, 64, 12) -premute->(32, 64, 307,12)
        # then convolution along the time axis is applied
        X_hat = self._time_convolution(X_hat.permute(
            0, 2, 1, 3))  # will give (32, 64, 307,12)
        # (b,N,F,T)-permute>(b,F,N,T) (1,1)->(b,F,N,T)  (32, 64, 307, 12)
        X = self._residual_convolution(X.permute(
            0, 2, 1, 3))  # will also give (32, 64, 307,12)
        #-adding X + X_hat->(32, 64, 307, 12)-premuting-> (32, 12, 307, 64)-layer_normalization_-premuting->(32, 307, 64,12)
        X = self._layer_norm(F.relu(X + X_hat).permute(0, 3, 2, 1))
        X = X.permute(0, 2, 3, 1)
        return X  # (b,N,F,T) for example (32, 307, 64,12)
Esempio n. 7
0
model.train()
T = len_input
x_seq = torch.zeros([batch_size, node_count, node_features, T]).to(device)
target_seq = torch.zeros([batch_size, node_count, T]).to(device)
for b in range(batch_size):
    for t in range(T):
        x, edge_index = create_mock_data(node_count, edge_per_node,
                                         node_features)
        x_seq[b, :, :, t] = x
        target = create_mock_target(node_count, num_classes)
        target_seq[b, :, t] = target
shuffle = True
train_dataset = torch.utils.data.TensorDataset(x_seq, target_seq)

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=shuffle)
for batch_data in train_loader:
    encoder_inputs, labels = batch_data
    data = Data(edge_index=edge_index, edge_attr=None, num_nodes=node_count)
    lambda_max = LaplacianLambdaMax()(data).lambda_max
    outputs = []
    for time_step in range(T):
        outputs.append(
            torch.unsqueeze(
                model(encoder_inputs[:, :, :, time_step],
                      edge_index,
                      spatial_attention,
                      lambda_max=lambda_max), -1))
    spatial_gcn = torch.nn.functional.relu(torch.cat(
        outputs, dim=-1))  # (b,N,F,T) # (b,N,F,T)