def forward(self, data): """ data.x: (num_nodes, num_features)""" x, edge_index = data.x, data.edge_index batch_info = create_batch_info(data, self.edge_counter) # Create the context matrix if self.use_x: u = x else: u = data.x.new_zeros((data.num_nodes, batch_info['n_colors'])) u.scatter_(1, data.coloring, 1) u = u[..., None] # Forward pass out = self.no_prop(u, batch_info) u = self.initial_lin(u) for i, (conv, bn, extractor) in enumerate( zip(self.convs, self.batch_norm_list, self.feature_extractors)): if self.use_batch_norm and i > 0: u = bn(u) u = conv(u, edge_index, batch_info) global_features = extractor.forward(u, batch_info) out += global_features / len(self.convs) # Two layer MLP with dropout and residual connections: out = torch.relu(self.after_conv(out)) + out out = F.dropout(out, p=self.dropout_prob, training=self.training) out = self.final_lin(out) return F.log_softmax(out, dim=-1)
def forward(self, data): """ data.x: (num_nodes, num_features)""" x, edge_index, batch, batch_size = data.x, data.edge_index, data.batch, data.num_graphs batch_info = create_batch_info(data, self.edge_counter) # Create the context matrix u = data.x.new_zeros((data.num_nodes, batch_info['n_colors'])) u.scatter_(1, data.coloring, 1) u = u[..., None] # Map x to u shortest_path_ids = x[:, 0] lap_feat = x[:, 1] u_shortest_path = torch.zeros_like(u) u_lap_feat = torch.zeros_like(u) non_zero = shortest_path_ids.nonzero()[:, 0] nonzero_batch = batch_info['batch'][non_zero] nonzero_color = batch_info['coloring'][non_zero][:, 0] for b, c in zip(nonzero_batch, nonzero_color): u_shortest_path[batch == b, c] = 1 for i, feat in enumerate(lap_feat): u_lap_feat[i, batch_info['coloring'][i]] = feat u = torch.cat((u, u_shortest_path, u_lap_feat), dim=2) # Forward pass u = self.initial_lin_u(u) hidden_state = None for i, (conv, bn_u) in enumerate(zip(self.convs, self.batch_norm_u)): if i > 0: u = bn_u(u) u = conv(u, edge_index, batch_info, self.debug_model) extracted = self.extractor(x, u, batch_info)[None, :, :] hidden_state = self.gru(extracted, hidden_state)[1] # Compute the final representation out = hidden_state[0, :, :] nodes_out = self.final_node(out) after_set2set = self.set2set(out, batch_info['batch']) graph_out = self.final_graph(after_set2set) return nodes_out, graph_out
def forward(self, data): """ data.x: (num_nodes, num_node_features)""" x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr # Compute information about the batch batch_info = create_batch_info(data, self.edge_counter) # Create the context matrix if self.use_x: assert x is not None u = x elif self.map_x_to_u: u = map_x_to_u(data, batch_info) else: u = data.x.new_zeros((data.num_nodes, batch_info['n_colors'])) u.scatter_(1, data.coloring, 1) u = u[..., None] # Forward pass out = self.no_prop(u, batch_info) u = self.initial_lin(u) for i in range(len(self.convs)): conv = self.convs[i] bn = self.batch_norm_list[i] extractor = self.feature_extractor if self.shared_extractor else self.feature_extractors[ i] if self.use_batch_norm and i > 0: u = bn(u) u = conv(u, edge_index, edge_attr, batch_info) + (u if self.residual else 0) global_features = extractor.forward(u, batch_info) out += global_features / len(self.convs) out = self.final_lin(torch.relu(self.after_conv(out)) + out) assert out.shape[1] == 1 return out[:, 0]