def test_laplacian_lambda_max(): out = LaplacianLambdaMax().__repr__() assert out == 'LaplacianLambdaMax(normalization=None)' edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) edge_attr = torch.tensor([1, 1, 2, 2], dtype=torch.float) data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3) out = LaplacianLambdaMax(normalization=None, is_undirected=True)(data) assert len(out) == 4 assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(4.732049)) data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3) out = LaplacianLambdaMax(normalization='sym', is_undirected=True)(data) assert len(out) == 4 assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(2.0)) data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3) out = LaplacianLambdaMax(normalization='rw', is_undirected=True)(data) assert len(out) == 4 assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(2.0)) data = Data(edge_index=edge_index, edge_attr=torch.randn(4, 2), num_nodes=3) out = LaplacianLambdaMax(normalization=None)(data) assert len(out) == 4 assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(3.0))
def forward( self, X: torch.FloatTensor, edge_index: Union[torch.LongTensor, List[torch.LongTensor]] ) -> torch.FloatTensor: """ Making a forward pass with the ASTGCN block. Arg types: * **X** (PyTorch Float Tensor) - Node features for T time periods, with shape (B, N_nodes, F_in, T_in). * **edge_index** (LongTensor): Edge indices, can be an array of a list of Tensor arrays, depending on whether edges change over time. Return types: * **X** (PyTorch Float Tensor) - Hidden state tensor for all nodes, with shape (B, N_nodes, nb_time_filter, T_out). """ batch_size, num_of_vertices, num_of_features, num_of_timesteps = X.shape X_tilde = self._temporal_attention(X) X_tilde = torch.matmul(X.reshape(batch_size, -1, num_of_timesteps), X_tilde) X_tilde = X_tilde.reshape(batch_size, num_of_vertices, num_of_features, num_of_timesteps) X_tilde = self._spatial_attention(X_tilde) if not isinstance(edge_index, list): data = Data(edge_index=edge_index, edge_attr=None, num_nodes=num_of_vertices) lambda_max = LaplacianLambdaMax()(data).lambda_max X_hat = [] for t in range(num_of_timesteps): X_hat.append( torch.unsqueeze( self._chebconv_attention(X[:, :, :, t], edge_index, X_tilde, lambda_max=lambda_max), -1)) X_hat = F.relu(torch.cat(X_hat, dim=-1)) else: X_hat = [] for t in range(num_of_timesteps): data = Data(edge_index=edge_index[t], edge_attr=None, num_nodes=num_of_vertices) lambda_max = LaplacianLambdaMax()(data).lambda_max X_hat.append( torch.unsqueeze( self._chebconv_attention(X[:, :, :, t], edge_index[t], X_tilde, lambda_max=lambda_max), -1)) X_hat = F.relu(torch.cat(X_hat, dim=-1)) X_hat = self._time_convolution(X_hat.permute(0, 2, 1, 3)) X = self._residual_convolution(X.permute(0, 2, 1, 3)) X = self._layer_norm(F.relu(X + X_hat).permute(0, 3, 2, 1)) X = X.permute(0, 2, 3, 1) return X
def forward(self, x, edge_index): """ Making a forward pass. This is one MSTGCN block. B is the batch size. N_nodes is the number of nodes in the graph. F_in is the dimension of input features. T_in is the length of input sequence in time. T_out is the length of output sequence in time. nb_time_filter is the number of time filters used. Arg types: * x (PyTorch Float Tensor) - Node features for T time periods, with shape (B, N_nodes, F_in, T_in). * edge_index (Tensor): Edge indices, can be an array of a list of Tensor arrays, depending on whether edges change over time. Return types: * output (PyTorch Float Tensor) - Hidden state tensor for all nodes, with shape (B, N_nodes, nb_time_filter, T_out). """ # cheb gcn batch_size, num_of_vertices, in_channels, num_of_timesteps = x.shape if not isinstance(edge_index, list): data = Data(edge_index=edge_index, edge_attr=None, num_nodes=num_of_vertices) lambda_max = LaplacianLambdaMax()(data).lambda_max tmp = x.permute(2, 0, 1, 3).reshape( num_of_vertices, in_channels, num_of_timesteps * batch_size) # (N_nodes, F_in, B*T_in) tmp = tmp.permute(2, 0, 1) # (B*T_in, N_nodes, F_in) output = F.relu( self.cheb_conv(x=tmp, edge_index=edge_index, lambda_max=lambda_max)) spatial_gcn = output.permute(1, 2, 0).reshape( num_of_vertices, self.nb_time_filter, batch_size, num_of_timesteps).permute(2, 0, 1, 3) # (B,N_nodes,F_out,T_in) else: # edge_index changes over time outputs = [] for time_step in range(num_of_timesteps): data = Data(edge_index=edge_index[time_step], edge_attr=None, num_nodes=num_of_vertices) lambda_max = LaplacianLambdaMax()(data).lambda_max outputs.append( torch.unsqueeze( self.cheb_conv(x=x[:, :, :, time_step], edge_index=edge_index[time_step], lambda_max=lambda_max), -1)) spatial_gcn = F.relu(torch.cat(outputs, dim=-1)) # (b,N,F,T) # convolution along the time axis time_conv_output = self.time_conv(spatial_gcn.permute(0, 2, 1, 3)) # (b,F,N,T) # residual shortcut x_residual = self.residual_conv(x.permute(0, 2, 1, 3)) # (b,F,N,T) x_residual = self.ln( F.relu(x_residual + time_conv_output).permute(0, 3, 2, 1)).permute( 0, 2, 3, 1) # (b,N,F,T) return x_residual
def forward(self, X: torch.FloatTensor, edge_index: torch.LongTensor) -> torch.FloatTensor: """ Making a forward pass with a single MSTGCN block. Arg types: * X (PyTorch FloatTensor) - Node features for T time periods, with shape (B, N_nodes, F_in, T_in). * edge_index (PyTorch LongTensor): Edge indices, can be an array of a list of Tensor arrays, depending on whether edges change over time. Return types: * X (PyTorch FloatTensor) - Hidden state tensor for all nodes, with shape (B, N_nodes, nb_time_filter, T_out). """ batch_size, num_of_vertices, in_channels, num_of_timesteps = X.shape if not isinstance(edge_index, list): lambda_max = LaplacianLambdaMax()(Data( edge_index=edge_index, edge_attr=None, num_nodes=num_of_vertices)).lambda_max X_tilde = X.permute(2, 0, 1, 3) X_tilde = X_tilde.reshape(num_of_vertices, in_channels, num_of_timesteps * batch_size) X_tilde = X_tilde.permute(2, 0, 1) X_tilde = F.relu( self._cheb_conv(x=X_tilde, edge_index=edge_index, lambda_max=lambda_max)) X_tilde = X_tilde.permute(1, 2, 0) X_tilde = X_tilde.reshape(num_of_vertices, self.nb_time_filter, batch_size, num_of_timesteps) X_tilde = X_tilde.permute(2, 0, 1, 3) else: X_tilde = [] for t in range(num_of_timesteps): lambda_max = LaplacianLambdaMax()(Data( edge_index=edge_index[t], edge_attr=None, num_nodes=num_of_vertices, )).lambda_max X_tilde.append( torch.unsqueeze( self._cheb_conv(X[:, :, :, t], edge_index[t], lambda_max=lambda_max), -1, )) X_tilde = F.relu(torch.cat(X_tilde, dim=-1)) X_tilde = self._time_conv(X_tilde.permute(0, 2, 1, 3)) X = self._residual_conv(X.permute(0, 2, 1, 3)) X = self._layer_norm(F.relu(X + X_tilde).permute(0, 3, 2, 1)) X = X.permute(0, 2, 3, 1) return X
def test_chebconvatt(): """ Testing ChebCOnvAtt block """ node_count = 307 num_classes = 10 edge_per_node = 15 len_input = 12 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") node_features = 2 K = 3 nb_chev_filter = 64 batch_size = 32 x, edge_index = create_mock_data(node_count, edge_per_node, node_features) model = ChebConvAtt(node_features, nb_chev_filter, K) spatial_attention = torch.rand(batch_size, node_count, node_count) spatial_attention = torch.nn.functional.softmax(spatial_attention, dim=1) model.train() T = len_input x_seq = torch.zeros([batch_size, node_count, node_features, T]).to(device) target_seq = torch.zeros([batch_size, node_count, T]).to(device) for b in range(batch_size): for t in range(T): x, edge_index = create_mock_data(node_count, edge_per_node, node_features) x_seq[b, :, :, t] = x target = create_mock_target(node_count, num_classes) target_seq[b, :, t] = target shuffle = True train_dataset = torch.utils.data.TensorDataset(x_seq, target_seq) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle) for batch_data in train_loader: encoder_inputs, labels = batch_data data = Data(edge_index=edge_index, edge_attr=None, num_nodes=node_count) lambda_max = LaplacianLambdaMax()(data).lambda_max outputs = [] for time_step in range(T): outputs.append( torch.unsqueeze( model(encoder_inputs[:, :, :, time_step], edge_index, spatial_attention, lambda_max=lambda_max), -1)) spatial_gcn = torch.nn.functional.relu(torch.cat( outputs, dim=-1)) # (b,N,F,T) # (b,N,F,T) assert spatial_gcn.shape == (batch_size, node_count, nb_chev_filter, T)
def forward( self, X: torch.FloatTensor, edge_index: Union[torch.LongTensor, List[torch.LongTensor]], ) -> torch.FloatTensor: """ Making a forward pass with the ASTGCN block. Arg types: * **X** (PyTorch Float Tensor) - Node features for T time periods, with shape (B, N_nodes, F_in, T_in). * **edge_index** (LongTensor): Edge indices, can be an array of a list of Tensor arrays, depending on whether edges change over time. Return types: * **X** (PyTorch Float Tensor) - Hidden state tensor for all nodes, with shape (B, N_nodes, nb_time_filter, T_out). """ batch_size, num_of_vertices, num_of_features, num_of_timesteps = X.shape # (32, 307, 1, 12) X_tilde = self._temporal_attention( X ) # (b, T, T) (32, 12, 12) * reshaped x(32, 307, 12) -reshape> (32, 307, 1, 12) # xreshaped is e.g. (32, 307, 12) * (32, 12, 12) -then_reshaped> (32, 307, 1, 12) X_tilde = torch.matmul(X.reshape(batch_size, -1, num_of_timesteps), X_tilde) X_tilde = X_tilde.reshape(batch_size, num_of_vertices, num_of_features, num_of_timesteps) X_tilde = self._spatial_attention( X_tilde) # (B,N,N) for example (32, 307, 307) if not isinstance(edge_index, list): data = Data(edge_index=edge_index, edge_attr=None, num_nodes=num_of_vertices) if self._normalization != "sym": lambda_max = LaplacianLambdaMax()(data).lambda_max else: lambda_max = None X_hat = [] for t in range(num_of_timesteps): X_hat.append( torch.unsqueeze( self._chebconv_attention(X[:, :, :, t], edge_index, X_tilde, lambda_max=lambda_max), -1, )) X_hat = F.relu(torch.cat(X_hat, dim=-1)) else: X_hat = [] for t in range(num_of_timesteps): data = Data(edge_index=edge_index[t], edge_attr=None, num_nodes=num_of_vertices) if self._normalization != "sym": lambda_max = LaplacianLambdaMax()(data).lambda_max else: lambda_max = None X_hat.append( torch.unsqueeze( self._chebconv_attention(X[:, :, :, t], edge_index[t], X_tilde, lambda_max=lambda_max), -1, )) X_hat = F.relu(torch.cat(X_hat, dim=-1)) # (b,N,F,T)->(b,F,N,T) for example (32, 307, 64, 12) -premute->(32, 64, 307,12) # then convolution along the time axis is applied X_hat = self._time_convolution(X_hat.permute( 0, 2, 1, 3)) # will give (32, 64, 307,12) # (b,N,F,T)-permute>(b,F,N,T) (1,1)->(b,F,N,T) (32, 64, 307, 12) X = self._residual_convolution(X.permute( 0, 2, 1, 3)) # will also give (32, 64, 307,12) #-adding X + X_hat->(32, 64, 307, 12)-premuting-> (32, 12, 307, 64)-layer_normalization_-premuting->(32, 307, 64,12) X = self._layer_norm(F.relu(X + X_hat).permute(0, 3, 2, 1)) X = X.permute(0, 2, 3, 1) return X # (b,N,F,T) for example (32, 307, 64,12)
model.train() T = len_input x_seq = torch.zeros([batch_size, node_count, node_features, T]).to(device) target_seq = torch.zeros([batch_size, node_count, T]).to(device) for b in range(batch_size): for t in range(T): x, edge_index = create_mock_data(node_count, edge_per_node, node_features) x_seq[b, :, :, t] = x target = create_mock_target(node_count, num_classes) target_seq[b, :, t] = target shuffle = True train_dataset = torch.utils.data.TensorDataset(x_seq, target_seq) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle) for batch_data in train_loader: encoder_inputs, labels = batch_data data = Data(edge_index=edge_index, edge_attr=None, num_nodes=node_count) lambda_max = LaplacianLambdaMax()(data).lambda_max outputs = [] for time_step in range(T): outputs.append( torch.unsqueeze( model(encoder_inputs[:, :, :, time_step], edge_index, spatial_attention, lambda_max=lambda_max), -1)) spatial_gcn = torch.nn.functional.relu(torch.cat( outputs, dim=-1)) # (b,N,F,T) # (b,N,F,T)