def forward(self, data, conv_train=False): x = data.x edge_index = data.edge_index x1 = self.norm1(self.act1(self.conv1(x, edge_index))) x = self.dropout(x1) x2 = self.norm2(self.act2(self.conv2(x, edge_index))) x = self.dropout(x2) x3 = self.norm3(self.act3(self.conv3(x, edge_index))) h_conv = torch.cat([x1, x2, x3], dim=1) #compute GNN only output conv_batch_avg = gap(h_conv, data.batch) conv_batch_add = gadd(h_conv, data.batch) conv_batch_max = gmp(h_conv, data.batch) h_GNN = torch.cat([conv_batch_avg, conv_batch_add, conv_batch_max], dim=1) gnn_out = self.out_fun(self.lin_GNN(h_GNN)) if conv_train: return None, None, gnn_out #SOM _, _, som_out_1 = self.som1(x1) _, _, som_out_2 = self.som2(x2) _, _, som_out_3 = self.som3(x3) #READOUT h1 = self.out_norm1(self.act1(self.out_conv1(som_out_1, edge_index))) h2 = self.out_norm2(self.act2(self.out_conv2(som_out_2, edge_index))) h3 = self.out_norm3(self.act3(self.out_conv3(som_out_3, edge_index))) som_out_conv = torch.cat([h1, h2, h3], dim=1) som_batch_avg = gap(som_out_conv, data.batch) som_batch_add = gadd(som_out_conv, data.batch) som_batch_max = gmp(som_out_conv, data.batch) h = torch.cat([som_batch_avg, som_batch_add, som_batch_max], dim=1) h = self.out_norm4(h) h = self.out_act(self.lin_out1(h)) h = self.dropout(h) h = self.out_act(self.lin_out2(h)) h = self.dropout(h) h = self.out_fun(self.lin_out3(h)) return h, h_conv, gnn_out
def forward(self, data, hidden_layer_aggregator=None): X = data.x k = self.max_k #compute Laplacian L_edge_index, L_values = get_laplacian(data.edge_index, normalization="sym") L = torch.sparse.FloatTensor(L_edge_index, L_values, torch.Size([X.shape[0], X.shape[0]])).to_dense() H = [X] for i in range(k - 1): xhi_layer_i = torch.mm(torch.matrix_power(L, i + 1), X) H.append(xhi_layer_i) H = self.lin(torch.cat(H, dim=1), self.xhi_layer_mask) H = self.reservoir_act_fun(H) H = self.bn_hidden_rec(H) H_avg = gap(H, data.batch) H_add = gadd(H, data.batch) H_max = gmp(H, data.batch) H = torch.cat([H_avg, H_add, H_max], dim=1) if self.output == "funnel" or self.output is None: return self.funnel_output(H) elif self.output == "one_layer": return self.one_layer_out(H) elif self.output == "restricted_funnel": return self.restricted_funnel_output(H) else: assert False, "error in output stage"
def forward(self, data): ''' model forward method :param data: current batch :return: the model output given the batch ''' X = data.x edge_index = data.edge_index X = self.norm0(self.conv0(X, edge_index)) k = self.k # compute adjacency matrix A adjacency_indexes = data.edge_index A_rows = adjacency_indexes[0] A_data = [1] * A_rows.shape[0] v_index = torch.FloatTensor(A_data).to(self.device) A_shape = [X.shape[0], X.shape[0]] A = torch.sparse.FloatTensor(adjacency_indexes, v_index, torch.Size(A_shape)).to_dense() H = [X] # compute the parts of matrix H for each k for i in range(k - 1): xhi_layer_i = torch.mm(torch.matrix_power(A, i + 1), X) H.append(xhi_layer_i) # project H by W H = self.bn_hidden_rec( self.lin(torch.cat(H, dim=1), self.xhi_layer_mask)) # compute the graph layer representation using 3 different pooling strategies H_avg = gap(H, data.batch) H_add = gadd(H, data.batch) H_max = gmp(H, data.batch) H = torch.cat([H_avg, H_add, H_max], dim=1) #compute the readout if self.output == "funnel" or self.output is None: return self.funnel_output(H) elif self.output == "restricted_funnel": return self.restricted_funnel_output(H) else: assert False, "error in output stage"
def readout_fw(self, data): H = data.reservoir H = self.reservoir_act_fun(H) H = self.bn_hidden_rec(H) H_avg = gap(H, data.batch) H_add = gadd(H, data.batch) H_max = gmp(H, data.batch) H = torch.cat([H_avg, H_add, H_max], dim=1) # torch.cat([H_avg, H_add, H_max, H_min],dim=1) if self.output == "funnel" or self.output is None: return self.funnel_output(H) elif self.output == "one_layer": return self.one_layer_out(H) elif self.output == "restricted_funnel": return self.restricted_funnel_output(H) elif self.output == "svm": return self.svm_output(H) else: assert False, "error in output stage"