def loopy_bp(self, node_feat, edge_feat, n2e_sp, e2e_sp, e2n_sp, subg_sp): input_node_linear = self.w_n2l(node_feat) input_edge_linear = self.w_e2l(edge_feat) n2epool_input = gnn_spmm(n2e_sp, input_node_linear) input_message = input_edge_linear + n2epool_input input_potential = F.relu(input_message) lv = 0 cur_message_layer = input_potential while lv < self.max_lv: e2epool = gnn_spmm(e2e_sp, cur_message_layer) edge_linear = self.conv_params(e2epool) merged_linear = edge_linear + input_message cur_message_layer = F.relu(merged_linear) lv += 1 e2npool = gnn_spmm(e2n_sp, cur_message_layer) hidden_msg = F.relu(e2npool) out_linear = self.out_params(hidden_msg) reluact_fp = F.relu(out_linear) y_potential = gnn_spmm(subg_sp, reluact_fp) return F.relu(y_potential)
def mean_field(self, node_feat, edge_feat, n2n_sp, e2n_sp, subg_sp): input_node_linear = self.w_n2l(node_feat) input_message = input_node_linear if edge_feat is not None: input_edge_linear = self.w_e2l(edge_feat) e2npool_input = gnn_spmm(e2n_sp, input_edge_linear) input_message += e2npool_input input_potential = F.relu(input_message) lv = 0 cur_message_layer = input_potential while lv < self.max_lv: n2npool = gnn_spmm(n2n_sp, cur_message_layer) node_linear = self.conv_params( n2npool ) merged_linear = node_linear + input_message cur_message_layer = F.relu(merged_linear) lv += 1 if self.output_dim > 0: out_linear = self.out_params(cur_message_layer) reluact_fp = F.relu(out_linear) else: reluact_fp = cur_message_layer y_potential = gnn_spmm(subg_sp, reluact_fp) return F.relu(y_potential)
def node_level_embedding(self, node_feat, edge_feat, n2n_sp, e2n_sp, node_degs): input_node_linear = self.w_n2l( node_feat ) # Question: could try to remove this layer. -->> this is for channels matching, hard to remove. input_node_linear = self.bn1(input_node_linear) input_message = input_node_linear if edge_feat is not None: input_edge_linear = self.w_e2l(edge_feat) input_edge_linear = self.bne1(input_edge_linear) e2npool_input = gnn_spmm(e2n_sp, input_edge_linear) input_message += e2npool_input input_potential = F.relu(input_message) if self.k_hop_embedding: block = 0 cur_message_layer = input_potential A = n2n_sp # .to_dense() # 1 hop information while block < self.max_block: if block == 0: block_input = cur_message_layer else: block_input = cur_message_layer + input_potential h = self.multi_hop_embedding(block_input, A, node_degs, input_message) h = F.relu( h) # fixme: do we need this relu after the block, may not. cur_message_layer = h block += 1 else: # simple aggregate the node features from neighbors, the same as structure2vec lv = 0 cur_message_layer = input_potential while lv < self.max_lv: n2npool = gnn_spmm(n2n_sp, cur_message_layer) node_linear = self.conv_params(n2npool) #layer3 self.bn2 = nn.BatchNorm1d(latent_dim) merged_linear = node_linear + input_message cur_message_layer = F.relu(merged_linear) lv += 1 if self.output_dim > 0: out_linear = self.out_params(cur_message_layer) reluact_fp = F.relu(out_linear) else: reluact_fp = cur_message_layer return reluact_fp
def sortpooling_embedding(self, node_feat, edge_feat, n2n_sp, e2n_sp, subg_sp, graph_sizes, node_degs): ''' if exists edge feature, concatenate to node feature vector ''' if edge_feat is not None: input_edge_linear = self.w_e2l(edge_feat) e2npool_input = gnn_spmm(e2n_sp, input_edge_linear) node_feat = torch.cat([node_feat, e2npool_input], 1) ''' graph convolution layers ''' lv = 0 cur_message_layer = node_feat cat_message_layers = [] while lv < len(self.latent_dim): n2npool = gnn_spmm( n2n_sp, cur_message_layer) + cur_message_layer # Y = (A + I) * X node_linear = self.conv_params[lv](n2npool) # Y = Y * W normalized_linear = node_linear.div(node_degs) # Y = D^-1 * Y cur_message_layer = F.tanh(normalized_linear) cat_message_layers.append(cur_message_layer) lv += 1 cur_message_layer = torch.cat(cat_message_layers, 1) ''' sortpooling layer ''' sort_channel = cur_message_layer[:, -1] batch_sortpooling_graphs = torch.zeros(len(graph_sizes), self.k, self.total_latent_dim) if isinstance(node_feat.data, torch.cuda.FloatTensor): batch_sortpooling_graphs = batch_sortpooling_graphs.cuda() batch_sortpooling_graphs = Variable(batch_sortpooling_graphs) accum_count = 0 for i in range(subg_sp.size()[0]): to_sort = sort_channel[accum_count:accum_count + graph_sizes[i]] k = self.k if self.k <= graph_sizes[i] else graph_sizes[i] _, topk_indices = to_sort.topk(k) topk_indices += accum_count sortpooling_graph = cur_message_layer.index_select(0, topk_indices) if k < self.k: to_pad = torch.zeros(self.k - k, self.total_latent_dim) if isinstance(node_feat.data, torch.cuda.FloatTensor): to_pad = to_pad.cuda() to_pad = Variable(to_pad) sortpooling_graph = torch.cat((sortpooling_graph, to_pad), 0) batch_sortpooling_graphs[i] = sortpooling_graph accum_count += graph_sizes[i] return batch_sortpooling_graphs
def process_sparse(graph_list, node_feat, edge_feat): graph_sizes = [graph_list[i].num_nodes for i in range(len(graph_list))] node_degs = [ torch.Tensor(graph_list[i].degs) + 1 for i in range(len(graph_list)) ] node_degs = torch.cat(node_degs).unsqueeze(1) n2n_sp, e2n_sp, subg_sp = GNNLIB.PrepareSparseMatrices(graph_list) n2n_sp = n2n_sp.cuda() e2n_sp = e2n_sp.cuda() subg_sp = subg_sp.cuda() node_degs = node_degs.cuda() node_feat = Variable(node_feat) if edge_feat is not None: edge_feat = Variable(edge_feat) if torch.cuda.is_available() and isinstance(node_feat, torch.cuda.FloatTensor): edge_feat = edge_feat.cuda() n2n_sp = Variable(n2n_sp) e2n_sp = Variable(e2n_sp) subg_sp = Variable(subg_sp) node_degs = Variable(node_degs) if edge_feat is not None: input_edge_linear = edge_feat e2npool_input = gnn_spmm(e2n_sp, input_edge_linear) node_feat = torch.cat([node_feat, e2npool_input], 1) return graph_sizes, n2n_sp, e2n_sp, subg_sp, node_degs
def attention_gcn(self, node_feat, edge_feat, n2n_sp, e2n_sp, node_degs): input_node_linear = self.w_n2l(node_feat) input_node_linear = self.bn1(input_node_linear) input_message = input_node_linear if edge_feat is not None: input_edge_linear = self.w_e2l(edge_feat) input_edge_linear = self.bne1(input_edge_linear) e2npool_input = gnn_spmm(e2n_sp, input_edge_linear) input_message += e2npool_input input_potential = F.relu(input_message) block = 0 cur_message_layer = input_potential A = n2n_sp while block < self.max_block: if block == 0: block_input = cur_message_layer else: block_input = cur_message_layer + input_potential h = self.multi_hop_embedding(block_input, A, node_degs, input_message) h = F.relu(h) cur_message_layer = h block += 1 if self.output_dim > 0: out_linear = self.out_params(cur_message_layer) reluact_fp = F.relu(out_linear) else: reluact_fp = cur_message_layer return reluact_fp
def deepsets_embedding(self, node_feat, edge_feat, n2n_sp, e2n_sp, subg_sp, graph_sizes, node_degs): ''' if exists edge feature, concatenate to node feature vector ''' if edge_feat is not None: input_edge_linear = self.w_e2l(edge_feat) e2npool_input = gnn_spmm(e2n_sp, input_edge_linear) node_feat = torch.cat([node_feat, e2npool_input], 1) ''' graph convolution layers ''' lv = 0 cur_message_layer = node_feat cat_message_layers = [] while lv < len(self.latent_dim): #original=Variable(cur_message_layer) n2npool = gnn_spmm(n2n_sp, cur_message_layer) + cur_message_layer # Y = (A + I) * X node_linear = self.conv_params[lv](n2npool) # Y = Y * W normalized_linear = node_linear.div(node_degs) # Y = D^-1 * Y cur_message_layer = F.tanh(normalized_linear) #NICk possibly substitute with relu cat_message_layers.append(cur_message_layer) lv += 1 cur_message_layer = torch.cat(cat_message_layers, 1) ''' sortpooling layer ''' max_size=max(graph_sizes) batch_sortpooling_graphs = torch.zeros(len(graph_sizes), max_size, self.total_latent_dim) if isinstance(node_feat.data, torch.cuda.FloatTensor): batch_sortpooling_graphs = batch_sortpooling_graphs.cuda() batch_sortpooling_graphs = Variable(batch_sortpooling_graphs) accum_count = 0 for i in range(subg_sp.size()[0]): R=torch.cuda.LongTensor(range(accum_count,accum_count + graph_sizes[i])) k = graph_sizes[i] sortpooling_graph = cur_message_layer.index_select(0, R) if k < max_size: to_pad = torch.zeros(max_size-k, self.total_latent_dim) if isinstance(node_feat.data, torch.cuda.FloatTensor): to_pad = to_pad.cuda() to_pad = Variable(to_pad) sortpooling_graph = torch.cat((sortpooling_graph, to_pad), 0) batch_sortpooling_graphs[i] = sortpooling_graph accum_count += graph_sizes[i] #call to Deep sets return self.model.forward(batch_sortpooling_graphs)
def graphConvLayers(self, nodeFeats, edgeFeats, n2nSp, e2nSp, graphSizes, nodeDegs): """graph convolution layers""" # if exists edge feature, concatenate to node feature vector if edgeFeats is not None: inputEdgeLinear = self.wE2L(edgeFeats) e2nPool = gnn_spmm(e2nSp, inputEdgeLinear) nodeFeats = torch.cat([nodeFeats, e2nPool], 1) lv = 0 currMsgLayer = nodeFeats msgLayers = [] while lv < len(self.latentDims): # Y = (A + I) * X n2npool = gnn_spmm(n2nSp, currMsgLayer) + currMsgLayer nodeLinear = self.graphConvParams[lv](n2npool) # Y = Y * W normalizedLinear = nodeLinear.div(nodeDegs) # Y = D^-1 * Y currMsgLayer = torch.tanh(normalizedLinear) msgLayers.append(currMsgLayer) lv += 1 return torch.cat(msgLayers, 1)
def multi_hop_embedding(self, cur_message_layer, A, node_degs, input_message): step = 0 input_x = cur_message_layer n, m = cur_message_layer.shape result = torch.zeros((n, m * self.max_k)).to(A.device) while step < self.max_k: n2npool = gnn_spmm(A, input_x) + cur_message_layer # Y = (A + I) * X input_x = self.conv_params[step](n2npool) # Y = Y * W input_x = self.bn2[step](input_x) result[:, (step * self.latent_dim):((step + 1) * self.latent_dim)] = input_x[:, :] step += 1 return self.bn3(torch.matmul(result, self.k_weight).view(n, -1))
def graph_level_embedding(self, node_emb, subg_sp, graph_sizes): if self.graph_level_attention: # Attention layer for nodes atten_layer1 = self.bn4(F.tanh(self.att_params_w1(node_emb))) atten_layer2 = self.bn5(self.att_params_w2(atten_layer1)) graph_emb = torch.zeros(len(graph_sizes), self.multi_h_emb_weight, self.latent_dim) graph_emb = graph_emb.to(node_emb.device) graph_emb = Variable(graph_emb) accum_count = 0 for i in range(subg_sp.size()[0]): alpha = atten_layer2[ accum_count:accum_count + graph_sizes[i]] # nodes in a single graphs # alpha = self.leakyrelu(alpha) alpha = F.softmax( alpha, dim=-1 ) # softmax for normalization bs[32, 8] --> [node_num, multi_head_channel] alpha = F.dropout(alpha, self.dropout) alpha = alpha.t() input_before = node_emb[ accum_count:accum_count + graph_sizes[i]] #vs[32, 64] --> [node_num, latent_dim] emb_g = torch.matmul( alpha, input_before ) # attention: alpha * h, a single row bs[8,64] -->> [multi_head_channel, latent_dim] # emb_g = emb_g.view(1, -1) graph_emb[ i] = emb_g # bs[graph_num, multi_head_channel, latent_dim] accum_count += graph_sizes[i] y_potential = graph_emb.view(len(graph_sizes), -1) else: # average aggregator y_potential = gnn_spmm(subg_sp, node_emb) return F.relu(y_potential)
def sortpooling_embedding(self, node_feat, edge_feat, n2n_sp, e2n_sp, subg_sp, graph_sizes, node_degs): ''' if exists edge feature, concatenate to node feature vector ''' if edge_feat is not None: input_edge_linear = self.w_e2l(edge_feat) e2npool_input = gnn_spmm(e2n_sp, input_edge_linear) node_feat = torch.cat([node_feat, e2npool_input], 1) ''' graph convolution layers ''' A = ops.normalize_adj(n2n_sp) ver = 2 if ver == 2: cur_message_layer = self.gUnet(A, node_feat) else: lv = 0 cur_message_layer = node_feat cat_message_layers = [] while lv < len(self.latent_dim): n2npool = gnn_spmm( n2n_sp, cur_message_layer) + cur_message_layer # noqa node_linear = self.conv_params[lv](n2npool) # Y = Y * W normalized_linear = node_linear.div(node_degs) # Y = D^-1 * Y cur_message_layer = F.tanh(normalized_linear) cat_message_layers.append(cur_message_layer) lv += 1 cur_message_layer = torch.cat(cat_message_layers, 1) ''' sortpooling layer ''' sort_channel = cur_message_layer[:, -1] batch_sortpooling_graphs = torch.zeros(len(graph_sizes), self.k, self.total_latent_dim) if isinstance(node_feat.data, torch.cuda.FloatTensor): batch_sortpooling_graphs = batch_sortpooling_graphs.cuda() batch_sortpooling_graphs = Variable(batch_sortpooling_graphs) accum_count = 0 for i in range(subg_sp.size()[0]): to_sort = sort_channel[accum_count:accum_count + graph_sizes[i]] k = self.k if self.k <= graph_sizes[i] else graph_sizes[i] _, topk_indices = to_sort.topk(k) topk_indices += accum_count sortpooling_graph = cur_message_layer.index_select(0, topk_indices) if k < self.k: to_pad = torch.zeros(self.k - k, self.total_latent_dim) if isinstance(node_feat.data, torch.cuda.FloatTensor): to_pad = to_pad.cuda() to_pad = Variable(to_pad) sortpooling_graph = torch.cat((sortpooling_graph, to_pad), 0) batch_sortpooling_graphs[i] = sortpooling_graph accum_count += graph_sizes[i] ''' traditional 1d convlution and dense layers ''' to_conv1d = batch_sortpooling_graphs.view( (-1, 1, self.k * self.total_latent_dim)) conv1d_res = self.conv1d_params1(to_conv1d) conv1d_res = F.relu(conv1d_res) conv1d_res = self.maxpool1d(conv1d_res) conv1d_res = self.conv1d_params2(conv1d_res) conv1d_res = F.relu(conv1d_res) to_dense = conv1d_res.view(len(graph_sizes), -1) if self.output_dim > 0: out_linear = self.out_params(to_dense) reluact_fp = F.relu(out_linear) else: reluact_fp = to_dense return F.relu(reluact_fp)
def sortpooling_embedding(self, non_zero, node_feat, n2n_sp, graph_sizes, node_degs): ''' graph convolution layers ''' lv = 0 N = n2n_sp.size(0) batch_size = len(graph_sizes) n2n_sp = n2n_sp.cuda() node_degs = node_degs.cuda() if args.embedding == True: print('embedding') node_feat = self.embedding( node_feat.type(torch.LongTensor).cuda()).squeeze() cur_message_layer = node_feat cat_message_layers = [] output = None if args.model == 'gin': while lv < 2 * len(self.latent_dim): n2npool = gnn_spmm( n2n_sp, cur_message_layer) + cur_message_layer # Y = (A + I) * X # if lv==0: # cur_message_layer = (nn.BatchNorm1d(n2npool.size()[1])(n2npool.cpu())).cuda() # readout_features = self.readout(cur_message_layer, graph_sizes) # cat_message_layers.append(readout_features) # lv += 1 # else: node_linear = self.conv_params[lv](n2npool) node_linear = F.relu(node_linear) cur_message_layer = self.conv_params[lv + 1](node_linear) cur_message_layer = F.relu(cur_message_layer) cur_message_layer = (nn.BatchNorm1d( cur_message_layer.size()[1])( cur_message_layer.cpu())).cuda() readout_features = self.readout(cur_message_layer, graph_sizes) output = readout_features cat_message_layers.append(readout_features) lv += 2 else: while lv < len(self.latent_dim): # n2n_sp即为邻接矩阵,一个batch所有图的邻接矩阵 n2npool = gnn_spmm( n2n_sp, cur_message_layer) + cur_message_layer # Y = (A + I) * X node_linear = self.conv_params[lv]( n2npool) # Y = Y * W => shape N * F' normalized_linear = node_linear # normalized_linear = node_linear.div(node_degs) # Y = D^-1 * Y cur_message_layer = torch.tanh(normalized_linear) cat_message_layers.append(cur_message_layer) lv += 1 # Attention in concat feature if args.model == 'concat': cur_message_layer = torch.cat(cat_message_layers, 1) cur_message_layer = torch.cat([ att(non_zero, cur_message_layer)[0] for att in self.attention ], dim=1) # (total_node, sum(latent_dim)) # Separate attention each layer elif args.model == 'separate': cur_message_layer = torch.cat(cat_message_layers, 0) cur_message_layer = torch.cat([ att(non_zero, cur_message_layer)[0] for att in self.attention ], dim=1) # (total_node * k, latent_dim) # Fusion attention multi scale elif args.model == 'fusion': a = [] for f in cat_message_layers: tmp = torch.cat([ att(non_zero, f)[1].view(-1, 1) for att in self.attention ], dim=1) a.append((torch.sum(tmp, 1) / len(self.attention)).view(-1, 1)) # Attention fusion if args.ff == 'max': a = torch.cat(a, dim=1) a_r = torch.max(a, 1) elif args.ff == 'sum': a = torch.cat(a, dim=1) a_r = torch.sum(a, 1) elif args.ff == 'mul': a_r = None for i in range(len(cat_message_layers)): if i == 0: a_r = a[0] else: a_r = torch.mul(a_r, a[i]) # M = AX a_r = a_r.view(-1) special_spmm = SpecialSpmm() non_zero = torch.LongTensor(non_zero) cur_message_layer = special_spmm(non_zero, a_r, torch.Size([N, N]), cat_message_layers[-1]) # (total_node, latent_dim) # No attention and just concat standard GNN elif args.model == 'no-att': cur_message_layer = torch.cat(cat_message_layers, 1) ''' Readout function: sortpooling layer ''' if args.model != 'gin': sort_channel = cur_message_layer[:, -1] # sort_channel: total_node * 1 # 只对最后一个channel的feature进行sort batch_sortpooling_graphs = torch.zeros(len(graph_sizes), self.k, self.att_out_size) # 每一个图的顶点数都变为K batch_sortpooling_graphs = Variable(batch_sortpooling_graphs) if isinstance(node_feat.data, torch.cuda.FloatTensor): batch_sortpooling_graphs = batch_sortpooling_graphs.cuda() accum_count = 0 # sort pool操作只对node_feat进行操作 for i in range(batch_size): to_sort = sort_channel[accum_count:accum_count + graph_sizes[i]] k = self.k if self.k <= graph_sizes[i] else graph_sizes[ i] # 下面只需要判断是否pad _, topk_indices = to_sort.topk(k) # 返回K个最大值元组,(values, indices) topk_indices += accum_count # 因为是to_sort的indices,在原来的feature中提取出来还需要加上count sortpooling_graph = cur_message_layer.index_select( 0, topk_indices).cuda() # 判断是否需要pad if k < self.k: to_pad = torch.zeros(self.k - k, self.att_out_size) if isinstance(node_feat.data, torch.cuda.FloatTensor): to_pad = to_pad.cuda() to_pad = Variable(to_pad) sortpooling_graph = torch.cat((sortpooling_graph, to_pad), 0) batch_sortpooling_graphs[i] = sortpooling_graph accum_count += graph_sizes[i] # 每次对一个batch的feature进行sort ''' traditional 1d convlution and dense layers ''' # batch_size * self.k * att_out_dim # 图的数量 * 每个图固定的顶点数 * 最终维度 res = [] to_conv1d = batch_sortpooling_graphs.view( (-1, 1, self.k * self.att_out_size)) conv1d_res = self.conv1d_params1(to_conv1d) conv1d_res = F.relu(conv1d_res) # print("conv1d_res.shape", conv1d_res.shape) # 50 * 16 * 291 conv1d_res = self.maxpool1d(conv1d_res) res.append(conv1d_res.view(len(graph_sizes), -1)) # print("conv1d_res.shape", conv1d_res.shape) # 50 * 16 * 145 conv1d_res = self.conv1d_params2(conv1d_res) conv1d_res = F.relu(conv1d_res) res.append(conv1d_res.view(len(graph_sizes), -1)) # print("conv1d_res.shape", conv1d_res.shape) # 50 * 32 * 141 if args.model == 'gin': # to_dense = torch.cat(cat_message_layers, 1) to_dense = cat_message_layers[-1] else: if args.concat == 0: to_dense = conv1d_res.view(len(graph_sizes), -1) elif args.concat == 1: to_dense = torch.cat(res, 1) return output
def sortpooling_embedding(self, node_feat, edge_feat, n2n_sp, e2n_sp, subg_sp, graph_sizes, node_degs): ''' if exists edge feature, concatenate to node feature vector ''' if edge_feat is not None: input_edge_linear = self.w_e2l(edge_feat) e2npool_input = gnn_spmm(e2n_sp, input_edge_linear) node_feat = torch.cat([node_feat, e2npool_input], 1) ''' graph convolution layers ''' # A = ops.normalize_adj(n2n_sp) # A = ops.normalize_adj(n2n_sp) lv = 0 cur_message_layer = node_feat cat_message_layers = [] while lv < len(self.latent_dim): n2npool = gnn_spmm(n2n_sp, cur_message_layer) + cur_message_layer # Y = (A + I) * X # print("n2n_sp: ",n2n_sp.type()) # print("cur_message_layer: ",cur_message_layer.type()) node_linear = self.conv_params[lv](n2npool) # Y = Y * W normalized_linear = node_linear.div(node_degs) # Y = D^-1 * Y cur_message_layer = torch.tanh(normalized_linear) # print(" The shape of X is: ", cur_message_layer.size()) cat_message_layers.append(cur_message_layer) lv += 1 ''' You may choose to contact the node features from different layers or not ''' # cur_message_layer = torch.cat(cat_message_layers, 1) ''' CRF pooling ''' ''' First Use GCNs to obtain u(x) for a batch ''' lv2 = 0 X = cur_message_layer #[b,N,d] the features for nodes # cur_message_layer = cur_message_layer #[b,N,d] n2n_sp = n2n_sp # cat_message_layers = [] while lv2 < len(self.latent_dim2): n2npool = gnn_spmm(n2n_sp, cur_message_layer) + cur_message_layer # Y = (A + I) * X node_linear = self.conv_params_p[lv2](n2npool) # Y = Y * W normalized_linear = node_linear.div(node_degs) # Y = D^-1 * Y cur_message_layer = torch.tanh(normalized_linear) # print("The shape of X^bar is: ", cur_message_layer.size()) # cat_message_layers.append(cur_message_layer) lv2 += 1 # print("The shape of X^bar is: ", cur_message_layer.size()) batch_sortpooling_graphs = torch.zeros(len(graph_sizes), self.k, self.last_dim) batch_sortpooling_As = torch.zeros(len(graph_sizes), self.k, self.k) # batch_sortpooling_Ux = torch.zeros(len(graph_sizes), self.k, self.k) if torch.cuda.is_available() and isinstance(node_feat.data, torch.cuda.FloatTensor): batch_sortpooling_graphs = batch_sortpooling_graphs.cuda() batch_sortpooling_As = batch_sortpooling_As.cuda() batch_sortpooling_graphs = Variable(batch_sortpooling_graphs) batch_sortpooling_As = Variable(batch_sortpooling_As) accum_count = 0 n2n_dense = self.sparse_to_dense(n2n_sp) ''' CRF pooling ''' ''' Second perform pooling for each graph ''' for i in range(subg_sp.size()[0]): X_i = X[accum_count : accum_count+ graph_sizes[i], :] A = n2n_dense[accum_count : accum_count+ graph_sizes[i], accum_count : accum_count+ graph_sizes[i]] U_X = cur_message_layer[accum_count : accum_count+ graph_sizes[i],:] X_out, A_out = self.Pool(A, X_i, U_X) batch_sortpooling_graphs[i] = X_out batch_sortpooling_As[i] =A_out accum_count += graph_sizes[i] # print('The output of pooling is :', cur_message_layer.size()) ''' traditional 1d convlution and dense layers ''' to_conv1d = batch_sortpooling_graphs.view( (-1, 1, self.k * self.last_dim)) #[b,1,k*d] # print("After reshaping, the size is:", to_conv1d.size()) conv1d_res = self.conv1d_params1(to_conv1d) conv1d_res = F.relu(conv1d_res) # print("After conv1, the shape is :", conv1d_res.size()) conv1d_res = self.maxpool1d(conv1d_res) # print("After pooling, the shape is :", conv1d_res.size()) conv1d_res = self.conv1d_params2(conv1d_res) conv1d_res = F.relu(conv1d_res) #print("After conv2, the shape is :", conv1d_res.size()) to_dense = conv1d_res.view(len(graph_sizes), -1) if self.output_dim > 0: out_linear = self.out_params(to_dense) reluact_fp = F.relu(out_linear) else: reluact_fp = to_dense return F.relu(reluact_fp)
def sortpooling_embedding(self, node_feat, edge_feat, n2f_sp, f2n_sp, subhg_sp, hypergraph_sizes, node_hdegs, hyperedge_sizes): ''' if exists edge feature, concatenate to node feature vector ''' edge_feat = None ########################################################### if edge_feat is not None: input_edge_linear = self.w_e2l(edge_feat) # input_edge_linear = edge_feat e2npool_input = gnn_spmm(n2f_sp, input_edge_linear) node_feat = torch.cat([node_feat, e2npool_input], 1) ''' graph convolution layers ''' lv = 0 cur_message_layer = node_feat cat_message_layers = [] while lv < 2 * len(self.latent_dim): # # OPTION 1 # f2npool = gnn_spmm(f2n_sp, cur_message_layer) # Y = S^T * X # print(f2npool.shape) # print(f2n_sp.shape) # print(cur_message_layer.shape) # print(node_feat.shape) # node_linear = self.conv_params[lv](f2npool) # Y = Y * W # normalized_linear = node_linear.div(hyperedge_sizes) # Y = sizes^-1 * Y # cur_message_layer = torch.tanh(normalized_linear) # Z = tanh(Y) # n2fpool = gnn_spmm(n2f_sp, cur_message_layer) # Y = S * Z # node_linear = nn.Linear(cur_message_layer.shape[1], cur_message_layer.shape[1])(n2fpool) # Z = Z * W # normalized_linear = node_linear.div(node_hdegs) # Z = D_f^-1 * Z # cur_message_layer = torch.tanh(normalized_linear) # cat_message_layers.append(cur_message_layer) # lv += 1 # OPTION 2 # f2npool = gnn_spmm(f2n_sp, cur_message_layer) # Y = S^T * X # node_linear = self.conv_params[lv](f2npool) # Y = Y * W # normalized_linear = node_linear.div(hyperedge_sizes) # Y = sizes^-1 * Y # cur_message_layer = torch.tanh(normalized_linear) # Z = tanh(Y) # lv += 1 # n2fpool = gnn_spmm(n2f_sp, cur_message_layer) # Y = S * Z # node_linear = self.conv_params[lv](n2fpool) # Z = Z * W # normalized_linear = node_linear.div(node_hdegs) # Z = D_f^-1 * Z # cur_message_layer = torch.tanh(normalized_linear) # cat_message_layers.append(cur_message_layer) # lv += 1 # OPTION 3 # f2npool = gnn_spmm(f2n_sp, cur_message_layer) # Y = S^T * X # node_linear = self.conv_params[lv](f2npool) # Y = Y * W # normalized_linear = node_linear.div(hyperedge_sizes) # Y = sizes^-1 * Y # node_linear = gnn_spmm(n2f_sp, normalized_linear) # Y = S * Y # normalized_linear = node_linear.div(node_hdegs) # Y = D_f^-1 * Z # cur_message_layer = torch.tanh(normalized_linear) # Z = tanh(Y) # cat_message_layers.append(cur_message_layer) # lv += 1 # OPTION 4 # n2fpool = gnn_spmm(n2f_sp, cur_message_layer) # Y = S * Z # print(n2fpool.shape) # print(n2f_sp.shape) # print(cur_message_layer.shape) # print(edge_feat.shape) # node_linear = self.conv_params[lv](n2fpool) # Z = Z * W # normalized_linear = node_linear.div(node_hdegs) # Z = D_f^-1 * Z # cur_message_layer = torch.tanh(normalized_linear) # cat_message_layers.append(cur_message_layer) # lv += 1 # f2npool = gnn_spmm(f2n_sp, cur_message_layer) # Y = S^T * X # node_linear = self.conv_params[lv](f2npool) # Y = Y * W # normalized_linear = node_linear.div(hyperedge_sizes) # Y = sizes^-1 * Y # cur_message_layer = torch.tanh(normalized_linear) # Z = tanh(Y) # lv += 1 # OPTION 2 f2npool = gnn_spmm(f2n_sp, cur_message_layer) # Y = S^T * X node_linear = self.conv_params[lv](f2npool) # Y = Y * W normalized_linear = node_linear.div( hyperedge_sizes) # Y = sizes^-1 * Y cur_message_layer = torch.tanh(normalized_linear) # Z = tanh(Y) lv += 1 n2fpool = gnn_spmm(n2f_sp, cur_message_layer) # Y = S * Z node_linear = self.conv_params[lv](n2fpool) # Z = Z * W normalized_linear = node_linear.div(node_hdegs) # Z = D_f^-1 * Z cur_message_layer = torch.tanh(normalized_linear) cat_message_layers.append(cur_message_layer) lv += 1 cur_message_layer = torch.cat(cat_message_layers, 1) ''' sortpooling layer ''' sort_channel = cur_message_layer[:, -1] batch_sortpooling_graphs = torch.zeros(len(hypergraph_sizes), self.k, self.total_latent_dim) if torch.cuda.is_available() and isinstance(node_feat.data, torch.cuda.FloatTensor): batch_sortpooling_graphs = batch_sortpooling_graphs.cuda() batch_sortpooling_graphs = Variable(batch_sortpooling_graphs) accum_count = 0 for i in range(subhg_sp.size()[0]): to_sort = sort_channel[accum_count:accum_count + hypergraph_sizes[i]] k = self.k if self.k <= hypergraph_sizes[i] else hypergraph_sizes[i] _, topk_indices = to_sort.topk(k) topk_indices += accum_count sortpooling_graph = cur_message_layer.index_select(0, topk_indices) if k < self.k: to_pad = torch.zeros(self.k - k, self.total_latent_dim) if torch.cuda.is_available() and isinstance( node_feat.data, torch.cuda.FloatTensor): to_pad = to_pad.cuda() to_pad = Variable(to_pad) sortpooling_graph = torch.cat((sortpooling_graph, to_pad), 0) batch_sortpooling_graphs[i] = sortpooling_graph accum_count += hypergraph_sizes[i] ''' traditional 1d convlution and dense layers ''' to_conv1d = batch_sortpooling_graphs.view( (-1, 1, self.k * self.total_latent_dim)) conv1d_res = self.conv1d_params1(to_conv1d) conv1d_res = self.conv1d_activation(conv1d_res) conv1d_res = self.maxpool1d(conv1d_res) conv1d_res = self.conv1d_params2(conv1d_res) conv1d_res = self.conv1d_activation(conv1d_res) to_dense = conv1d_res.view(len(hypergraph_sizes), -1) if self.output_dim > 0: out_linear = self.out_params(to_dense) reluact_fp = self.conv1d_activation(out_linear) else: reluact_fp = to_dense return self.conv1d_activation(reluact_fp)