def forward(self, data): x, edge_index, batch = data.node_feature, data.edge_index, data.batch x = self.pre_mp(x) num_nodes = x.size(0) # [num nodes x current num layer x hidden_dim] all_emb = x.unsqueeze(1) # [num nodes x (curr num layer * hidden_dim)] emb = x for i in range(len(self.convs)): if self.args.skip == 'learnable': skip_vals = self.learnable_skip[i, :i + 1].unsqueeze(0).unsqueeze(-1) curr_emb = all_emb * torch.sigmoid(skip_vals) curr_emb = curr_emb.view(num_nodes, -1) x = self.convs[i](curr_emb, edge_index) if self.args.skip == 'all' or self.args.skip == 'learnable': x = self.convs[i](emb, edge_index) else: x = self.convs[i](x, edge_index) x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) emb = torch.cat((emb, x), 1) if self.args.skip == 'learnable': all_emb = torch.cat((all_emb, x.unsqueeze(1)), 1) # x = pyg_nn.global_mean_pool(x, batch) emb = pyg_nn.global_add_pool(emb, batch) emb = self.post_mp(emb) out = F.log_softmax(emb, dim=1) return out
def forward(self, data): x = data.x # Compute graph convolutional part if self.net_type != 'gmmcn': for gcn_layer in self.gcn: x = F.relu(gcn_layer(x, data.edge_index)) else: for gcn_layer in self.gcn: x = F.relu( gcn_layer(x.float(), data.edge_index.long(), data.pseudo.float())) # Apply global sum pooling and dropout x = global_add_pool(x, data.batch) x = self.drop(x) embedding = x # Compute fully-connected part if self.fc_dim > 0: x = F.relu(self.fc(x)) output = self.fc_out(x) # sigmoid in loss function return embedding, output
def test_permuted_global_pool(): N_1, N_2 = 4, 6 x = torch.randn(N_1 + N_2, 4) batch = torch.cat([torch.zeros(N_1), torch.ones(N_2)]).to(torch.long) perm = torch.randperm(N_1 + N_2) px = x[perm] pbatch = batch[perm] px1 = px[pbatch == 0] px2 = px[pbatch == 1] out = global_add_pool(px, pbatch) assert out.size() == (2, 4) assert torch.allclose(out[0], px1.sum(dim=0)) assert torch.allclose(out[1], px2.sum(dim=0)) out = global_mean_pool(px, pbatch) assert out.size() == (2, 4) assert torch.allclose(out[0], px1.mean(dim=0)) assert torch.allclose(out[1], px2.mean(dim=0)) out = global_max_pool(px, pbatch) assert out.size() == (2, 4) assert torch.allclose(out[0], px1.max(dim=0)[0]) assert torch.allclose(out[1], px2.max(dim=0)[0])
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x1 = self.conv1(x, edge_index) y_molecules = global_add_pool(x1, batch) z_molecules = self.gather_layer(y_molecules) return z_molecules
def forward(self, data): x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr print('x ', x) print('edge_index ', edge_index) print('edge_attr ', edge_attr) x = F.relu(self.nnconv1(x, edge_index, edge_attr)) print('x1 ', x.shape) x = self.bn1(x) print('x2 ', x.shape) x = F.relu(self.nnconv2(x, edge_index)) print('x3 ', x.shape) x = self.bn2(x) print('x4 ', x.shape) x = global_add_pool(x, data.batch) print('x5 ', x.shape) x = F.relu(self.fc1(x)) print('x6 ', x.shape) # x = self.bn3(x) # print('x7 ', x.shape) x = F.relu(self.fc2(x)) print('x8 ', x.shape) x = F.dropout(x, p=0.2, training=self.training) print('x9 ', x.shape) x = self.fc3(x) print('x10 ', x.shape) x = F.log_softmax(x, dim=1) print('x11 ', x.shape) return x
def forward(self, batched_data): x, edge_index, node_depth, batch = batched_data.x, batched_data.edge_index, batched_data.node_depth, batched_data.batch x = self.node_encoder(x, node_depth.view(-1, )) node_states_per_layer = [] # one entry per layer (final state of that layer), shape: number of nodes in batch v x D node_states_per_layer.append(x) for layer_idx, num_timesteps in enumerate(self.layer_timesteps): # Extract residual messages, if any: layer_residual_connections = self.residual_connections.get(str(layer_idx)) layer_residual_states = [] if layer_residual_connections is None else \ [node_states_per_layer[residual_layer_idx] for residual_layer_idx in layer_residual_connections] # Record new states for this layer. Initialised to last state, but will be updated below: node_states_layer = self.convs[layer_idx](node_states_per_layer[-1], edge_index, batched_data.edge_attr, layer_residual_states) node_states_per_layer.append(node_states_layer) hx = torch.cat([node_states_per_layer[-1], x], dim=-1) x = self.classifier_l(hx) * self.classifier_r(hx) output = global_add_pool(x, batch=batch) pred_list = [] for i in range(self.max_seq_len): pred_list.append(self.graph_pred_linear_list[i](output)) return pred_list
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch target = data.target x = F.relu(self.conv1(x, edge_index)) x = self.bn1(x) x = F.relu(self.conv2(x, edge_index)) x = self.bn2(x) x = F.relu(self.conv3(x, edge_index)) x = self.bn3(x) x = global_add_pool(x, batch) x = F.relu(self.fc1_xd(x)) x = F.dropout(x, p=0.2, training=self.training) embedded_xt = self.embedding_xt(target) # flatten xt = embedded_xt.view(-1, 1000 * 128) xt = self.fc1_xt(xt) # concat xc = torch.cat((x, xt), 1) # add some dense layers xc = self.fc1(xc) xc = self.relu(xc) xc = self.dropout(xc) xc = self.fc2(xc) xc = self.relu(xc) xc = self.dropout(xc) out = self.out(xc) return out
def forward(self, z, edge_index, batch, x=None, edge_weight=None, node_id=None): z_emb = self.z_embedding(z) if z_emb.ndim == 3: # in case z has multiple integer labels z_emb = z_emb.sum(dim=1) if self.use_feature and x is not None: x = torch.cat([z_emb, x.to(torch.float)], 1) else: x = z_emb if self.node_embedding is not None and node_id is not None: n_emb = self.node_embedding(node_id) x = torch.cat([x, n_emb], 1) for conv in self.convs[:-1]: x = conv(x, edge_index, edge_weight) x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) x = self.convs[-1](x, edge_index, edge_weight) if True: # center pooling _, center_indices = np.unique(batch.cpu().numpy(), return_index=True) x_src = x[center_indices] x_dst = x[center_indices + 1] x = (x_src * x_dst) x = F.relu(self.lin1(x)) x = F.dropout(x, p=self.dropout, training=self.training) x = self.lin2(x) else: # sum pooling x = global_add_pool(x, batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=self.dropout, training=self.training) x = self.lin2(x) return x
def forward(self, batch): x, edge_index, batch_ids = batch.x, batch.edge_index, batch.batch out = None for _ in range(self.num_perm): new_x = torch.empty(x.size(0), x.size(1) + self.fixed_size).to(x.device) for graph in range(torch.max(batch_ids).item() + 1): node_indices = (batch_ids == graph).nonzero().squeeze(1) graph_size = node_indices.size(0) perm = torch.randperm(graph_size) node_ids = self.__getattr__(f"node_ids").repeat( graph_size // self.fixed_size + 1, 1)[:graph_size] permuted_node_ids = node_ids.to(x.device)[perm, :] new_x[node_indices] = torch.cat( [x[node_indices], permuted_node_ids], dim=1) h_v = self.node_embedder.forward(new_x, edge_index) h_g = global_add_pool(h_v, batch_ids) if out is None: out = h_g / self.num_perm else: out += h_g / self.num_perm return out
def forward(self, x, edge_index, batch, pretr=False): x1 = F.relu(self.conv1(x, edge_index)) #x1 = F.dropout(x1, training=self.training) x2 = self.conv2(x1, edge_index) #x2 = F.dropout(x2, training=self.training) x3 = self.conv3(x2, edge_index) #return F.log_softmax(x, dim=1) x = torch.cat([x1, x2, x3], dim=1) x = F.relu(self.fc1(x)) x = global_add_pool(x, batch) #print(x.shape) #x = F.relu(self.fc1a(x)) #x = F.dropout(x, p=0.2, training=self.training) #x = self.fc2(x) if pretr: out1 = self.fc2(x) #else: #x= self.fc2(x) x = self.fc3(x) x = F.log_softmax(x, dim=-1) if pretr: return out1, x #F.log_softmax(x, dim=-1) else: return x
def forward(self, x, edge_index, batch, pretr=False): x = F.relu(self.conv1(x, edge_index)) x = self.bn1(x) x = F.relu(self.conv2(x, edge_index)) x = self.bn2(x) x = F.relu(self.conv3(x, edge_index)) x = self.bn3(x) x = F.relu(self.conv4(x, edge_index)) x = self.bn4(x) x = F.relu(self.conv5(x, edge_index)) x = self.bn5(x) x = global_add_pool(x, batch) x = F.relu(self.fc1(x)) #x = global_add_pool(x, batch) x = F.dropout(x, p=0.5, training=self.training) #x = self.fc3(x) if pretr: x = self.fc3(x) else: #x = F.dropout(x, p=0.5, training=self.training) x = self.fc2(x) x = F.log_softmax(x, dim=-1) return x
def forward(self, batched_data): x, edge_index, node_depth, batch = batched_data.x, batched_data.edge_index, batched_data.node_depth, batched_data.batch x = self.node_encoder(x, node_depth.view(-1, )) node_states_per_layer = [] # one entry per layer (final state of that layer), shape: number of nodes in batch v x D node_states_per_layer.append(x) for layer_idx, num_timesteps in enumerate(self.layer_timesteps): # Record new states for this layer. Initialised to last state, but will be updated below: node_states_layer = self.convs[layer_idx](node_states_per_layer[-1], edge_index) node_states_per_layer.append(node_states_layer) hx = torch.cat([node_states_per_layer[-1], x], dim=-1) x = self.classifier_l(hx) * self.classifier_r(hx) output = global_add_pool(x, batch=batch) if self.num_class > 0: return self.graph_pred_linear(output) pred_list = [] for i in range(self.max_seq_len): pred_list.append(self.graph_pred_linear_list[i](output)) return pred_list
def pool_func(x, batch, mode="sum"): if mode == "sum": return global_add_pool(x, batch) elif mode == "mean": return global_mean_pool(x, batch) elif mode == "max": return global_max_pool(x, batch)
def forward(self, data): x, pos, batch, u = data.x, data.pos, data.batch, data.u # Get edges using positions by computing the kNNs or the neighbors within a radius #edge_index = knn_graph(pos, k=self.k_nn, batch=batch, loop=self.loop) edge_index = radius_graph(pos, r=self.k_nn, batch=batch, loop=self.loop) # Start message passing for layer in self.layers: if self.namemodel == "DeepSet": x = layer(x) elif self.namemodel == "PointNet": x = layer(x=x, pos=pos, edge_index=edge_index) elif self.namemodel == "MetaNet": x, dumb, u = layer(x, edge_index, None, u, batch) else: x = layer(x=x, edge_index=edge_index) self.h = x x = x.relu() # Mix different global pooling layers addpool = global_add_pool(x, batch) # [num_examples, hidden_channels] meanpool = global_mean_pool(x, batch) maxpool = global_max_pool(x, batch) #self.pooled = torch.cat([addpool, meanpool, maxpool], dim=1) self.pooled = torch.cat([addpool, meanpool, maxpool, u], dim=1) # Final linear layer return self.lin(self.pooled)
def forward(self, data): x1 = F.relu(self.conv1(data.x, data.edge_index)) x1 = self.bn1(x1) # x1_g = global_add_pool(x1, data.batch) x2 = F.relu(self.conv2(x1, data.edge_index)) x2 = self.bn2(x2) # x2_g = global_add_pool(x2, data.batch) x3 = F.relu(self.conv3(x2, data.edge_index)) x3 = self.bn3(x3) # x3_g = global_add_pool(x3, data.batch) x4 = F.relu(self.conv4(x3, data.edge_index)) x4 = self.bn4(x4) # x4_g = global_add_pool(x4, data.batch) x5 = F.relu(self.conv5(x4, data.edge_index)) x5 = self.bn5(x5) x5_g = global_add_pool(x5, data.batch) # x = torch.cat([x1_g, x2_g, x3_g, x4_g, x5_g], dim=-1) x = F.relu(self.fc1(x5_g)) x = self.fc2(x) return x.view(-1)
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch target = data.target u = F.relu(self.conv1(x, edge_index)) u = self.bn1(u) u = F.relu(self.conv2(u, edge_index)) u = self.bn2(u) u = F.relu(self.conv3(u, edge_index)) u = self.bn3(u) u = F.relu(self.conv4(u, edge_index)) u = self.bn4(u) u = F.relu(self.conv5(u, edge_index)) u = self.bn5(u) u = global_add_pool(u, batch) u = F.relu(self.fc1_xd(u)) u = F.dropout(u, p=0.2, training=self.training) embedded_xt = self.embedding_xt(target) conv_xt = self.conv_xt_1(embedded_xt) xt = conv_xt.view(-1, 32 * 121) xt = self.fc1_xt(xt) xc = torch.cat((u, xt), 1) xc = self.fc1(xc) xc = self.relu(xc) xc = self.dropout(xc) xc = self.fc2(xc) xc = self.relu(xc) xc = self.dropout(xc) out = self.out(xc) return out
def forward(self, batched_data): x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch ### virtual node embeddings for graphs virtualnode_embedding = self.virtualnode_embedding( torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to( edge_index.device)) h_list = [self.atom_encoder(x)] for layer in range(self.num_layer): ### add message from virtual nodes to graph nodes h_list[layer] = h_list[layer] + virtualnode_embedding[batch] ### Message passing among graph nodes h = self.convs[layer](h_list[layer], edge_index, edge_attr) h = self.batch_norms[layer](h) if layer == self.num_layer - 1: #remove relu for the last layer h = F.dropout(h, self.drop_ratio, training=self.training) else: h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) if self.residual: h = h + h_list[layer] h_list.append(h) ### update the virtual nodes if layer < self.num_layer - 1: ### add message from graph nodes to virtual nodes virtualnode_embedding_temp = global_add_pool( h_list[layer], batch) + virtualnode_embedding ### transform virtual nodes using MLP if self.residual: virtualnode_embedding = virtualnode_embedding + F.dropout( self.mlp_virtualnode_list[layer] (virtualnode_embedding_temp), self.drop_ratio, training=self.training) else: virtualnode_embedding = F.dropout( self.mlp_virtualnode_list[layer]( virtualnode_embedding_temp), self.drop_ratio, training=self.training) ### Different implementations of Jk-concat if self.JK == "last": node_representation = h_list[-1] elif self.JK == "sum": node_representation = 0 for layer in range(self.num_layer): node_representation += h_list[layer] return node_representation
def forward(self, x, edge_index, edge_attr, batch): """""" # Atom Embedding: x = F.leaky_relu_(self.lin1(x)) h = F.elu_(self.atom_convs[0](x, edge_index, edge_attr)) h = F.dropout(h, p=self.dropout, training=self.training) x = self.atom_grus[0](h, x).relu_() for conv, gru in zip(self.atom_convs[1:], self.atom_grus[1:]): h = F.elu_(conv(x, edge_index)) h = F.dropout(h, p=self.dropout, training=self.training) x = gru(h, x).relu_() # Molecule Embedding: row = torch.arange(batch.size(0), device=batch.device) edge_index = torch.stack([row, batch], dim=0) out = global_add_pool(x, batch).relu_() for t in range(self.num_timesteps): h = F.elu_(self.mol_conv((x, out), edge_index)) h = F.dropout(h, p=self.dropout, training=self.training) out = self.mol_gru(h, out).relu_() # Predictor: out = F.dropout(out, p=self.dropout, training=self.training) return self.lin2(out)
def get_distribution_parameters(self, node_embeddings, batch): if self.aggregate is not None: graph_embeddings = self.aggregate( self.node_transform(node_embeddings), batch) out = self.output_activation( self.final_transform(graph_embeddings)) else: out = self.output_activation( self.final_transform(self.node_transform(node_embeddings))) if 'binomial' in self.output_type: params = torch.reshape( out, [-1, self.no_experts, 2]) # ? x no_experts x K # first parameter not used here _, p = torch.round(torch.relu(params[:, :, 0])) + 1, torch.sigmoid( params[:, :, 1]) n = global_add_pool( torch.ones(node_embeddings.shape[0], self.no_experts).to(node_embeddings.device), batch) distr_params = (n, p) elif 'gaussian' in self.output_type: # Assume isotropic gaussians params = torch.reshape(out, [-1, self.no_experts, 2, self.dim_target ]) # ? x no_experts x 2 x F mu, var = params[:, :, 0, :], params[:, :, 1, :] var = torch.nn.functional.softplus(var) + 1e-8 # F is assumed to be 1 for now, add dimension to F distr_params = (mu, var) # each has shape ? x no_experts X F return distr_params
def forward(self, x, edge_index, batch): x = self.conv1(x, edge_index) x = F.relu(x) x = self.conv2(x, edge_index) x = global_add_pool(x, batch) x = self.lin(x) return x.log_softmax(dim=1)
def forward(self, data, batch_size=None, **kwargs): x = data.x batch = data.batch edge_index = data.edge_index pos = data.pos # infer real batch size, in case empty sample if batch_size is None: batch_size = data['size'].sum().item() img_feature = self.encoder(x).flatten(1) x = torch.cat([img_feature, pos], dim=1) x = self.gnn(x=x, edge_index=edge_index) x = self.encoder2(x) if self.global_aggr == 'max': global_feature = gnn.global_max_pool(x, batch, size=batch_size) elif self.global_aggr == 'sum': global_feature = gnn.global_add_pool(x, batch, size=batch_size) else: raise NotImplementedError() logits = self.fc(global_feature) out_dict = { 'logits': logits, } return out_dict
def forward( self, data: 'torch_geometric.data.Data' ) -> Tuple['torch.tensor', 'torch.tensor', 'torch.tensor']: """ torch.nn.module forward operation Args: data (torch_geometric.data.Data): data to be fed forward; must have node attributes, edge attributes, edge index defined Returns: Tuple[torch.tensor, torch.tensor, torch.tensor]: (GCN output, node embeddings, edge embeddings) """ # Get batch x, edge_attr, edge_index, batch = data.x, data.edge_attr,\ data.edge_index, data.batch row, col = edge_index if data.num_node_features == 0: x = torch.ones(data.num_nodes, 1) out = F.relu(self.lin0(x)) out_edge = F.relu(self.lin0_edge(edge_attr)) h = out.unsqueeze(0) h_edge = out_edge.unsqueeze(0) # Feed forward, node and edge messages for i in range(self._n_messages): m = F.relu(self.node_conv(out, edge_index)) emb_node = m m = F.dropout(m, p=self._dropout, training=self.training) out, h = self.node_gru(m.unsqueeze(0), h) out = out.squeeze(0) m_edge = F.relu(self.edge_conv(out_edge, edge_index)) emb_edge = m_edge m_edge = F.dropout(m_edge, p=self._dropout, training=self.training) out_edge, h_edge = self.edge_gru(m_edge.unsqueeze(0), h_edge) out_edge = out_edge.squeeze(0) # Concatenate node network and edge network output tensors out = torch.cat([out[row], out_edge[col]], dim=1) # Perform scatter add, reshape to original node dimensionality out = scatter_add(out, col, dim=0, dim_size=x.size(0)) # Perform summation over all nodes w.r.t. current batch out = pyg_nn.global_add_pool(out, batch) # Perform post-message passing feed forward operations for layer in self.post_conv[:-1]: out = layer(out) out = F.relu(out) out = F.dropout(out, p=self._dropout, training=self.training) out = self.post_conv[-1](out) # Return fed-forward data, node embedding, edge embedding return out, emb_node, emb_edge
def forward(self, data): #if data.x is None: # data.x = torch.ones((data.num_nodes, 1), device=utils.get_device()) #x = self.pre_mp(x) if self.feat_preprocess is not None: if not hasattr(data, "preprocessed"): data = self.feat_preprocess(data) data.preprocessed = True if 'aifb' == 'aifb' or 'wn18' == 'wn18': x, edge_index, batch, edge_type = data.node_feature, data.edge_index, data.batch, data.edge_feature edge_type = edge_type.reshape(-1) x = self.pre_mp(x) else: x, edge_index, batch = data.node_feature, data.edge_index, data.batch x = self.pre_mp(x) all_emb = x.unsqueeze(1) emb = x for i in range( len(self.convs_sum) if self.conv_type == "PNA" else len(self.convs)): if self.skip == 'learnable': skip_vals = self.learnable_skip[i, :i + 1].unsqueeze(0).unsqueeze(-1) curr_emb = all_emb * torch.sigmoid(skip_vals) curr_emb = curr_emb.view(x.size(0), -1) if self.conv_type == "PNA": x = torch.cat((self.convs_sum[i](curr_emb, edge_index), self.convs_mean[i](curr_emb, edge_index), self.convs_max[i](curr_emb, edge_index)), dim=-1) elif self.conv_type == "RGCN": # edge_type_ = torch.randint_like(edge_index, low=0, high=2)[1].detach().to(edge_index.device) x = self.convs[i](curr_emb, edge_index, edge_type) else: x = self.convs[i](curr_emb, edge_index) elif self.skip == 'all': if self.conv_type == "PNA": x = torch.cat((self.convs_sum[i]( emb, edge_index), self.convs_mean[i](emb, edge_index), self.convs_max[i](emb, edge_index)), dim=-1) else: x = self.convs[i](emb, edge_index) else: x = self.convs[i](x, edge_index) x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) emb = torch.cat((emb, x), 1) if self.skip == 'learnable': all_emb = torch.cat((all_emb, x.unsqueeze(1)), 1) # x = pyg_nn.global_mean_pool(x, batch) emb = pyg_nn.global_add_pool(emb, batch) emb = self.post_mp(emb) #emb = self.batch_norm(emb) # TODO: test #out = F.log_softmax(emb, dim=1) return emb
def forward(self, data): data.x = self.conv(data.x, data.edge_index) att_x, att_edge_index, att_edge_attr, att_batch, att_perm, att_scores = self.readout( data.x, data.edge_index, batch=data.batch) global_graph_emb = global_add_pool(att_x, att_batch) # data = max_pool_neighbor_x(data) return data, global_graph_emb
def forward(self, anchor_batch, negative_batch, positive_batch, anchor: Tensor, negative: Tensor, positive: Tensor, anchor_gt: Tensor, negative_gt: Tensor, positive_gt: Tensor) -> Tensor: anchor = global_add_pool(anchor, anchor_batch) positive = global_add_pool(positive, positive_batch) negative = global_add_pool(negative, negative_batch) pos_distance = torch.linalg.norm(positive - anchor, dim=1) negative_distance = torch.linalg.norm(negative - anchor, dim=1) coeff = torch.div(torch.abs(negative_gt - anchor_gt), (torch.abs(positive_gt - anchor_gt) + self.eps)) loss = F.relu((pos_distance - coeff * negative_distance) + self.margin) return torch.mean(loss)
def forward(self, data): subgraph_data = subgraph_loader( data, k, super_node_size, num_tours, num_cpus ) subgraphs = [get_subgraph(data[subgraph_data.batch[i].item()], subgraph_data.subgraphs[i].squeeze()) for i in range(len(subgraph_data.subgraphs))] subgraphs_lst = [] for i in range(0, len(subgraphs), 500): subgraphs_b = Batch().from_data_list(subgraphs[i:i+min([500,len(subgraphs)-i])]) subgraphs_b = self.gnn_layer(subgraphs_b.x.cuda(), subgraphs_b.edge_index.cuda(), subgraphs_b.batch.cuda()) \ if next(self.parameters()).get_device() != -1 else self.gnn_layer(subgraphs_b.x, subgraphs_b.edge_index, subgraphs_b.batch) subgraphs_lst.append(subgraphs_b) subgraphs = torch.cat(subgraphs_lst,dim=0) subgraphs = self.output_layer(subgraphs) weights = subgraph_data.weights.cuda() if next(self.parameters()).get_device() != -1 else subgraph_data.weights batch = subgraph_data.batch.cuda() if next(self.parameters()).get_device() != -1 else subgraph_data.batch subgraphs = subgraphs*weights norm = global_add_pool(weights, batch) energy = global_add_pool(subgraphs, batch) return energy/norm
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) xs = [global_add_pool(x, batch)] for i, conv in enumerate(self.convs): x = F.relu(conv(x, edge_index)) xs += [global_add_pool(x, batch)] if i % 2 == 0 and i < len(self.convs) - 1: pool = self.pools[i // 2] x, edge_index, _, batch, _, _ = pool(x, edge_index, batch=batch) x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1)
def forward(self, x, edge_index, batch): for conv, batch_norm in zip(self.convs, self.batch_norms): x = F.relu(batch_norm(conv(x, edge_index))) x = global_add_pool(x, batch) x = F.relu(self.batch_norm1(self.lin1(x))) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1)
def test_permuted_global_pool(): N_1, N_2 = 4, 6 x = torch.randn(N_1 + N_2, 4) batch = torch.cat([torch.zeros(N_1), torch.ones(N_2)]).to(torch.long) perm = torch.randperm(N_1 + N_2) out_1 = global_add_pool(x, batch) out_2 = global_add_pool(x[perm], batch[perm]) assert torch.allclose(out_1, out_2) out_1 = global_mean_pool(x, batch) out_2 = global_mean_pool(x[perm], batch[perm]) assert torch.allclose(out_1, out_2) out_1 = global_max_pool(x, batch) out_2 = global_max_pool(x[perm], batch[perm]) assert torch.allclose(out_1, out_2)
def forward(self, x, edge_index, edge_attr, batch): x = self.node_emb(x.squeeze()) edge_attr = self.edge_emb(edge_attr) for conv, batch_norm in zip(self.convs, self.batch_norms): x = F.relu(batch_norm(conv(x, edge_index, edge_attr))) x = global_add_pool(x, batch) return self.mlp(x)