def forward(self, data): x = self.atom_encoder(data.x) edge_index = data.edge_index edge_attr = data.edge_attr edge_attr = torch.LongTensor([ edge_type[0] + edge_type[1] * 5 + edge_type[2] * 30 for edge_type in edge_attr ]).to(self.device) for i, layer in enumerate(self.rgcn_list): x = layer(x, edge_index, edge_attr) x = F.relu(x) if i == len(self.rgcn_list) - 1: continue x = self.batchnorm(x) if self.pool_layer == 'add': x = global_add_pool(x, data.batch) if self.pool_layer == 'mean': x = global_mean_pool(x, data.batch) if self.pool_layer == 'max': x = global_max_pool(x, data.batch) if self.pool_layer == 'sort': x = global_sort_pool(x, data.batch, self.k) x = F.relu(self.linear1(x)) x = self.linear2(x) return x
def embed_graph(self, x, edge_index, batch=None): attn_weights = dict() x = F.one_hot(x, num_classes=self.config.num_feature_dim).float() x = F.relu(self.gc1(x, edge_index)) x = F.dropout(x, self.config.dropout, training=self.training) x = self.gc2(x, edge_index) if self.config.pooling_type == "sagpool": x, edge_index, _, batch, attn_weights['pool_perm'], attn_weights[ 'pool_score'] = self.pool1(x, edge_index, batch=batch) elif self.config.pooling_type == "topk": x, edge_index, _, batch, attn_weights['pool_perm'], attn_weights[ 'pool_score'] = self.pool1(x, edge_index, batch=batch) elif self.config.pooling_type == "asa": x, edge_index, _, batch, attn_weights['pool_perm'] = self.pool1( x, edge_index, batch=batch) if self.config.readout_type == "add": x = global_add_pool(x, batch) elif self.config.readout_type == "mean": x = global_mean_pool(x, batch) elif self.config.readout_type == "max": x = global_max_pool(x, batch) elif self.config.readout_type == "sort": x = global_sort_pool(x, batch, k=100) else: pass attn_weights['batch'] = batch x = self.fc(x) return x, attn_weights
def sortpooling_embedding_tg(self, data): ''' if exists edge feature, concatenate to node feature vector ''' node_feat, edge_index, batch = data.x, data.edge_index, data.batch '''G-UNet Layer to process the graph data''' # the output feature dimension of gUnet is cur_message_layer = self.gUnet(x=node_feat, edge_index=edge_index, batch=batch) ''' sortpooling layer ''' # the shape of global_sort_pool is (B, k*total_latent_dim) batch_sortpooling_graphs = global_sort_pool(cur_message_layer, batch, self.k) ''' traditional 1d convlution and dense layers ''' to_conv1d = batch_sortpooling_graphs.view( (-1, 1, self.k * self.total_latent_dim)) conv1d_res = self.conv1d_params1(to_conv1d) conv1d_res = F.relu(conv1d_res) conv1d_res = self.maxpool1d(conv1d_res) conv1d_res = self.conv1d_params2(conv1d_res) conv1d_res = F.relu(conv1d_res) to_dense = conv1d_res.view(batch_sortpooling_graphs.size(0), -1) if self.output_dim > 0: out_linear = self.out_params(to_dense) reluact_fp = F.relu(out_linear) else: reluact_fp = to_dense return F.relu(reluact_fp)
def forward(self, x): return global_sort_pool(x=x, batch=torch.tensor( [0 for i in range(x.size()[0])], dtype=torch.long, device=x.device), k=self.k)
def forward(self, data): # Implement Equation 4.2 of the paper i.e. concat all layers' graph representations and apply linear model # note: this can be decomposed in one smaller linear model per layer x, edge_index, batch = data.x, data.edge_index, data.batch hidden_repres = [] for conv in self.convs: x = torch.tanh(conv(x, edge_index)) hidden_repres.append(x) # apply sortpool x_to_sortpool = torch.cat(hidden_repres, dim=1) x_1d = global_sort_pool(x_to_sortpool, batch, self.k) # in the code the authors sort the last channel only # apply 1D convolutional layers x_1d = torch.unsqueeze(x_1d, dim=1) conv1d_res = F.relu(self.conv1d_params1(x_1d)) conv1d_res = self.maxpool1d(conv1d_res) conv1d_res = F.relu(self.conv1d_params2(conv1d_res)) conv1d_res = conv1d_res.reshape(conv1d_res.shape[0], -1) # apply dense layer out_dense = self.dense_layer(conv1d_res) return out_dense
def forward(self, z, edge_index, batch, x=None, edge_weight=None, node_id=None): z_emb = self.z_embedding(z) if z_emb.ndim == 3: # in case z has multiple integer labels z_emb = z_emb.sum(dim=1) if self.use_feature and x is not None: x = torch.cat([z_emb, x.to(torch.float)], 1) else: x = z_emb if self.node_embedding is not None and node_id is not None: n_emb = self.node_embedding(node_id) x = torch.cat([x, n_emb], 1) xs = [x] for conv in self.convs: xs += [torch.tanh(conv(xs[-1], edge_index, edge_weight))] x = torch.cat(xs[1:], dim=-1) # Global pooling. if self.random_pool: x = global_random_pool(x, batch, self.k) else: x = global_sort_pool(x, batch, self.k) x = x.unsqueeze(1) # [num_graphs, 1, k * hidden] x = F.relu(self.conv1(x)) x = self.maxpool1d(x) x = F.relu(self.conv2(x)) x = x.view(x.size(0), -1) # [num_graphs, dense_dim] # MLP. x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return x
def forward(self, data): # Implement Equation 4.2 of the paper i.e. concat all layers' graph representations and apply linear model # note: this can be decomposed in one smaller linear model per layer x, edge_index, batch = data.x, data.edge_index, data.batch # print("in DGCNN:",edge_index.shape) hidden_repres = [] # w 存储中间的输出 for conv in self.convs: x = torch.tanh(conv(x, edge_index)) hidden_repres.append(x) # apply sortpool x_to_sortpool = torch.cat(hidden_repres, dim=1) x_1d = global_sort_pool( x_to_sortpool, batch, self.k) # in the code the authors sort the last channel only #w x_id 的shape: b, k*f # apply 1D convolutional layers x_1d = torch.unsqueeze( x_1d, dim=1 ) #w x_id 的shape: b,1, k*f ,nn.Conv1d的输入是N,Cin,signal,也就是Cin个平面,f=total_latent_dim conv1d_res = F.relu(self.conv1d_params1(x_1d)) conv1d_res = self.maxpool1d(conv1d_res) conv1d_res = F.relu(self.conv1d_params2(conv1d_res)) conv1d_res = conv1d_res.reshape(conv1d_res.shape[0], -1) #w 把平面去点,得到b,32*... # apply dense layer out_dense = self.dense_layer(conv1d_res) return out_dense
def forward(self, x, edge_index, edge_attr, *args, **kwargs): """ multi-dimensional edge_attr is implemented as separate channels and concatenated before dense layer. :param x: Node feature matrix with shape [num_nodes, num_node_features] :param edge_index: Graph connectivity in COO format with shape [2, num_edges] and type torch.long :param edge_attr: Edge feature matrix with shape [num_edges, num_edge_features], num_edge_features >= 1 :param args: :param kwargs: :return: """ print(getattr(self, "conv1d1_channel1")) if x.dim() == 3: # a batch if x.shape[0] != 1: raise Exception("batch size greater than 1 is not supported.") x, edge_index, edge_attr = x.squeeze(0), edge_index.squeeze( 0), edge_attr.squeeze(0) for i, channel in enumerate(range(self.channels)): e = edge_attr[:, i] xx = self.gcnconv1[channel](x, edge_index, e) xx = torch.cat([ xx, self.gcnconv2[channel](xx[:, -1].unsqueeze(-1), edge_index, e) ], dim=-1) xx = torch.cat([ xx, self.gcnconv3[channel](xx[:, -1].unsqueeze(-1), edge_index, e) ], dim=-1) xx = torch.cat([ xx, self.gcnconv4[channel](xx[:, -1].unsqueeze(-1), edge_index, e) ], dim=-1) xx = torch.cat([ xx, self.gcnconv5[channel](xx[:, -1].unsqueeze(-1), edge_index, e) ], dim=-1) N, D = xx.size() xx = global_sort_pool(x=xx, batch=torch.tensor( [0 for i in range(self.num_nodes)], dtype=torch.long), k=self.num_nodes) xx = xx.view(-1, N, D).permute(0, 2, 1) xx = self.conv1d1[channel](xx) xx = self.maxpool1[channel](xx) xx = self.conv1d2[channel](xx) xx = self.maxpool2[channel](xx) all_x = xx if i == 0 else torch.cat([all_x, xx], dim=0) x = all_x.view(1, -1) x = F.elu(self.drop1(self.fc1(x))) x = F.elu(self.drop2(self.fc2(x))) x = self.fc3(x) return x
def forward(self, data): x, edge_index, edge_type, batch = data.x, data.edge_index, data.edge_type, data.batch if self.adj_dropout > 0: edge_index, edge_type = dropout_adj( edge_index, edge_type, p=self.adj_dropout, force_undirected=self.force_undirected, num_nodes=len(x), training=self.training) concat_states = [] for conv in self.convs: x = torch.tanh(conv(x, edge_index, edge_type)) concat_states.append(x) concat_states = torch.cat(concat_states, 1) x = global_sort_pool(concat_states, batch, self.k) # batch * (k*hidden) x = x.unsqueeze(1) # batch * 1 * (k*hidden) x = F.relu(self.conv1d_params1(x)) x = self.maxpool1d(x) x = F.relu(self.conv1d_params2(x)) x = x.view(len(x), -1) # flatten x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) if self.regression: return x[:, 0] else: return F.log_softmax(x, dim=-1)
def forward(self, data): x = self.atom_encoder(data.x) edge_index = data.edge_index edge_attr = data.edge_attr edge_attr = torch.LongTensor([ edge_type[0] + edge_type[1] * 5 + edge_type[2] * 30 for edge_type in edge_attr ]).to(self.device) for i in range(len(self.rgcn_list)): x_rgcn = self.rgcn_list[i](x, edge_index, edge_attr) x_gconv = self.graphconv_list[i](x, edge_index) x = torch.cat((x_rgcn, x_gconv), 1) x = F.relu(x) if i == len(self.rgcn_list) - 1: continue x = self.batchnorm(x) # x = self.graph_conv(x,edge_index) # x = F.relu(x) if self.pool_layer == 'add': x = global_add_pool(x, data.batch) if self.pool_layer == 'mean': x = global_mean_pool(x, data.batch) if self.pool_layer == 'max': x = global_max_pool(x, data.batch) if self.pool_layer == 'sort': x = global_sort_pool(x, data.batch, self.k) x = F.relu(self.linear1(x)) x = self.linear2(x) return x
def forward(self, sample): x, edge_index = sample.x, sample.edge_index # Dropout layer # edge_index = self.dropout_edges(edge_index, dropout=0.2) x = self.dense_input(x, self.empty_edges) x = F.gelu(x) x = self.input(x, edge_index) x = F.gelu(x) x = self.conv1(x, edge_index) x = F.gelu(x) # if self.pooling_layers > 1: # batch = torch.tensor([0 for _ in x], dtype=torch.long, device=self.device) # pooled = self.topkpool1(x, edge_index, batch=batch) # x, edge_index = pooled[0], pooled[1] x = self.conv2(x, edge_index) x = F.gelu(x) if self.pooling_layers > 0: batch = torch.tensor([0 for _ in x], dtype=torch.long, device=self.device) pooled = self.topkpool2(x, edge_index, batch=batch) x, edge_index = pooled[0], pooled[1] x = self.conv3(x, edge_index) x = F.gelu(x) # For large graphs # while len(x) > 8: # batch = torch.tensor([0 for _ in x], dtype=torch.long, device=self.device) # pooled = self.topkpool3(x, edge_index, batch=batch) # x, edge_index = pooled[0], pooled[1] # x = self.conv4(x, edge_index) # x = F.gelu(x) batch = torch.tensor([0 for _ in x], dtype=torch.long, device=self.device) # With sort_pool it works but we have the same problem: the output layer learns the order of the pooled nodes # using k = 3, let's see what happens by shuffling the nodes if self.final_pooling == "avg_pool_x": cluster = torch.as_tensor([i % self.final_nodes for i in range(len(x))], device=self.device) (x, cluster) = avg_pool_x(cluster, x, batch) elif self.final_pooling == "sort_pooling": x = global_sort_pool(x, batch, self.final_nodes) elif self.final_pooling == "topk" or self.final_pooling == "asap" or self.final_pooling == "sag": pooled = self.last_pooling_layer(x, edge_index) x = pooled[0] elif self.final_pooling == "max_pool_x": cluster = torch.as_tensor([i % self.final_nodes for i in range(len(x))], device=self.device) (x, cluster) = max_pool_x(cluster, x, batch) # (x2, cluster2) = avg_pool_x(cluster, x, batch) # x = torch.cat([x1.view(-1), x2.view(-1)]) return self.output(x.view(-1))
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) for conv in self.convs: x = F.relu(conv(x, edge_index)) x = global_sort_pool(x, batch, self.k) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1)
def forward(self, x, batch): batch_size = batch.max() + 1 re = global_sort_pool(x, batch, self.k) # re shape bs * (k * out_dim) 2910 re = re.unsqueeze(2).transpose(2, 1) conv1 = F.relu(self.conv1d_p1(re)) conv1 = self.pool(conv1) conv2 = F.relu(self.conv1d_p2(conv1)) to_dense = conv2.view(batch_size, -1) out = F.relu(to_dense) return out
def forward(self, x, edge_index, edge_attr, *args, **kwargs): """ multi-dimensional edge_attr is implemented as separate channels and concatenated before dense layer. :param x: Node feature matrix with shape [num_nodes, num_node_features] :param edge_index: Graph connectivity in COO format with shape [2, num_edges] and type torch.long :param edge_attr: Edge feature matrix with shape [num_edges, num_edge_features], num_edge_features >= 1 :param args: :param kwargs: :return: """ if x.dim() == 3: # a batch if x.shape[0] != 1: raise Exception("batch size greater than 1 is not supported.") x, edge_index, edge_attr = x.squeeze(0), edge_index.squeeze( 0), edge_attr.squeeze(0) for i, channel in enumerate(range(self.channels)): e = edge_attr[:, i] # gcn conv for conv_level in range(1, self.gcn_conv_level + 1): conv = getattr( self, "gcnconv{}_channel{}".format(conv_level, channel)) if conv_level == 1: xx = conv(x, edge_index, e) else: xx = torch.cat( [xx, conv(xx[:, -1].view(-1, 1), edge_index, e)], dim=-1) # sort pool N, D = xx.size() xx = global_sort_pool(x=xx, batch=torch.tensor( [0 for i in range(self.num_nodes)], dtype=torch.long, device=xx.device), k=self.num_nodes) xx = xx.view(-1, N, D).permute(0, 2, 1) # conv and pool xx = getattr(self, "conv1d1_channel{}".format(channel))(xx) xx = getattr(self, "maxpool1_channel{}".format(channel))(xx) xx = getattr(self, "conv1d2_channel{}".format(channel))(xx) xx = getattr(self, "maxpool2_channel{}".format(channel))(xx) all_x = xx if i == 0 else torch.cat([all_x, xx], dim=0) x = all_x.view(1, -1) x = F.elu(self.drop1(self.fc1(x))) x = F.elu(self.drop2(self.fc2(x))) x = self.fc3(x) print(x) return x
def forward(self, x, edge_index, batch): # x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) for conv in self.convs: x = F.relu(conv(x, edge_index)) x = global_sort_pool(x, batch, self.k) x = x.view(len(x), self.k, -1).permute(0, 2, 1) x = F.relu(self.conv1d(x)) x = x.view(len(x), -1) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return x
def forward(self, feats_in_l, idx_targets, sizes_subg): if self.type_pool == 'center': if self.type_res == 'none': return feats_in_l[-1][idx_targets] else: # regular JK feats_root_l = [f[idx_targets] for f in feats_in_l] feat_in = self.f_residue(feats_root_l) elif self.type_pool in ['max', 'mean', 'sum']: # first pool subgraph at each layer, then residue offsets = torch.cumsum(sizes_subg, dim=0) offsets = torch.roll(offsets, 1) offsets[0] = 0 idx = torch.arange(feats_in_l[-1].shape[0]).to( feats_in_l[-1].device) if self.type_res == 'none': feat_pool = F.embedding_bag(idx, feats_in_l[-1], offsets, mode=self.type_pool) feat_root = feats_in_l[-1][idx_targets] else: feat_pool_l = [] for feat in feats_in_l: feat_pool = F.embedding_bag(idx, feat, offsets, mode=self.type_pool) feat_pool_l.append(feat_pool) feat_pool = self.f_residue(feat_pool_l) feat_root = self.f_residue( [f[idx_targets] for f in feats_in_l]) feat_in = torch.cat([feat_root, feat_pool], dim=1) elif self.type_pool == 'sort': if self.type_res == 'none': feat_pool_in = feats_in_l[-1] feat_root = feats_in_l[-1][idx_targets] else: feat_pool_in = self.f_residue(feats_in_l) feat_root = self.f_residue( [f[idx_targets] for f in feats_in_l]) arange = torch.arange(sizes_subg.size(0)).to(sizes_subg.device) idx_batch = torch.repeat_interleave(arange, sizes_subg) feat_pool_k = global_sort_pool(feat_pool_in, idx_batch, self.k) # #subg x (k * F) feat_pool = self.nn_pool(feat_pool_k) feat_in = torch.cat([feat_root, feat_pool], dim=1) else: raise NotImplementedError return self.f_norm(self.nn(feat_in))
def forward(self, x, edge_index, batch): xs = [x] for conv in self.convs: xs += [conv(xs[-1], edge_index).tanh()] x = torch.cat(xs[1:], dim=-1) # Global pooling. x = global_sort_pool(x, batch, self.k) x = x.unsqueeze(1) # [num_graphs, 1, k * hidden] x = self.conv1(x).relu() x = self.maxpool1d(x) x = self.conv2(x).relu() x = x.view(x.size(0), -1) # [num_graphs, dense_dim] return self.mlp(x)
def forward(self, data): x, edge_index, edge_weight, batch = data.x, data.edge_index, data.edge_weight, data.batch edge_index2, edge_weight2 = data.edge_index2, data.edge_weight2 x0,x1,x2 = self.ib1(x, edge_index, edge_weight, edge_index2, edge_weight2) x0 = F.dropout(x0, p=self.dropout, training=self.training) x1 = F.dropout(x1, p=self.dropout, training=self.training) x2 = F.dropout(x2, p=self.dropout, training=self.training) x = x0+x1+x2 x = F.dropout(x, p=self.dropout, training=self.training) x0,x1,x2 = self.ib2(x, edge_index, edge_weight, edge_index2, edge_weight2) x0 = F.dropout(x0, p=self.dropout, training=self.training) x1 = F.dropout(x1, p=self.dropout, training=self.training) x2 = F.dropout(x2, p=self.dropout, training=self.training) x = x0+x1+x2 x = F.dropout(x, p=self.dropout, training=self.training) x0,x1,x2 = self.ib3(x, edge_index, edge_weight, edge_index2, edge_weight2) x0 = F.dropout(x0, p=self.dropout, training=self.training) x1 = F.dropout(x1, p=self.dropout, training=self.training) x2 = F.dropout(x2, p=self.dropout, training=self.training) x = x0+x1+x2 x0,x1,x2 = self.final(x, edge_index, edge_weight, edge_index2, edge_weight2) x0 = F.dropout(x0, p=self.dropout, training=self.training) x1 = F.dropout(x1, p=self.dropout, training=self.training) x2 = F.dropout(x2, p=self.dropout, training=self.training) y = x0+x1+x2 x = torch.cat([x, y], 1) # Global pooling. x = global_sort_pool(x, batch, self.k) x = x.unsqueeze(1) # [num_graphs, 1, k * hidden] x = F.relu(self.conv1(x)) x = self.maxpool1d(x) x = F.relu(self.conv2(x)) x = x.view(x.size(0), -1) # [num_graphs, dense_dim] # MLP. x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return x
def test_global_sort_pool(): N_1, N_2 = 4, 6 x = torch.randn(N_1 + N_2, 4) batch = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)]) out = global_sort_pool(x, batch, k=5) assert out.size() == (2, 5 * 4) out = out.view(2, 5, 4) # First graph output has been filled up with zeros. assert out[0, -1].tolist() == [0, 0, 0, 0] # Nodes are sorted. expected = 3 - torch.arange(4) assert out[0, :4, -1].argsort().tolist() == expected.tolist() expected = 4 - torch.arange(5) assert out[1, :, -1].argsort().tolist() == expected.tolist()
def sortpooling_embedding(self, batched_data): x, edge_index, node_depth, batch = batched_data.x, batched_data.edge_index, batched_data.node_depth, batched_data.batch x = self.node_encoder(x, node_depth.view(-1,)) node_feat = x edge_feat = batched_data.edge_attr if hasattr(batched_data, 'edge_attr') and \ batched_data.edge_attr is not None else None ''' if exists edge feature, concatenate to node feature vector ''' if edge_feat is not None: # we added inverse before.. # edge_index = torch.cat([edge_index, torch.stack([edge_index[1],edge_index[0]], dim=0)], dim=-1) e2n_sp = torch.zeros(x.shape[0], edge_index.shape[1]).to(edge_feat.device).scatter_(0, edge_index, 1) e2npool_input = torch.mm(e2n_sp, edge_feat) node_feat = torch.cat([node_feat, e2npool_input], 1) ''' graph convolution layers ''' cur_message_layer = self.compute_message_layers(node_feat, edge_index, batch) # put in extra function to reuse rest in unet ''' sortpooling layer ''' batch_sortpooling_graphs = global_sort_pool(cur_message_layer, batch, self.k) ''' traditional 1d convlution and dense layers ''' to_conv1d = batch_sortpooling_graphs.view((-1, 1, self.k * self.total_latent_dim)) conv1d_res = self.conv1d_params1(to_conv1d) conv1d_res = self.conv1d_activation(conv1d_res) conv1d_res = self.maxpool1d(conv1d_res) conv1d_res = self.conv1d_params2(conv1d_res) conv1d_res = self.conv1d_activation(conv1d_res) to_dense = conv1d_res.view(conv1d_res.shape[0], -1) # if self.output_dim > 0: # out_linear = self.out_params(to_dense) # reluact_fp = self.conv1d_activation(out_linear) # else: # reluact_fp = to_dense # # return self.mlp(self.conv1d_activation(reluact_fp)) pred_list = [] for i in range(self.max_seq_len): pred_list.append(self.graph_pred_linear_list[i](to_dense)) return pred_list
def test_global_sort_pool(): N_1, N_2 = 4, 6 x = torch.randn(N_1 + N_2, 4) batch = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)]) out = global_sort_pool(x, batch, k=5) assert out.size() == (2, 5 * 4) out = out.view(2, 5, 4) # Features are individually sorted. expected = torch.arange(4).view(1, 1, 4).expand_as(out) assert out.argsort(dim=2).tolist() == expected.tolist() # First graph output has been filled up with zeros. assert out[0, -1].tolist() == [0, 0, 0, 0] # Nodes are sorted. expected = 4 - torch.arange(5).view(1, 5).expand(2, 5) assert out.argsort(dim=1)[:, :, -1].tolist() == expected.tolist()
def forward(self, x, edge_index, batch): xs = [x] for conv in self.convs: xs += [torch.tanh(conv(xs[-1], edge_index))] x = torch.cat(xs[1:], dim=-1) # Global pooling. x = global_sort_pool(x, batch, self.k) x = x.unsqueeze(1) # [num_graphs, 1, k * hidden] x = F.relu(self.conv1(x)) x = self.maxpool1d(x) x = F.relu(self.conv2(x)) x = x.view(x.size(0), -1) # [num_graphs, dense_dim] # MLP. x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return x
def sortpooling_embedding_tg(self, data): ''' if exists edge feature, concatenate to node feature vector ''' node_feat, edge_index, batch = data.x, data.edge_index, data.batch # TODO: remove edge_attr consideration # if data.edge_attr is not None: # input_edge_linear = self.w_e2l(data.edge_attr) # e2npool_input = gnn_spmm(e2n_sp, input_edge_linear) # node_feat = torch.cat([node_feat, e2npool_input], 1) '''G-UNet Layer to process the graph data''' # the output feature dimension of gUnet is cur_message_layer = self.gUnet(x=node_feat, edge_index=edge_index, batch=batch) # X = torch.cat([cur_message_layer, node_feat], 1) # cur_message_layer = self.end_gcn(X, edge_index=edge_index, batch=batch) ''' sortpooling layer ''' # the shape of global_sort_pool is (B, k*total_latent_dim) batch_sortpooling_graphs = global_sort_pool(cur_message_layer, batch, self.k) ''' traditional 1d convlution and dense layers ''' to_conv1d = batch_sortpooling_graphs.view( (-1, 1, self.k * self.total_latent_dim)) conv1d_res = self.conv1d_params1(to_conv1d) conv1d_res = F.relu(conv1d_res) conv1d_res = self.maxpool1d(conv1d_res) conv1d_res = self.conv1d_params2(conv1d_res) conv1d_res = F.relu(conv1d_res) to_dense = conv1d_res.view(batch_sortpooling_graphs.size(0), -1) if self.output_dim > 0: out_linear = self.out_params(to_dense) reluact_fp = F.relu(out_linear) else: reluact_fp = to_dense return F.relu(reluact_fp)
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch edge_index, _ = remove_self_loops(edge_index) x_1 = torch.tanh(self.conv1(x, edge_index)) x_2 = torch.tanh(self.conv2(x_1, edge_index)) x_3 = torch.tanh(self.conv3(x_2, edge_index)) x_4 = torch.tanh(self.conv4(x_3, edge_index)) x = torch.cat([x_1, x_2, x_3, x_4], dim=-1) x = global_sort_pool(x, batch, k=30) x = x.view(x.size(0), 1, x.size(-1)) x = self.relu(self.conv5(x)) x = self.pool(x) x = self.relu(self.conv6(x)) x = x.view(x.size(0), -1) out = self.relu(self.classifier_1(x)) out = self.drop_out(out) classes = F.log_softmax(self.classifier_2(out), dim=-1) return classes
def test_global_sort_pool_smaller_than_k(): N_1, N_2 = 4, 6 x = torch.randn(N_1 + N_2, 4) batch = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)]) # Set k which is bigger than both N_1=4 and N_2=6. out = global_sort_pool(x, batch, k=10) assert out.size() == (2, 10 * 4) out = out.view(2, 10, 4) # Both graph outputs have been filled up with zeros. assert out[0, -1].tolist() == [0, 0, 0, 0] assert out[1, -1].tolist() == [0, 0, 0, 0] # Nodes are sorted. expected = 3 - torch.arange(4) assert out[0, :4, -1].argsort().tolist() == expected.tolist() expected = 5 - torch.arange(6) assert out[1, :6, -1].argsort().tolist() == expected.tolist()
def forward(self, data): # Implement Equation 4.2 of the paper i.e. concat all layers' graph representations and apply linear model # note: this can be decomposed in one smaller linear model per layer x, edge_index, batch = data.x, data.edge_index, data.batch # print("in DGCNN:",edge_index.shape) hidden_repres = [] for conv in self.convs: x = torch.tanh(conv(x, edge_index)) hidden_repres.append(x) # apply sortpool x_to_sortpool = torch.cat(hidden_repres, dim=1) x_1d = global_sort_pool( x_to_sortpool, batch, self.k) # in the code the authors sort the last channel only # apply 1D convolutional layers x_1d = torch.unsqueeze(x_1d, dim=1) conv1d_res = F.relu(self.conv1d_params1(x_1d)) conv1d_res = self.maxpool1d(conv1d_res) conv1d_res = F.relu(self.conv1d_params2(conv1d_res)) conv1d_res = conv1d_res.reshape(conv1d_res.shape[0], -1) # apply dense layer out_dense = self.dense_layer(conv1d_res) # Don't apply sigmoid during training b/c using BCEWithLogitsLoss if self.classification and not self.training: out_dense = self.sigmoid(out_dense) if self.multiclass: out_dense = out_dense.reshape( (out_dense.size(0), -1, self.multiclass_num_classes )) # batch size x num targets x num classes per target if not self.training: out_dense = self.multiclass_softmax( out_dense ) # to get probabilities during evaluation, but not during training as we're using CrossEntropyLoss return out_dense
def forward(self, x, edge_index, edge_weights, batch, debug=False): """ :param x: nodes' features (Nodes x Features_In) :param edge_index: COO-formatted sparse graph edges(2 x Edges) :param edge_weights: the weights of corresponding edges :param batch: actually just a graph labels. Indicates, which data belongs to the specific graph in the batch. See pytorch geometric docs. :param debug: whether to print shapes after each layer or not :return: a tensor containing labels for every graph in the batch """ if debug: print(x.shape) x = self.sage1(x, edge_index, edge_weights) # (NxFI, NxN) --> (NxFO) if debug: print(x.shape) x, edge_index, edge_weights, batch, _, _ = self.pooling1(x, edge_index, edge_weights, batch) if debug: print(x.shape) x = func.dropout(x, training=self.training) if debug: print(x.shape) x = self.sage2(x, edge_index, edge_weights) if debug: print(x.shape) x, edge_index, edge_weights, batch, _, _ = self.pooling2(x, edge_index, edge_weights, batch) x = func.dropout(x, training=self.training) if debug: print(x.shape) x = gnn.global_sort_pool(x, batch, 10) if debug: print(x.shape) x = func.relu(self.lin1(x)) if debug: print(x.shape) x = func.relu(self.lin2(x)) if debug: print(x.shape) return x
def forward(self, data): x = self.atom_encoder(data.x) edge_index = data.edge_index for i, layer in enumerate(self.graph_conv_list) : x = layer(x, edge_index) x = F.relu(x) if i == len(self.graph_conv_list) - 1: continue x = self.batchnorm(x) if self.pool_layer == 'add': x = global_add_pool(x, data.batch) if self.pool_layer == 'mean': x = global_mean_pool(x, data.batch) if self.pool_layer == 'max': x = global_max_pool(x, data.batch) if self.pool_layer == 'sort': x = global_sort_pool(x, data.batch, self.k) x = F.relu(self.linear1(x)) x = self.linear2(x) return x
def forward(self, data): x = data.x if self.args.use_feature else None z1, z2, w, edge_index, batch, x, edge_weight, node_id = data.z1, data.z2, data.w, data.edge_index, data.batch, x, data.edge_weight, None z1_emb = self.z1_embedding(z1) z2_emb = self.z2_embedding(z2) w_emb = self.w_embedding(w) z_emb = torch.cat([w_emb, z1_emb], 1) z_emb = torch.cat([z_emb, z2_emb], 1) if self.use_feature and x is not None: x = torch.cat([z_emb, x.to(torch.float)], 1) else: x = z_emb if self.node_embedding is not None and node_id is not None: n_emb = self.node_embedding(node_id) x = torch.cat([x, n_emb], 1) xs = [x] for conv in self.convs: xs += [torch.tanh(conv(xs[-1], edge_index))] x = torch.cat(xs[1:], dim=-1) # Global pooling. x = global_sort_pool(x, batch, self.k) x = x.unsqueeze(1) # [num_graphs, 1, k * hidden] x = F.relu(self.conv1(x)) x = self.maxpool1d(x) x = F.relu(self.conv2(x)) x = x.view(x.size(0), -1) # [num_graphs, dense_dim] # MLP. x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = F.relu(self.lin2(x)) return x
def forward(self, x, edge_index, batch=None): attn_weights = dict() x = F.one_hot(x, num_classes=self.config.num_feature_dim).float() for layer in range(self.config.num_layers - 1): x = F.relu(self.gin_convs[layer](x, edge_index)) x = self.batch_norms[layer](x) if self.config.pooling_type == "sagpool": x, edge_index, _, batch, attn_weights['pool_perm'], attn_weights[ 'pool_score'] = self.pool1(x, edge_index, batch=batch) elif self.config.pooling_type == "topk": x, edge_index, _, batch, attn_weights['pool_perm'], attn_weights[ 'pool_score'] = self.pool1(x, edge_index, batch=batch) elif self.config.pooling_type == "asa": x, edge_index, _, batch, attn_weights['pool_perm'] = self.pool1( x, edge_index, batch=batch) else: pass if self.config.readout_type == "add": x = global_add_pool(x, batch) elif self.config.readout_type == "mean": x = global_mean_pool(x, batch) elif self.config.readout_type == "max": x = global_max_pool(x, batch) elif self.config.readout_type == "sort": x = global_sort_pool(x, batch, k=100) else: pass attn_weights['batch'] = batch x = F.relu(self.fc1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.fc2(x) return F.log_softmax(x, dim=-1), attn_weights