Пример #1
0
    def loopy_bp(self, node_feat, edge_feat, n2e_sp, e2e_sp, e2n_sp, subg_sp):
        input_node_linear = self.w_n2l(node_feat)
        input_edge_linear = self.w_e2l(edge_feat)

        n2epool_input = gnn_spmm(n2e_sp, input_node_linear)
        
        input_message = input_edge_linear + n2epool_input
        input_potential = F.relu(input_message)

        lv = 0
        cur_message_layer = input_potential
        while lv < self.max_lv:
            e2epool = gnn_spmm(e2e_sp, cur_message_layer)
            edge_linear = self.conv_params(e2epool)                    
            merged_linear = edge_linear + input_message

            cur_message_layer = F.relu(merged_linear)
            lv += 1

        e2npool = gnn_spmm(e2n_sp, cur_message_layer)
        hidden_msg = F.relu(e2npool)
        out_linear = self.out_params(hidden_msg)
        reluact_fp = F.relu(out_linear)

        y_potential = gnn_spmm(subg_sp, reluact_fp)

        return F.relu(y_potential)
Пример #2
0
    def mean_field(self, node_feat, edge_feat, n2n_sp, e2n_sp, subg_sp):
        input_node_linear = self.w_n2l(node_feat)
        input_message = input_node_linear
        if edge_feat is not None:
            input_edge_linear = self.w_e2l(edge_feat)
            e2npool_input = gnn_spmm(e2n_sp, input_edge_linear)
            input_message += e2npool_input
        input_potential = F.relu(input_message)

        lv = 0
        cur_message_layer = input_potential
        while lv < self.max_lv:
            n2npool = gnn_spmm(n2n_sp, cur_message_layer)
            node_linear = self.conv_params( n2npool )
            merged_linear = node_linear + input_message

            cur_message_layer = F.relu(merged_linear)
            lv += 1
        if self.output_dim > 0:
            out_linear = self.out_params(cur_message_layer)
            reluact_fp = F.relu(out_linear)
        else:
            reluact_fp = cur_message_layer
            
        y_potential = gnn_spmm(subg_sp, reluact_fp)

        return F.relu(y_potential)
Пример #3
0
    def node_level_embedding(self, node_feat, edge_feat, n2n_sp, e2n_sp,
                             node_degs):
        input_node_linear = self.w_n2l(
            node_feat
        )  # Question: could try to remove this layer. -->> this is for channels matching, hard to remove.
        input_node_linear = self.bn1(input_node_linear)

        input_message = input_node_linear
        if edge_feat is not None:
            input_edge_linear = self.w_e2l(edge_feat)
            input_edge_linear = self.bne1(input_edge_linear)

            e2npool_input = gnn_spmm(e2n_sp, input_edge_linear)
            input_message += e2npool_input
        input_potential = F.relu(input_message)

        if self.k_hop_embedding:
            block = 0
            cur_message_layer = input_potential
            A = n2n_sp  # .to_dense() # 1 hop information

            while block < self.max_block:
                if block == 0:
                    block_input = cur_message_layer
                else:
                    block_input = cur_message_layer + input_potential
                h = self.multi_hop_embedding(block_input, A, node_degs,
                                             input_message)
                h = F.relu(
                    h)  # fixme: do we need this relu after the block, may not.
                cur_message_layer = h
                block += 1

        else:  # simple aggregate the node features from neighbors, the same as structure2vec
            lv = 0
            cur_message_layer = input_potential
            while lv < self.max_lv:
                n2npool = gnn_spmm(n2n_sp, cur_message_layer)
                node_linear = self.conv_params(n2npool)  #layer3
                self.bn2 = nn.BatchNorm1d(latent_dim)
                merged_linear = node_linear + input_message

                cur_message_layer = F.relu(merged_linear)
                lv += 1

        if self.output_dim > 0:
            out_linear = self.out_params(cur_message_layer)
            reluact_fp = F.relu(out_linear)
        else:
            reluact_fp = cur_message_layer

        return reluact_fp
Пример #4
0
    def sortpooling_embedding(self, node_feat, edge_feat, n2n_sp, e2n_sp,
                              subg_sp, graph_sizes, node_degs):
        ''' if exists edge feature, concatenate to node feature vector '''
        if edge_feat is not None:
            input_edge_linear = self.w_e2l(edge_feat)
            e2npool_input = gnn_spmm(e2n_sp, input_edge_linear)
            node_feat = torch.cat([node_feat, e2npool_input], 1)
        ''' graph convolution layers '''
        lv = 0
        cur_message_layer = node_feat
        cat_message_layers = []
        while lv < len(self.latent_dim):
            n2npool = gnn_spmm(
                n2n_sp,
                cur_message_layer) + cur_message_layer  # Y = (A + I) * X
            node_linear = self.conv_params[lv](n2npool)  # Y = Y * W
            normalized_linear = node_linear.div(node_degs)  # Y = D^-1 * Y
            cur_message_layer = F.tanh(normalized_linear)
            cat_message_layers.append(cur_message_layer)
            lv += 1

        cur_message_layer = torch.cat(cat_message_layers, 1)
        ''' sortpooling layer '''
        sort_channel = cur_message_layer[:, -1]
        batch_sortpooling_graphs = torch.zeros(len(graph_sizes), self.k,
                                               self.total_latent_dim)
        if isinstance(node_feat.data, torch.cuda.FloatTensor):
            batch_sortpooling_graphs = batch_sortpooling_graphs.cuda()

        batch_sortpooling_graphs = Variable(batch_sortpooling_graphs)
        accum_count = 0
        for i in range(subg_sp.size()[0]):
            to_sort = sort_channel[accum_count:accum_count + graph_sizes[i]]
            k = self.k if self.k <= graph_sizes[i] else graph_sizes[i]
            _, topk_indices = to_sort.topk(k)
            topk_indices += accum_count
            sortpooling_graph = cur_message_layer.index_select(0, topk_indices)
            if k < self.k:
                to_pad = torch.zeros(self.k - k, self.total_latent_dim)
                if isinstance(node_feat.data, torch.cuda.FloatTensor):
                    to_pad = to_pad.cuda()

                to_pad = Variable(to_pad)
                sortpooling_graph = torch.cat((sortpooling_graph, to_pad), 0)
            batch_sortpooling_graphs[i] = sortpooling_graph
            accum_count += graph_sizes[i]

        return batch_sortpooling_graphs
def process_sparse(graph_list, node_feat, edge_feat):
    graph_sizes = [graph_list[i].num_nodes for i in range(len(graph_list))]
    node_degs = [
        torch.Tensor(graph_list[i].degs) + 1 for i in range(len(graph_list))
    ]
    node_degs = torch.cat(node_degs).unsqueeze(1)
    n2n_sp, e2n_sp, subg_sp = GNNLIB.PrepareSparseMatrices(graph_list)

    n2n_sp = n2n_sp.cuda()
    e2n_sp = e2n_sp.cuda()
    subg_sp = subg_sp.cuda()
    node_degs = node_degs.cuda()

    node_feat = Variable(node_feat)
    if edge_feat is not None:
        edge_feat = Variable(edge_feat)
        if torch.cuda.is_available() and isinstance(node_feat,
                                                    torch.cuda.FloatTensor):
            edge_feat = edge_feat.cuda()
    n2n_sp = Variable(n2n_sp)
    e2n_sp = Variable(e2n_sp)
    subg_sp = Variable(subg_sp)
    node_degs = Variable(node_degs)

    if edge_feat is not None:
        input_edge_linear = edge_feat
        e2npool_input = gnn_spmm(e2n_sp, input_edge_linear)
        node_feat = torch.cat([node_feat, e2npool_input], 1)
    return graph_sizes, n2n_sp, e2n_sp, subg_sp, node_degs
Пример #6
0
    def attention_gcn(self, node_feat, edge_feat, n2n_sp, e2n_sp, node_degs):
        input_node_linear = self.w_n2l(node_feat)
        input_node_linear = self.bn1(input_node_linear)

        input_message = input_node_linear
        if edge_feat is not None:
            input_edge_linear = self.w_e2l(edge_feat)
            input_edge_linear = self.bne1(input_edge_linear)

            e2npool_input = gnn_spmm(e2n_sp, input_edge_linear)
            input_message += e2npool_input
        input_potential = F.relu(input_message)

        block = 0
        cur_message_layer = input_potential
        A = n2n_sp

        while block < self.max_block:
            if block == 0:
                block_input = cur_message_layer
            else:
                block_input = cur_message_layer + input_potential
            h = self.multi_hop_embedding(block_input, A, node_degs,
                                         input_message)
            h = F.relu(h)
            cur_message_layer = h
            block += 1

        if self.output_dim > 0:
            out_linear = self.out_params(cur_message_layer)
            reluact_fp = F.relu(out_linear)
        else:
            reluact_fp = cur_message_layer

        return reluact_fp
Пример #7
0
    def deepsets_embedding(self, node_feat, edge_feat, n2n_sp, e2n_sp, subg_sp, graph_sizes, node_degs):
        ''' if exists edge feature, concatenate to node feature vector '''
        if edge_feat is not None:
            input_edge_linear = self.w_e2l(edge_feat)
            e2npool_input = gnn_spmm(e2n_sp, input_edge_linear)
            node_feat = torch.cat([node_feat, e2npool_input], 1)

        ''' graph convolution layers '''
        lv = 0
        cur_message_layer = node_feat
        cat_message_layers = []
        while lv < len(self.latent_dim):
            #original=Variable(cur_message_layer)
            n2npool = gnn_spmm(n2n_sp, cur_message_layer) + cur_message_layer  # Y = (A + I) * X
            node_linear = self.conv_params[lv](n2npool)  # Y = Y * W
            normalized_linear = node_linear.div(node_degs)  # Y = D^-1 * Y
            cur_message_layer = F.tanh(normalized_linear) #NICk possibly substitute with relu
            cat_message_layers.append(cur_message_layer)
            lv += 1

        cur_message_layer = torch.cat(cat_message_layers, 1)

        ''' sortpooling layer '''
        max_size=max(graph_sizes)
        batch_sortpooling_graphs = torch.zeros(len(graph_sizes), max_size, self.total_latent_dim)
        if isinstance(node_feat.data, torch.cuda.FloatTensor):
            batch_sortpooling_graphs = batch_sortpooling_graphs.cuda()

        batch_sortpooling_graphs = Variable(batch_sortpooling_graphs)
        accum_count = 0
        for i in range(subg_sp.size()[0]):
            R=torch.cuda.LongTensor(range(accum_count,accum_count + graph_sizes[i]))
            k =  graph_sizes[i]
            sortpooling_graph = cur_message_layer.index_select(0, R)
            if k < max_size:
                to_pad = torch.zeros(max_size-k, self.total_latent_dim)
                if isinstance(node_feat.data, torch.cuda.FloatTensor):
                    to_pad = to_pad.cuda()

                to_pad = Variable(to_pad)
                sortpooling_graph = torch.cat((sortpooling_graph, to_pad), 0)
            batch_sortpooling_graphs[i] = sortpooling_graph
            accum_count += graph_sizes[i]

            #call to Deep sets
        return self.model.forward(batch_sortpooling_graphs)
Пример #8
0
    def graphConvLayers(self, nodeFeats, edgeFeats, n2nSp, e2nSp,
                        graphSizes, nodeDegs):
        """graph convolution layers"""
        # if exists edge feature, concatenate to node feature vector
        if edgeFeats is not None:
            inputEdgeLinear = self.wE2L(edgeFeats)
            e2nPool = gnn_spmm(e2nSp, inputEdgeLinear)
            nodeFeats = torch.cat([nodeFeats, e2nPool], 1)

        lv = 0
        currMsgLayer = nodeFeats
        msgLayers = []
        while lv < len(self.latentDims):
            # Y = (A + I) * X
            n2npool = gnn_spmm(n2nSp, currMsgLayer) + currMsgLayer
            nodeLinear = self.graphConvParams[lv](n2npool)  # Y = Y * W
            normalizedLinear = nodeLinear.div(nodeDegs)  # Y = D^-1 * Y
            currMsgLayer = torch.tanh(normalizedLinear)
            msgLayers.append(currMsgLayer)
            lv += 1

        return torch.cat(msgLayers, 1)
Пример #9
0
    def multi_hop_embedding(self, cur_message_layer, A, node_degs,
                            input_message):
        step = 0
        input_x = cur_message_layer
        n, m = cur_message_layer.shape
        result = torch.zeros((n, m * self.max_k)).to(A.device)
        while step < self.max_k:
            n2npool = gnn_spmm(A,
                               input_x) + cur_message_layer  # Y = (A + I) * X
            input_x = self.conv_params[step](n2npool)  # Y = Y * W
            input_x = self.bn2[step](input_x)
            result[:,
                   (step * self.latent_dim):((step + 1) *
                                             self.latent_dim)] = input_x[:, :]
            step += 1

        return self.bn3(torch.matmul(result, self.k_weight).view(n, -1))
Пример #10
0
    def graph_level_embedding(self, node_emb, subg_sp, graph_sizes):

        if self.graph_level_attention:
            # Attention layer for nodes
            atten_layer1 = self.bn4(F.tanh(self.att_params_w1(node_emb)))
            atten_layer2 = self.bn5(self.att_params_w2(atten_layer1))

            graph_emb = torch.zeros(len(graph_sizes), self.multi_h_emb_weight,
                                    self.latent_dim)
            graph_emb = graph_emb.to(node_emb.device)
            graph_emb = Variable(graph_emb)

            accum_count = 0
            for i in range(subg_sp.size()[0]):
                alpha = atten_layer2[
                    accum_count:accum_count +
                    graph_sizes[i]]  # nodes in a single graphs
                # alpha = self.leakyrelu(alpha)
                alpha = F.softmax(
                    alpha, dim=-1
                )  # softmax for normalization    bs[32, 8] --> [node_num, multi_head_channel]

                alpha = F.dropout(alpha, self.dropout)
                alpha = alpha.t()

                input_before = node_emb[
                    accum_count:accum_count +
                    graph_sizes[i]]  #vs[32, 64] --> [node_num, latent_dim]
                emb_g = torch.matmul(
                    alpha, input_before
                )  # attention: alpha * h, a single row   bs[8,64] -->> [multi_head_channel, latent_dim]

                # emb_g = emb_g.view(1, -1)
                graph_emb[
                    i] = emb_g  # bs[graph_num, multi_head_channel, latent_dim]
                accum_count += graph_sizes[i]

            y_potential = graph_emb.view(len(graph_sizes), -1)

        else:  # average aggregator
            y_potential = gnn_spmm(subg_sp, node_emb)

        return F.relu(y_potential)
Пример #11
0
    def sortpooling_embedding(self, node_feat, edge_feat, n2n_sp, e2n_sp,
                              subg_sp, graph_sizes, node_degs):
        ''' if exists edge feature, concatenate to node feature vector '''
        if edge_feat is not None:
            input_edge_linear = self.w_e2l(edge_feat)
            e2npool_input = gnn_spmm(e2n_sp, input_edge_linear)
            node_feat = torch.cat([node_feat, e2npool_input], 1)
        ''' graph convolution layers '''
        A = ops.normalize_adj(n2n_sp)

        ver = 2

        if ver == 2:
            cur_message_layer = self.gUnet(A, node_feat)
        else:
            lv = 0
            cur_message_layer = node_feat
            cat_message_layers = []
            while lv < len(self.latent_dim):
                n2npool = gnn_spmm(
                    n2n_sp, cur_message_layer) + cur_message_layer  # noqa
                node_linear = self.conv_params[lv](n2npool)  # Y = Y * W
                normalized_linear = node_linear.div(node_degs)  # Y = D^-1 * Y
                cur_message_layer = F.tanh(normalized_linear)
                cat_message_layers.append(cur_message_layer)
                lv += 1

            cur_message_layer = torch.cat(cat_message_layers, 1)
        ''' sortpooling layer '''
        sort_channel = cur_message_layer[:, -1]
        batch_sortpooling_graphs = torch.zeros(len(graph_sizes), self.k,
                                               self.total_latent_dim)
        if isinstance(node_feat.data, torch.cuda.FloatTensor):
            batch_sortpooling_graphs = batch_sortpooling_graphs.cuda()

        batch_sortpooling_graphs = Variable(batch_sortpooling_graphs)
        accum_count = 0
        for i in range(subg_sp.size()[0]):
            to_sort = sort_channel[accum_count:accum_count + graph_sizes[i]]
            k = self.k if self.k <= graph_sizes[i] else graph_sizes[i]
            _, topk_indices = to_sort.topk(k)
            topk_indices += accum_count
            sortpooling_graph = cur_message_layer.index_select(0, topk_indices)
            if k < self.k:
                to_pad = torch.zeros(self.k - k, self.total_latent_dim)
                if isinstance(node_feat.data, torch.cuda.FloatTensor):
                    to_pad = to_pad.cuda()

                to_pad = Variable(to_pad)
                sortpooling_graph = torch.cat((sortpooling_graph, to_pad), 0)
            batch_sortpooling_graphs[i] = sortpooling_graph
            accum_count += graph_sizes[i]
        ''' traditional 1d convlution and dense layers '''
        to_conv1d = batch_sortpooling_graphs.view(
            (-1, 1, self.k * self.total_latent_dim))
        conv1d_res = self.conv1d_params1(to_conv1d)
        conv1d_res = F.relu(conv1d_res)
        conv1d_res = self.maxpool1d(conv1d_res)
        conv1d_res = self.conv1d_params2(conv1d_res)
        conv1d_res = F.relu(conv1d_res)

        to_dense = conv1d_res.view(len(graph_sizes), -1)

        if self.output_dim > 0:
            out_linear = self.out_params(to_dense)
            reluact_fp = F.relu(out_linear)
        else:
            reluact_fp = to_dense

        return F.relu(reluact_fp)
Пример #12
0
    def sortpooling_embedding(self, non_zero, node_feat, n2n_sp, graph_sizes,
                              node_degs):
        ''' graph convolution layers '''
        lv = 0
        N = n2n_sp.size(0)
        batch_size = len(graph_sizes)
        n2n_sp = n2n_sp.cuda()
        node_degs = node_degs.cuda()
        if args.embedding == True:
            print('embedding')
            node_feat = self.embedding(
                node_feat.type(torch.LongTensor).cuda()).squeeze()
        cur_message_layer = node_feat
        cat_message_layers = []
        output = None

        if args.model == 'gin':
            while lv < 2 * len(self.latent_dim):
                n2npool = gnn_spmm(
                    n2n_sp,
                    cur_message_layer) + cur_message_layer  # Y = (A + I) * X
                # if lv==0:
                #     cur_message_layer = (nn.BatchNorm1d(n2npool.size()[1])(n2npool.cpu())).cuda()
                #     readout_features = self.readout(cur_message_layer, graph_sizes)
                #     cat_message_layers.append(readout_features)
                #     lv += 1
                # else:
                node_linear = self.conv_params[lv](n2npool)
                node_linear = F.relu(node_linear)
                cur_message_layer = self.conv_params[lv + 1](node_linear)
                cur_message_layer = F.relu(cur_message_layer)
                cur_message_layer = (nn.BatchNorm1d(
                    cur_message_layer.size()[1])(
                        cur_message_layer.cpu())).cuda()
                readout_features = self.readout(cur_message_layer, graph_sizes)
                output = readout_features
                cat_message_layers.append(readout_features)
                lv += 2
        else:
            while lv < len(self.latent_dim):
                # n2n_sp即为邻接矩阵,一个batch所有图的邻接矩阵
                n2npool = gnn_spmm(
                    n2n_sp,
                    cur_message_layer) + cur_message_layer  # Y = (A + I) * X
                node_linear = self.conv_params[lv](
                    n2npool)  # Y = Y * W => shape N * F'
                normalized_linear = node_linear
                # normalized_linear = node_linear.div(node_degs)  # Y = D^-1 * Y
                cur_message_layer = torch.tanh(normalized_linear)
                cat_message_layers.append(cur_message_layer)
                lv += 1

        # Attention in concat feature
        if args.model == 'concat':
            cur_message_layer = torch.cat(cat_message_layers, 1)
            cur_message_layer = torch.cat([
                att(non_zero, cur_message_layer)[0] for att in self.attention
            ],
                                          dim=1)
            # (total_node, sum(latent_dim))

        # Separate attention each layer
        elif args.model == 'separate':
            cur_message_layer = torch.cat(cat_message_layers, 0)
            cur_message_layer = torch.cat([
                att(non_zero, cur_message_layer)[0] for att in self.attention
            ],
                                          dim=1)
            # (total_node * k, latent_dim)

        # Fusion attention multi scale
        elif args.model == 'fusion':
            a = []
            for f in cat_message_layers:
                tmp = torch.cat([
                    att(non_zero, f)[1].view(-1, 1) for att in self.attention
                ],
                                dim=1)
                a.append((torch.sum(tmp, 1) / len(self.attention)).view(-1, 1))
            # Attention fusion
            if args.ff == 'max':
                a = torch.cat(a, dim=1)
                a_r = torch.max(a, 1)
            elif args.ff == 'sum':
                a = torch.cat(a, dim=1)
                a_r = torch.sum(a, 1)
            elif args.ff == 'mul':
                a_r = None
                for i in range(len(cat_message_layers)):
                    if i == 0:
                        a_r = a[0]
                    else:
                        a_r = torch.mul(a_r, a[i])

            # M = AX
            a_r = a_r.view(-1)
            special_spmm = SpecialSpmm()
            non_zero = torch.LongTensor(non_zero)
            cur_message_layer = special_spmm(non_zero, a_r, torch.Size([N, N]),
                                             cat_message_layers[-1])
            # (total_node, latent_dim)

        # No attention and just concat standard GNN
        elif args.model == 'no-att':
            cur_message_layer = torch.cat(cat_message_layers, 1)
        ''' Readout function: sortpooling layer '''
        if args.model != 'gin':
            sort_channel = cur_message_layer[:, -1]
            # sort_channel: total_node * 1
            # 只对最后一个channel的feature进行sort
            batch_sortpooling_graphs = torch.zeros(len(graph_sizes), self.k,
                                                   self.att_out_size)
            # 每一个图的顶点数都变为K
            batch_sortpooling_graphs = Variable(batch_sortpooling_graphs)
            if isinstance(node_feat.data, torch.cuda.FloatTensor):
                batch_sortpooling_graphs = batch_sortpooling_graphs.cuda()
            accum_count = 0
            # sort pool操作只对node_feat进行操作
            for i in range(batch_size):
                to_sort = sort_channel[accum_count:accum_count +
                                       graph_sizes[i]]
                k = self.k if self.k <= graph_sizes[i] else graph_sizes[
                    i]  # 下面只需要判断是否pad
                _, topk_indices = to_sort.topk(k)
                # 返回K个最大值元组,(values, indices)
                topk_indices += accum_count
                # 因为是to_sort的indices,在原来的feature中提取出来还需要加上count
                sortpooling_graph = cur_message_layer.index_select(
                    0, topk_indices).cuda()
                # 判断是否需要pad
                if k < self.k:
                    to_pad = torch.zeros(self.k - k, self.att_out_size)
                    if isinstance(node_feat.data, torch.cuda.FloatTensor):
                        to_pad = to_pad.cuda()

                    to_pad = Variable(to_pad)
                    sortpooling_graph = torch.cat((sortpooling_graph, to_pad),
                                                  0)
                batch_sortpooling_graphs[i] = sortpooling_graph
                accum_count += graph_sizes[i]
                # 每次对一个batch的feature进行sort
            ''' traditional 1d convlution and dense layers '''
            # batch_size * self.k * att_out_dim
            # 图的数量 * 每个图固定的顶点数 * 最终维度
            res = []

            to_conv1d = batch_sortpooling_graphs.view(
                (-1, 1, self.k * self.att_out_size))
            conv1d_res = self.conv1d_params1(to_conv1d)
            conv1d_res = F.relu(conv1d_res)
            # print("conv1d_res.shape", conv1d_res.shape) # 50 * 16 * 291
            conv1d_res = self.maxpool1d(conv1d_res)
            res.append(conv1d_res.view(len(graph_sizes), -1))
            # print("conv1d_res.shape", conv1d_res.shape) # 50 * 16 * 145
            conv1d_res = self.conv1d_params2(conv1d_res)
            conv1d_res = F.relu(conv1d_res)
            res.append(conv1d_res.view(len(graph_sizes), -1))
            # print("conv1d_res.shape", conv1d_res.shape) # 50 * 32 * 141

        if args.model == 'gin':
            # to_dense = torch.cat(cat_message_layers, 1)
            to_dense = cat_message_layers[-1]
        else:
            if args.concat == 0:
                to_dense = conv1d_res.view(len(graph_sizes), -1)
            elif args.concat == 1:
                to_dense = torch.cat(res, 1)

        return output
Пример #13
0
    def sortpooling_embedding(self, node_feat, edge_feat, n2n_sp, e2n_sp,
                              subg_sp, graph_sizes, node_degs):
        ''' if exists edge feature, concatenate to node feature vector '''
        if edge_feat is not None:
            input_edge_linear = self.w_e2l(edge_feat)
            e2npool_input = gnn_spmm(e2n_sp, input_edge_linear)
            node_feat = torch.cat([node_feat, e2npool_input], 1)

        ''' graph convolution layers '''
      #  A = ops.normalize_adj(n2n_sp)



         #   A = ops.normalize_adj(n2n_sp)

        lv = 0
        cur_message_layer = node_feat
        cat_message_layers = []
        while lv < len(self.latent_dim):
            n2npool = gnn_spmm(n2n_sp, cur_message_layer) + cur_message_layer # Y = (A + I) * X
    #         print("n2n_sp: ",n2n_sp.type())
    #        print("cur_message_layer: ",cur_message_layer.type())
            node_linear = self.conv_params[lv](n2npool)  # Y = Y * W
            normalized_linear = node_linear.div(node_degs)  # Y = D^-1 * Y
            cur_message_layer = torch.tanh(normalized_linear)
        #       print(" The shape of X is: ", cur_message_layer.size())
            cat_message_layers.append(cur_message_layer)
            lv += 1
        '''  You may choose to contact the node features from different layers or not '''
         #   cur_message_layer = torch.cat(cat_message_layers, 1) 
        '''  CRF pooling '''
        '''  First Use GCNs to obtain u(x) for a batch '''

        lv2 = 0
        X = cur_message_layer #[b,N,d] the features for nodes
       # cur_message_layer = cur_message_layer  #[b,N,d]
        n2n_sp = n2n_sp
        #  cat_message_layers = []  
        while lv2 < len(self.latent_dim2):
            n2npool = gnn_spmm(n2n_sp, cur_message_layer) + cur_message_layer # Y = (A + I) * X
            node_linear = self.conv_params_p[lv2](n2npool)  # Y = Y * W
            normalized_linear = node_linear.div(node_degs)  # Y = D^-1 * Y
            cur_message_layer = torch.tanh(normalized_linear)
      #      print("The shape of X^bar is: ", cur_message_layer.size())
        #       cat_message_layers.append(cur_message_layer)
            lv2 += 1
        #    print("The shape of X^bar is: ", cur_message_layer.size())


        batch_sortpooling_graphs = torch.zeros(len(graph_sizes), self.k, self.last_dim)
        batch_sortpooling_As = torch.zeros(len(graph_sizes), self.k, self.k)
    #    batch_sortpooling_Ux = torch.zeros(len(graph_sizes), self.k, self.k)
        if torch.cuda.is_available() and isinstance(node_feat.data, torch.cuda.FloatTensor):
            batch_sortpooling_graphs = batch_sortpooling_graphs.cuda()
            batch_sortpooling_As = batch_sortpooling_As.cuda()
        
        batch_sortpooling_graphs = Variable(batch_sortpooling_graphs)
        batch_sortpooling_As = Variable(batch_sortpooling_As)
        accum_count = 0

        n2n_dense = self.sparse_to_dense(n2n_sp)
        '''  CRF pooling '''
        '''  Second perform pooling for each graph '''

        for i in range(subg_sp.size()[0]):
            X_i = X[accum_count : accum_count+ graph_sizes[i], :]
            A = n2n_dense[accum_count : accum_count+ graph_sizes[i], accum_count : accum_count+ graph_sizes[i]]
            U_X = cur_message_layer[accum_count : accum_count+ graph_sizes[i],:]
            X_out, A_out = self.Pool(A,  X_i, U_X)
            batch_sortpooling_graphs[i] = X_out
            batch_sortpooling_As[i] =A_out
            accum_count += graph_sizes[i]
        
    #    print('The output of pooling is :', cur_message_layer.size())


        ''' traditional 1d convlution and dense layers '''
        to_conv1d = batch_sortpooling_graphs.view(
            (-1, 1, self.k * self.last_dim)) #[b,1,k*d]
     #   print("After reshaping, the size is:", to_conv1d.size())
        conv1d_res = self.conv1d_params1(to_conv1d)
        conv1d_res = F.relu(conv1d_res)
      #  print("After conv1, the shape is :", conv1d_res.size())

        conv1d_res = self.maxpool1d(conv1d_res)

       # print("After pooling, the shape is :", conv1d_res.size())
        conv1d_res = self.conv1d_params2(conv1d_res)
        conv1d_res = F.relu(conv1d_res)
        #print("After conv2, the shape is :", conv1d_res.size())

        to_dense = conv1d_res.view(len(graph_sizes), -1)

        if self.output_dim > 0:
            out_linear = self.out_params(to_dense)
            reluact_fp = F.relu(out_linear)
        else:
            reluact_fp = to_dense

        return F.relu(reluact_fp)
Пример #14
0
    def sortpooling_embedding(self, node_feat, edge_feat, n2f_sp, f2n_sp,
                              subhg_sp, hypergraph_sizes, node_hdegs,
                              hyperedge_sizes):
        ''' if exists edge feature, concatenate to node feature vector '''

        edge_feat = None  ###########################################################

        if edge_feat is not None:
            input_edge_linear = self.w_e2l(edge_feat)
            #             input_edge_linear = edge_feat
            e2npool_input = gnn_spmm(n2f_sp, input_edge_linear)
            node_feat = torch.cat([node_feat, e2npool_input], 1)
        ''' graph convolution layers '''
        lv = 0
        cur_message_layer = node_feat
        cat_message_layers = []
        while lv < 2 * len(self.latent_dim):
            #             # OPTION 1
            #             f2npool = gnn_spmm(f2n_sp, cur_message_layer)  # Y = S^T * X

            #             print(f2npool.shape)
            #             print(f2n_sp.shape)
            #             print(cur_message_layer.shape)
            #             print(node_feat.shape)

            #             node_linear = self.conv_params[lv](f2npool)  # Y = Y * W
            #             normalized_linear = node_linear.div(hyperedge_sizes)  # Y = sizes^-1 * Y
            #             cur_message_layer = torch.tanh(normalized_linear) # Z = tanh(Y)
            #             n2fpool = gnn_spmm(n2f_sp, cur_message_layer)  # Y = S * Z
            #             node_linear = nn.Linear(cur_message_layer.shape[1], cur_message_layer.shape[1])(n2fpool)  # Z = Z * W
            #             normalized_linear = node_linear.div(node_hdegs)  # Z = D_f^-1 * Z
            #             cur_message_layer = torch.tanh(normalized_linear)
            #             cat_message_layers.append(cur_message_layer)
            #             lv += 1

            # OPTION 2
            #             f2npool = gnn_spmm(f2n_sp, cur_message_layer)  # Y = S^T * X
            #             node_linear = self.conv_params[lv](f2npool)  # Y = Y * W
            #             normalized_linear = node_linear.div(hyperedge_sizes)  # Y = sizes^-1 * Y
            #             cur_message_layer = torch.tanh(normalized_linear) # Z = tanh(Y)
            #             lv += 1
            #             n2fpool = gnn_spmm(n2f_sp, cur_message_layer)  # Y = S * Z
            #             node_linear = self.conv_params[lv](n2fpool)  # Z = Z * W
            #             normalized_linear = node_linear.div(node_hdegs)  # Z = D_f^-1 * Z
            #             cur_message_layer = torch.tanh(normalized_linear)
            #             cat_message_layers.append(cur_message_layer)
            #             lv += 1

            # OPTION 3
            #             f2npool = gnn_spmm(f2n_sp, cur_message_layer)  # Y = S^T * X
            #             node_linear = self.conv_params[lv](f2npool)  # Y = Y * W
            #             normalized_linear = node_linear.div(hyperedge_sizes)  # Y = sizes^-1 * Y
            #             node_linear = gnn_spmm(n2f_sp, normalized_linear) # Y = S * Y
            #             normalized_linear = node_linear.div(node_hdegs) # Y = D_f^-1 * Z
            #             cur_message_layer = torch.tanh(normalized_linear) # Z = tanh(Y)
            #             cat_message_layers.append(cur_message_layer)
            #             lv += 1

            # OPTION 4
            #             n2fpool = gnn_spmm(n2f_sp, cur_message_layer)  # Y = S * Z

            #             print(n2fpool.shape)
            #             print(n2f_sp.shape)
            #             print(cur_message_layer.shape)
            #             print(edge_feat.shape)

            #             node_linear = self.conv_params[lv](n2fpool)  # Z = Z * W
            #             normalized_linear = node_linear.div(node_hdegs)  # Z = D_f^-1 * Z
            #             cur_message_layer = torch.tanh(normalized_linear)
            #             cat_message_layers.append(cur_message_layer)
            #             lv += 1

            #             f2npool = gnn_spmm(f2n_sp, cur_message_layer)  # Y = S^T * X
            #             node_linear = self.conv_params[lv](f2npool)  # Y = Y * W
            #             normalized_linear = node_linear.div(hyperedge_sizes)  # Y = sizes^-1 * Y
            #             cur_message_layer = torch.tanh(normalized_linear) # Z = tanh(Y)
            #             lv += 1

            # OPTION 2
            f2npool = gnn_spmm(f2n_sp, cur_message_layer)  # Y = S^T * X
            node_linear = self.conv_params[lv](f2npool)  # Y = Y * W
            normalized_linear = node_linear.div(
                hyperedge_sizes)  # Y = sizes^-1 * Y
            cur_message_layer = torch.tanh(normalized_linear)  # Z = tanh(Y)
            lv += 1
            n2fpool = gnn_spmm(n2f_sp, cur_message_layer)  # Y = S * Z
            node_linear = self.conv_params[lv](n2fpool)  # Z = Z * W
            normalized_linear = node_linear.div(node_hdegs)  # Z = D_f^-1 * Z
            cur_message_layer = torch.tanh(normalized_linear)
            cat_message_layers.append(cur_message_layer)
            lv += 1
        cur_message_layer = torch.cat(cat_message_layers, 1)
        ''' sortpooling layer '''
        sort_channel = cur_message_layer[:, -1]
        batch_sortpooling_graphs = torch.zeros(len(hypergraph_sizes), self.k,
                                               self.total_latent_dim)
        if torch.cuda.is_available() and isinstance(node_feat.data,
                                                    torch.cuda.FloatTensor):
            batch_sortpooling_graphs = batch_sortpooling_graphs.cuda()

        batch_sortpooling_graphs = Variable(batch_sortpooling_graphs)
        accum_count = 0
        for i in range(subhg_sp.size()[0]):
            to_sort = sort_channel[accum_count:accum_count +
                                   hypergraph_sizes[i]]
            k = self.k if self.k <= hypergraph_sizes[i] else hypergraph_sizes[i]
            _, topk_indices = to_sort.topk(k)
            topk_indices += accum_count
            sortpooling_graph = cur_message_layer.index_select(0, topk_indices)
            if k < self.k:
                to_pad = torch.zeros(self.k - k, self.total_latent_dim)
                if torch.cuda.is_available() and isinstance(
                        node_feat.data, torch.cuda.FloatTensor):
                    to_pad = to_pad.cuda()

                to_pad = Variable(to_pad)
                sortpooling_graph = torch.cat((sortpooling_graph, to_pad), 0)
            batch_sortpooling_graphs[i] = sortpooling_graph
            accum_count += hypergraph_sizes[i]
        ''' traditional 1d convlution and dense layers '''
        to_conv1d = batch_sortpooling_graphs.view(
            (-1, 1, self.k * self.total_latent_dim))
        conv1d_res = self.conv1d_params1(to_conv1d)
        conv1d_res = self.conv1d_activation(conv1d_res)
        conv1d_res = self.maxpool1d(conv1d_res)
        conv1d_res = self.conv1d_params2(conv1d_res)
        conv1d_res = self.conv1d_activation(conv1d_res)

        to_dense = conv1d_res.view(len(hypergraph_sizes), -1)

        if self.output_dim > 0:
            out_linear = self.out_params(to_dense)
            reluact_fp = self.conv1d_activation(out_linear)
        else:
            reluact_fp = to_dense

        return self.conv1d_activation(reluact_fp)