Ejemplo n.º 1
0
    def forward(self, data):
        x = self.atom_encoder(data.x)
        edge_index = data.edge_index
        edge_attr = data.edge_attr
        edge_attr = torch.LongTensor([
            edge_type[0] + edge_type[1] * 5 + edge_type[2] * 30
            for edge_type in edge_attr
        ]).to(self.device)
        for i, layer in enumerate(self.rgcn_list):
            x = layer(x, edge_index, edge_attr)
            x = F.relu(x)
            if i == len(self.rgcn_list) - 1: continue
            x = self.batchnorm(x)

        if self.pool_layer == 'add':
            x = global_add_pool(x, data.batch)
        if self.pool_layer == 'mean':
            x = global_mean_pool(x, data.batch)
        if self.pool_layer == 'max':
            x = global_max_pool(x, data.batch)
        if self.pool_layer == 'sort':
            x = global_sort_pool(x, data.batch, self.k)

        x = F.relu(self.linear1(x))
        x = self.linear2(x)

        return x
Ejemplo n.º 2
0
    def embed_graph(self, x, edge_index, batch=None):
        attn_weights = dict()

        x = F.one_hot(x, num_classes=self.config.num_feature_dim).float()
        x = F.relu(self.gc1(x, edge_index))
        x = F.dropout(x, self.config.dropout, training=self.training)
        x = self.gc2(x, edge_index)

        if self.config.pooling_type == "sagpool":
            x, edge_index, _, batch, attn_weights['pool_perm'], attn_weights[
                'pool_score'] = self.pool1(x, edge_index, batch=batch)
        elif self.config.pooling_type == "topk":
            x, edge_index, _, batch, attn_weights['pool_perm'], attn_weights[
                'pool_score'] = self.pool1(x, edge_index, batch=batch)
        elif self.config.pooling_type == "asa":
            x, edge_index, _, batch, attn_weights['pool_perm'] = self.pool1(
                x, edge_index, batch=batch)

        if self.config.readout_type == "add":
            x = global_add_pool(x, batch)
        elif self.config.readout_type == "mean":
            x = global_mean_pool(x, batch)
        elif self.config.readout_type == "max":
            x = global_max_pool(x, batch)
        elif self.config.readout_type == "sort":
            x = global_sort_pool(x, batch, k=100)
        else:
            pass

        attn_weights['batch'] = batch
        x = self.fc(x)
        return x, attn_weights
Ejemplo n.º 3
0
    def sortpooling_embedding_tg(self, data):
        '''
        if exists edge feature, concatenate to node feature vector
        '''

        node_feat, edge_index, batch = data.x, data.edge_index, data.batch
        '''G-UNet Layer to process the graph data'''
        # the output feature dimension of gUnet is
        cur_message_layer = self.gUnet(x=node_feat,
                                       edge_index=edge_index,
                                       batch=batch)
        ''' sortpooling layer '''
        # the shape of global_sort_pool is (B, k*total_latent_dim)
        batch_sortpooling_graphs = global_sort_pool(cur_message_layer, batch,
                                                    self.k)
        ''' traditional 1d convlution and dense layers '''
        to_conv1d = batch_sortpooling_graphs.view(
            (-1, 1, self.k * self.total_latent_dim))

        conv1d_res = self.conv1d_params1(to_conv1d)
        conv1d_res = F.relu(conv1d_res)
        conv1d_res = self.maxpool1d(conv1d_res)
        conv1d_res = self.conv1d_params2(conv1d_res)
        conv1d_res = F.relu(conv1d_res)

        to_dense = conv1d_res.view(batch_sortpooling_graphs.size(0), -1)

        if self.output_dim > 0:
            out_linear = self.out_params(to_dense)
            reluact_fp = F.relu(out_linear)
        else:
            reluact_fp = to_dense

        return F.relu(reluact_fp)
Ejemplo n.º 4
0
 def forward(self, x):
     return global_sort_pool(x=x,
                             batch=torch.tensor(
                                 [0 for i in range(x.size()[0])],
                                 dtype=torch.long,
                                 device=x.device),
                             k=self.k)
Ejemplo n.º 5
0
    def forward(self, data):
        # Implement Equation 4.2 of the paper i.e. concat all layers' graph representations and apply linear model
        # note: this can be decomposed in one smaller linear model per layer
        x, edge_index, batch = data.x, data.edge_index, data.batch

        hidden_repres = []

        for conv in self.convs:
            x = torch.tanh(conv(x, edge_index))
            hidden_repres.append(x)

        # apply sortpool
        x_to_sortpool = torch.cat(hidden_repres, dim=1)

        x_1d = global_sort_pool(x_to_sortpool, batch, self.k)  # in the code the authors sort the last channel only

        # apply 1D convolutional layers
        x_1d = torch.unsqueeze(x_1d, dim=1)
        conv1d_res = F.relu(self.conv1d_params1(x_1d))
        conv1d_res = self.maxpool1d(conv1d_res)
        conv1d_res = F.relu(self.conv1d_params2(conv1d_res))
        conv1d_res = conv1d_res.reshape(conv1d_res.shape[0], -1)

        # apply dense layer
        out_dense = self.dense_layer(conv1d_res)

        return out_dense
Ejemplo n.º 6
0
    def forward(self, z, edge_index, batch, x=None, edge_weight=None, node_id=None):
        z_emb = self.z_embedding(z)
        if z_emb.ndim == 3:  # in case z has multiple integer labels
            z_emb = z_emb.sum(dim=1)
        if self.use_feature and x is not None:
            x = torch.cat([z_emb, x.to(torch.float)], 1)
        else:
            x = z_emb
        if self.node_embedding is not None and node_id is not None:
            n_emb = self.node_embedding(node_id)
            x = torch.cat([x, n_emb], 1)
        xs = [x]

        for conv in self.convs:
            xs += [torch.tanh(conv(xs[-1], edge_index, edge_weight))]
        x = torch.cat(xs[1:], dim=-1)

        # Global pooling.
        if self.random_pool: 
            x = global_random_pool(x, batch, self.k)
        else: 
            x = global_sort_pool(x, batch, self.k)
        x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]
        x = F.relu(self.conv1(x))
        x = self.maxpool1d(x)
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]

        # MLP.
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x
Ejemplo n.º 7
0
    def forward(self, data):
        # Implement Equation 4.2 of the paper i.e. concat all layers' graph representations and apply linear model
        # note: this can be decomposed in one smaller linear model per layer
        x, edge_index, batch = data.x, data.edge_index, data.batch
        # print("in DGCNN:",edge_index.shape)
        hidden_repres = []  # w 存储中间的输出

        for conv in self.convs:
            x = torch.tanh(conv(x, edge_index))
            hidden_repres.append(x)

        # apply sortpool
        x_to_sortpool = torch.cat(hidden_repres, dim=1)
        x_1d = global_sort_pool(
            x_to_sortpool, batch,
            self.k)  # in the code the authors sort the last channel only
        #w x_id 的shape: b, k*f
        # apply 1D convolutional layers
        x_1d = torch.unsqueeze(
            x_1d, dim=1
        )  #w x_id 的shape: b,1, k*f ,nn.Conv1d的输入是N,Cin,signal,也就是Cin个平面,f=total_latent_dim
        conv1d_res = F.relu(self.conv1d_params1(x_1d))
        conv1d_res = self.maxpool1d(conv1d_res)
        conv1d_res = F.relu(self.conv1d_params2(conv1d_res))
        conv1d_res = conv1d_res.reshape(conv1d_res.shape[0],
                                        -1)  #w 把平面去点,得到b,32*...

        # apply dense layer
        out_dense = self.dense_layer(conv1d_res)
        return out_dense
Ejemplo n.º 8
0
    def forward(self, x, edge_index, edge_attr, *args, **kwargs):
        """
        multi-dimensional edge_attr is implemented as separate channels and concatenated before dense layer.
        :param x: Node feature matrix with shape [num_nodes, num_node_features]
        :param edge_index: Graph connectivity in COO format with shape [2, num_edges] and type torch.long
        :param edge_attr: Edge feature matrix with shape [num_edges, num_edge_features], num_edge_features >= 1
        :param args:
        :param kwargs:
        :return:
        """
        print(getattr(self, "conv1d1_channel1"))
        if x.dim() == 3:  # a batch
            if x.shape[0] != 1:
                raise Exception("batch size greater than 1 is not supported.")
            x, edge_index, edge_attr = x.squeeze(0), edge_index.squeeze(
                0), edge_attr.squeeze(0)

        for i, channel in enumerate(range(self.channels)):
            e = edge_attr[:, i]
            xx = self.gcnconv1[channel](x, edge_index, e)
            xx = torch.cat([
                xx, self.gcnconv2[channel](xx[:, -1].unsqueeze(-1), edge_index,
                                           e)
            ],
                           dim=-1)
            xx = torch.cat([
                xx, self.gcnconv3[channel](xx[:, -1].unsqueeze(-1), edge_index,
                                           e)
            ],
                           dim=-1)
            xx = torch.cat([
                xx, self.gcnconv4[channel](xx[:, -1].unsqueeze(-1), edge_index,
                                           e)
            ],
                           dim=-1)
            xx = torch.cat([
                xx, self.gcnconv5[channel](xx[:, -1].unsqueeze(-1), edge_index,
                                           e)
            ],
                           dim=-1)
            N, D = xx.size()
            xx = global_sort_pool(x=xx,
                                  batch=torch.tensor(
                                      [0 for i in range(self.num_nodes)],
                                      dtype=torch.long),
                                  k=self.num_nodes)
            xx = xx.view(-1, N, D).permute(0, 2, 1)
            xx = self.conv1d1[channel](xx)
            xx = self.maxpool1[channel](xx)
            xx = self.conv1d2[channel](xx)
            xx = self.maxpool2[channel](xx)

            all_x = xx if i == 0 else torch.cat([all_x, xx], dim=0)

        x = all_x.view(1, -1)
        x = F.elu(self.drop1(self.fc1(x)))
        x = F.elu(self.drop2(self.fc2(x)))
        x = self.fc3(x)

        return x
Ejemplo n.º 9
0
 def forward(self, data):
     x, edge_index, edge_type, batch = data.x, data.edge_index, data.edge_type, data.batch
     if self.adj_dropout > 0:
         edge_index, edge_type = dropout_adj(
             edge_index,
             edge_type,
             p=self.adj_dropout,
             force_undirected=self.force_undirected,
             num_nodes=len(x),
             training=self.training)
     concat_states = []
     for conv in self.convs:
         x = torch.tanh(conv(x, edge_index, edge_type))
         concat_states.append(x)
     concat_states = torch.cat(concat_states, 1)
     x = global_sort_pool(concat_states, batch,
                          self.k)  # batch * (k*hidden)
     x = x.unsqueeze(1)  # batch * 1 * (k*hidden)
     x = F.relu(self.conv1d_params1(x))
     x = self.maxpool1d(x)
     x = F.relu(self.conv1d_params2(x))
     x = x.view(len(x), -1)  # flatten
     x = F.relu(self.lin1(x))
     x = F.dropout(x, p=0.5, training=self.training)
     x = self.lin2(x)
     if self.regression:
         return x[:, 0]
     else:
         return F.log_softmax(x, dim=-1)
Ejemplo n.º 10
0
    def forward(self, data):
        x = self.atom_encoder(data.x)
        edge_index = data.edge_index
        edge_attr = data.edge_attr
        edge_attr = torch.LongTensor([
            edge_type[0] + edge_type[1] * 5 + edge_type[2] * 30
            for edge_type in edge_attr
        ]).to(self.device)
        for i in range(len(self.rgcn_list)):
            x_rgcn = self.rgcn_list[i](x, edge_index, edge_attr)
            x_gconv = self.graphconv_list[i](x, edge_index)
            x = torch.cat((x_rgcn, x_gconv), 1)
            x = F.relu(x)
            if i == len(self.rgcn_list) - 1: continue
            x = self.batchnorm(x)

        # x = self.graph_conv(x,edge_index)
        # x = F.relu(x)

        if self.pool_layer == 'add':
            x = global_add_pool(x, data.batch)
        if self.pool_layer == 'mean':
            x = global_mean_pool(x, data.batch)
        if self.pool_layer == 'max':
            x = global_max_pool(x, data.batch)
        if self.pool_layer == 'sort':
            x = global_sort_pool(x, data.batch, self.k)

        x = F.relu(self.linear1(x))
        x = self.linear2(x)

        return x
Ejemplo n.º 11
0
    def forward(self, sample):
        x, edge_index = sample.x, sample.edge_index

        # Dropout layer
        # edge_index = self.dropout_edges(edge_index, dropout=0.2)

        x = self.dense_input(x, self.empty_edges)
        x = F.gelu(x)

        x = self.input(x, edge_index)
        x = F.gelu(x)
        x = self.conv1(x, edge_index)
        x = F.gelu(x)

        # if self.pooling_layers > 1:
        #     batch = torch.tensor([0 for _ in x], dtype=torch.long, device=self.device)
        #     pooled = self.topkpool1(x, edge_index, batch=batch)
        #     x, edge_index = pooled[0], pooled[1]

        x = self.conv2(x, edge_index)
        x = F.gelu(x)

        if self.pooling_layers > 0:
            batch = torch.tensor([0 for _ in x], dtype=torch.long, device=self.device)
            pooled = self.topkpool2(x, edge_index, batch=batch)
            x, edge_index = pooled[0], pooled[1]

        x = self.conv3(x, edge_index)
        x = F.gelu(x)

        # For large graphs
        # while len(x) > 8:
        #     batch = torch.tensor([0 for _ in x], dtype=torch.long, device=self.device)
        #     pooled = self.topkpool3(x, edge_index, batch=batch)
        #     x, edge_index = pooled[0], pooled[1]
        #     x = self.conv4(x, edge_index)
        #     x = F.gelu(x)

        batch = torch.tensor([0 for _ in x], dtype=torch.long, device=self.device)
        # With sort_pool it works but we have the same problem: the output layer learns the order of the pooled nodes
        # using k = 3, let's see what happens by shuffling the nodes
        if self.final_pooling == "avg_pool_x":
            cluster = torch.as_tensor([i % self.final_nodes for i in range(len(x))], device=self.device)
            (x, cluster) = avg_pool_x(cluster, x, batch)
        elif self.final_pooling == "sort_pooling":
            x = global_sort_pool(x, batch, self.final_nodes)
        elif self.final_pooling == "topk" or self.final_pooling == "asap" or self.final_pooling == "sag":
            pooled = self.last_pooling_layer(x, edge_index)
            x = pooled[0]
        elif self.final_pooling == "max_pool_x":
            cluster = torch.as_tensor([i % self.final_nodes for i in range(len(x))], device=self.device)
            (x, cluster) = max_pool_x(cluster, x, batch)
            # (x2, cluster2) = avg_pool_x(cluster, x, batch)
            # x = torch.cat([x1.view(-1), x2.view(-1)])

        return self.output(x.view(-1))
Ejemplo n.º 12
0
 def forward(self, data):
     x, edge_index, batch = data.x, data.edge_index, data.batch
     x = F.relu(self.conv1(x, edge_index))
     for conv in self.convs:
         x = F.relu(conv(x, edge_index))
     x = global_sort_pool(x, batch, self.k)
     x = F.relu(self.lin1(x))
     x = F.dropout(x, p=0.5, training=self.training)
     x = self.lin2(x)
     return F.log_softmax(x, dim=-1)
Ejemplo n.º 13
0
 def forward(self, x, batch):
     batch_size = batch.max() + 1
     re = global_sort_pool(x, batch, self.k)
     # re shape bs * (k * out_dim) 2910
     re = re.unsqueeze(2).transpose(2, 1)
     conv1 = F.relu(self.conv1d_p1(re))
     conv1 = self.pool(conv1)
     conv2 = F.relu(self.conv1d_p2(conv1))
     to_dense = conv2.view(batch_size, -1)
     out = F.relu(to_dense)
     return out
Ejemplo n.º 14
0
    def forward(self, x, edge_index, edge_attr, *args, **kwargs):
        """
        multi-dimensional edge_attr is implemented as separate channels and concatenated before dense layer.
        :param x: Node feature matrix with shape [num_nodes, num_node_features]
        :param edge_index: Graph connectivity in COO format with shape [2, num_edges] and type torch.long
        :param edge_attr: Edge feature matrix with shape [num_edges, num_edge_features], num_edge_features >= 1
        :param args:
        :param kwargs:
        :return:
        """
        if x.dim() == 3:  # a batch
            if x.shape[0] != 1:
                raise Exception("batch size greater than 1 is not supported.")
            x, edge_index, edge_attr = x.squeeze(0), edge_index.squeeze(
                0), edge_attr.squeeze(0)

        for i, channel in enumerate(range(self.channels)):
            e = edge_attr[:, i]
            # gcn conv
            for conv_level in range(1, self.gcn_conv_level + 1):
                conv = getattr(
                    self, "gcnconv{}_channel{}".format(conv_level, channel))
                if conv_level == 1:
                    xx = conv(x, edge_index, e)
                else:
                    xx = torch.cat(
                        [xx, conv(xx[:, -1].view(-1, 1), edge_index, e)],
                        dim=-1)
            # sort pool
            N, D = xx.size()
            xx = global_sort_pool(x=xx,
                                  batch=torch.tensor(
                                      [0 for i in range(self.num_nodes)],
                                      dtype=torch.long,
                                      device=xx.device),
                                  k=self.num_nodes)
            xx = xx.view(-1, N, D).permute(0, 2, 1)
            # conv and pool
            xx = getattr(self, "conv1d1_channel{}".format(channel))(xx)
            xx = getattr(self, "maxpool1_channel{}".format(channel))(xx)
            xx = getattr(self, "conv1d2_channel{}".format(channel))(xx)
            xx = getattr(self, "maxpool2_channel{}".format(channel))(xx)

            all_x = xx if i == 0 else torch.cat([all_x, xx], dim=0)

        x = all_x.view(1, -1)
        x = F.elu(self.drop1(self.fc1(x)))
        x = F.elu(self.drop2(self.fc2(x)))
        x = self.fc3(x)

        print(x)
        return x
    def forward(self, x, edge_index, batch):
        # x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))

        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        x = global_sort_pool(x, batch, self.k)
        x = x.view(len(x), self.k, -1).permute(0, 2, 1)
        x = F.relu(self.conv1d(x))
        x = x.view(len(x), -1)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x
Ejemplo n.º 16
0
 def forward(self, feats_in_l, idx_targets, sizes_subg):
     if self.type_pool == 'center':
         if self.type_res == 'none':
             return feats_in_l[-1][idx_targets]
         else:  # regular JK
             feats_root_l = [f[idx_targets] for f in feats_in_l]
             feat_in = self.f_residue(feats_root_l)
     elif self.type_pool in ['max', 'mean', 'sum']:
         # first pool subgraph at each layer, then residue
         offsets = torch.cumsum(sizes_subg, dim=0)
         offsets = torch.roll(offsets, 1)
         offsets[0] = 0
         idx = torch.arange(feats_in_l[-1].shape[0]).to(
             feats_in_l[-1].device)
         if self.type_res == 'none':
             feat_pool = F.embedding_bag(idx,
                                         feats_in_l[-1],
                                         offsets,
                                         mode=self.type_pool)
             feat_root = feats_in_l[-1][idx_targets]
         else:
             feat_pool_l = []
             for feat in feats_in_l:
                 feat_pool = F.embedding_bag(idx,
                                             feat,
                                             offsets,
                                             mode=self.type_pool)
                 feat_pool_l.append(feat_pool)
             feat_pool = self.f_residue(feat_pool_l)
             feat_root = self.f_residue(
                 [f[idx_targets] for f in feats_in_l])
         feat_in = torch.cat([feat_root, feat_pool], dim=1)
     elif self.type_pool == 'sort':
         if self.type_res == 'none':
             feat_pool_in = feats_in_l[-1]
             feat_root = feats_in_l[-1][idx_targets]
         else:
             feat_pool_in = self.f_residue(feats_in_l)
             feat_root = self.f_residue(
                 [f[idx_targets] for f in feats_in_l])
         arange = torch.arange(sizes_subg.size(0)).to(sizes_subg.device)
         idx_batch = torch.repeat_interleave(arange, sizes_subg)
         feat_pool_k = global_sort_pool(feat_pool_in, idx_batch,
                                        self.k)  # #subg x (k * F)
         feat_pool = self.nn_pool(feat_pool_k)
         feat_in = torch.cat([feat_root, feat_pool], dim=1)
     else:
         raise NotImplementedError
     return self.f_norm(self.nn(feat_in))
Ejemplo n.º 17
0
    def forward(self, x, edge_index, batch):
        xs = [x]
        for conv in self.convs:
            xs += [conv(xs[-1], edge_index).tanh()]
        x = torch.cat(xs[1:], dim=-1)

        # Global pooling.
        x = global_sort_pool(x, batch, self.k)
        x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]
        x = self.conv1(x).relu()
        x = self.maxpool1d(x)
        x = self.conv2(x).relu()
        x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]

        return self.mlp(x)
Ejemplo n.º 18
0
    def forward(self, data):
        x, edge_index, edge_weight, batch = data.x, data.edge_index, data.edge_weight, data.batch
        edge_index2, edge_weight2 = data.edge_index2, data.edge_weight2
        
        x0,x1,x2 = self.ib1(x, edge_index, edge_weight, edge_index2, edge_weight2)
        x0 = F.dropout(x0, p=self.dropout, training=self.training)
        x1 = F.dropout(x1, p=self.dropout, training=self.training)
        x2 = F.dropout(x2, p=self.dropout, training=self.training)
        x = x0+x1+x2
        x = F.dropout(x, p=self.dropout, training=self.training)

        x0,x1,x2 = self.ib2(x, edge_index, edge_weight, edge_index2, edge_weight2)
        x0 = F.dropout(x0, p=self.dropout, training=self.training)
        x1 = F.dropout(x1, p=self.dropout, training=self.training)
        x2 = F.dropout(x2, p=self.dropout, training=self.training)
        x = x0+x1+x2
        x = F.dropout(x, p=self.dropout, training=self.training)

        x0,x1,x2 = self.ib3(x, edge_index, edge_weight, edge_index2, edge_weight2)
        x0 = F.dropout(x0, p=self.dropout, training=self.training)
        x1 = F.dropout(x1, p=self.dropout, training=self.training)
        x2 = F.dropout(x2, p=self.dropout, training=self.training)
        x = x0+x1+x2

        x0,x1,x2 = self.final(x, edge_index, edge_weight, edge_index2, edge_weight2)
        x0 = F.dropout(x0, p=self.dropout, training=self.training)
        x1 = F.dropout(x1, p=self.dropout, training=self.training)
        x2 = F.dropout(x2, p=self.dropout, training=self.training)
        y = x0+x1+x2

        x = torch.cat([x, y], 1)

         # Global pooling.
        x = global_sort_pool(x, batch, self.k)
        x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]
        x = F.relu(self.conv1(x))
        x = self.maxpool1d(x)                           
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]

        # MLP.
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)

        return x
Ejemplo n.º 19
0
def test_global_sort_pool():
    N_1, N_2 = 4, 6
    x = torch.randn(N_1 + N_2, 4)
    batch = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])

    out = global_sort_pool(x, batch, k=5)
    assert out.size() == (2, 5 * 4)
    out = out.view(2, 5, 4)

    # First graph output has been filled up with zeros.
    assert out[0, -1].tolist() == [0, 0, 0, 0]

    # Nodes are sorted.
    expected = 3 - torch.arange(4)
    assert out[0, :4, -1].argsort().tolist() == expected.tolist()

    expected = 4 - torch.arange(5)
    assert out[1, :, -1].argsort().tolist() == expected.tolist()
Ejemplo n.º 20
0
Archivo: gnn2.py Proyecto: vthost/DAGNN
    def sortpooling_embedding(self, batched_data):
        x, edge_index, node_depth, batch = batched_data.x, batched_data.edge_index,  batched_data.node_depth, batched_data.batch

        x = self.node_encoder(x, node_depth.view(-1,))

        node_feat = x
        edge_feat = batched_data.edge_attr if hasattr(batched_data, 'edge_attr') and \
                                                                         batched_data.edge_attr is not None else None
        ''' if exists edge feature, concatenate to node feature vector '''
        if edge_feat is not None:
            # we added inverse before..
            # edge_index = torch.cat([edge_index, torch.stack([edge_index[1],edge_index[0]], dim=0)], dim=-1)
            e2n_sp = torch.zeros(x.shape[0], edge_index.shape[1]).to(edge_feat.device).scatter_(0, edge_index, 1)
            e2npool_input = torch.mm(e2n_sp, edge_feat)
            node_feat = torch.cat([node_feat, e2npool_input], 1)

        ''' graph convolution layers '''
        cur_message_layer = self.compute_message_layers(node_feat, edge_index, batch)  # put in extra function to reuse rest in unet

        ''' sortpooling layer '''
        batch_sortpooling_graphs = global_sort_pool(cur_message_layer, batch, self.k)

        ''' traditional 1d convlution and dense layers '''
        to_conv1d = batch_sortpooling_graphs.view((-1, 1, self.k * self.total_latent_dim))
        conv1d_res = self.conv1d_params1(to_conv1d)
        conv1d_res = self.conv1d_activation(conv1d_res)
        conv1d_res = self.maxpool1d(conv1d_res)
        conv1d_res = self.conv1d_params2(conv1d_res)
        conv1d_res = self.conv1d_activation(conv1d_res)

        to_dense = conv1d_res.view(conv1d_res.shape[0], -1)

        # if self.output_dim > 0:
        #     out_linear = self.out_params(to_dense)
        #     reluact_fp = self.conv1d_activation(out_linear)
        # else:
        #     reluact_fp = to_dense
        #
        # return self.mlp(self.conv1d_activation(reluact_fp))
        pred_list = []
        for i in range(self.max_seq_len):
            pred_list.append(self.graph_pred_linear_list[i](to_dense))

        return pred_list
Ejemplo n.º 21
0
def test_global_sort_pool():
    N_1, N_2 = 4, 6
    x = torch.randn(N_1 + N_2, 4)
    batch = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])

    out = global_sort_pool(x, batch, k=5)
    assert out.size() == (2, 5 * 4)
    out = out.view(2, 5, 4)

    # Features are individually sorted.
    expected = torch.arange(4).view(1, 1, 4).expand_as(out)
    assert out.argsort(dim=2).tolist() == expected.tolist()

    # First graph output has been filled up with zeros.
    assert out[0, -1].tolist() == [0, 0, 0, 0]

    # Nodes are sorted.
    expected = 4 - torch.arange(5).view(1, 5).expand(2, 5)
    assert out.argsort(dim=1)[:, :, -1].tolist() == expected.tolist()
Ejemplo n.º 22
0
    def forward(self, x, edge_index, batch):
        xs = [x]
        for conv in self.convs:
            xs += [torch.tanh(conv(xs[-1], edge_index))]
        x = torch.cat(xs[1:], dim=-1)

        # Global pooling.
        x = global_sort_pool(x, batch, self.k)
        x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]
        x = F.relu(self.conv1(x))
        x = self.maxpool1d(x)
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]

        # MLP.
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x
Ejemplo n.º 23
0
    def sortpooling_embedding_tg(self, data):
        '''
        if exists edge feature, concatenate to node feature vector
        '''

        node_feat, edge_index, batch = data.x, data.edge_index, data.batch
        # TODO: remove edge_attr consideration
        # if data.edge_attr is not None:
        #     input_edge_linear = self.w_e2l(data.edge_attr)
        #     e2npool_input = gnn_spmm(e2n_sp, input_edge_linear)
        #     node_feat = torch.cat([node_feat, e2npool_input], 1)
        '''G-UNet Layer to process the graph data'''
        # the output feature dimension of gUnet is
        cur_message_layer = self.gUnet(x=node_feat,
                                       edge_index=edge_index,
                                       batch=batch)

        # X = torch.cat([cur_message_layer, node_feat], 1)
        # cur_message_layer = self.end_gcn(X, edge_index=edge_index, batch=batch)
        ''' sortpooling layer '''
        # the shape of global_sort_pool is (B, k*total_latent_dim)
        batch_sortpooling_graphs = global_sort_pool(cur_message_layer, batch,
                                                    self.k)
        ''' traditional 1d convlution and dense layers '''
        to_conv1d = batch_sortpooling_graphs.view(
            (-1, 1, self.k * self.total_latent_dim))

        conv1d_res = self.conv1d_params1(to_conv1d)
        conv1d_res = F.relu(conv1d_res)
        conv1d_res = self.maxpool1d(conv1d_res)
        conv1d_res = self.conv1d_params2(conv1d_res)
        conv1d_res = F.relu(conv1d_res)

        to_dense = conv1d_res.view(batch_sortpooling_graphs.size(0), -1)

        if self.output_dim > 0:
            out_linear = self.out_params(to_dense)
            reluact_fp = F.relu(out_linear)
        else:
            reluact_fp = to_dense

        return F.relu(reluact_fp)
Ejemplo n.º 24
0
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        edge_index, _ = remove_self_loops(edge_index)

        x_1 = torch.tanh(self.conv1(x, edge_index))
        x_2 = torch.tanh(self.conv2(x_1, edge_index))
        x_3 = torch.tanh(self.conv3(x_2, edge_index))
        x_4 = torch.tanh(self.conv4(x_3, edge_index))
        x = torch.cat([x_1, x_2, x_3, x_4], dim=-1)
        x = global_sort_pool(x, batch, k=30)
        x = x.view(x.size(0), 1, x.size(-1))
        x = self.relu(self.conv5(x))
        x = self.pool(x)
        x = self.relu(self.conv6(x))
        x = x.view(x.size(0), -1)
        out = self.relu(self.classifier_1(x))
        out = self.drop_out(out)
        classes = F.log_softmax(self.classifier_2(out), dim=-1)

        return classes
Ejemplo n.º 25
0
def test_global_sort_pool_smaller_than_k():
    N_1, N_2 = 4, 6
    x = torch.randn(N_1 + N_2, 4)
    batch = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])

    # Set k which is bigger than both N_1=4 and N_2=6.
    out = global_sort_pool(x, batch, k=10)
    assert out.size() == (2, 10 * 4)
    out = out.view(2, 10, 4)

    # Both graph outputs have been filled up with zeros.
    assert out[0, -1].tolist() == [0, 0, 0, 0]
    assert out[1, -1].tolist() == [0, 0, 0, 0]

    # Nodes are sorted.
    expected = 3 - torch.arange(4)
    assert out[0, :4, -1].argsort().tolist() == expected.tolist()

    expected = 5 - torch.arange(6)
    assert out[1, :6, -1].argsort().tolist() == expected.tolist()
Ejemplo n.º 26
0
    def forward(self, data):
        # Implement Equation 4.2 of the paper i.e. concat all layers' graph representations and apply linear model
        # note: this can be decomposed in one smaller linear model per layer
        x, edge_index, batch = data.x, data.edge_index, data.batch
        # print("in DGCNN:",edge_index.shape)
        hidden_repres = []

        for conv in self.convs:
            x = torch.tanh(conv(x, edge_index))
            hidden_repres.append(x)

        # apply sortpool
        x_to_sortpool = torch.cat(hidden_repres, dim=1)
        x_1d = global_sort_pool(
            x_to_sortpool, batch,
            self.k)  # in the code the authors sort the last channel only
        # apply 1D convolutional layers
        x_1d = torch.unsqueeze(x_1d, dim=1)
        conv1d_res = F.relu(self.conv1d_params1(x_1d))
        conv1d_res = self.maxpool1d(conv1d_res)
        conv1d_res = F.relu(self.conv1d_params2(conv1d_res))
        conv1d_res = conv1d_res.reshape(conv1d_res.shape[0], -1)

        # apply dense layer
        out_dense = self.dense_layer(conv1d_res)

        # Don't apply sigmoid during training b/c using BCEWithLogitsLoss
        if self.classification and not self.training:
            out_dense = self.sigmoid(out_dense)
        if self.multiclass:
            out_dense = out_dense.reshape(
                (out_dense.size(0), -1, self.multiclass_num_classes
                 ))  # batch size x num targets x num classes per target
            if not self.training:
                out_dense = self.multiclass_softmax(
                    out_dense
                )  # to get probabilities during evaluation, but not during training as we're using CrossEntropyLoss

        return out_dense
Ejemplo n.º 27
0
 def forward(self, x, edge_index, edge_weights, batch, debug=False):
     """
     :param x: nodes' features (Nodes x Features_In)
     :param edge_index: COO-formatted sparse graph edges(2 x Edges)
     :param edge_weights: the weights of corresponding edges
     :param batch: actually just a graph labels. Indicates, which data belongs to the specific graph in the batch.
         See pytorch geometric docs.
     :param debug: whether to print shapes after each layer or not
     :return: a tensor containing labels for every graph in the batch
     """
     if debug:
         print(x.shape)
     x = self.sage1(x, edge_index, edge_weights)  # (NxFI, NxN) --> (NxFO)
     if debug:
         print(x.shape)
     x, edge_index, edge_weights, batch, _, _ = self.pooling1(x, edge_index, edge_weights, batch)
     if debug:
         print(x.shape)
     x = func.dropout(x, training=self.training)
     if debug:
         print(x.shape)
     x = self.sage2(x, edge_index, edge_weights)
     if debug:
         print(x.shape)
     x, edge_index, edge_weights, batch, _, _ = self.pooling2(x, edge_index, edge_weights, batch)
     x = func.dropout(x, training=self.training)
     if debug:
         print(x.shape)
     x = gnn.global_sort_pool(x, batch, 10)
     if debug:
         print(x.shape)
     x = func.relu(self.lin1(x))
     if debug:
         print(x.shape)
     x = func.relu(self.lin2(x))
     if debug:
         print(x.shape)
     return x
Ejemplo n.º 28
0
  def forward(self, data):
      x = self.atom_encoder(data.x)
      edge_index = data.edge_index
      
      for i, layer in enumerate(self.graph_conv_list) : 
          x = layer(x, edge_index)
          x = F.relu(x)
          if i == len(self.graph_conv_list) - 1: continue
          x = self.batchnorm(x)  
          
      if self.pool_layer == 'add':
          x = global_add_pool(x, data.batch)
      if self.pool_layer == 'mean':
          x = global_mean_pool(x, data.batch)
      if self.pool_layer == 'max':
          x = global_max_pool(x, data.batch)
      if self.pool_layer == 'sort':
          x = global_sort_pool(x, data.batch, self.k)
 
      x = F.relu(self.linear1(x))
      x = self.linear2(x)
      
      return x
Ejemplo n.º 29
0
    def forward(self, data):
        x = data.x if self.args.use_feature else None
        z1, z2, w, edge_index, batch, x, edge_weight, node_id = data.z1, data.z2, data.w, data.edge_index, data.batch, x, data.edge_weight, None

        z1_emb = self.z1_embedding(z1)
        z2_emb = self.z2_embedding(z2)
        w_emb = self.w_embedding(w)

        z_emb = torch.cat([w_emb, z1_emb], 1)
        z_emb = torch.cat([z_emb, z2_emb], 1)

        if self.use_feature and x is not None:
            x = torch.cat([z_emb, x.to(torch.float)], 1)
        else:
            x = z_emb
        if self.node_embedding is not None and node_id is not None:
            n_emb = self.node_embedding(node_id)
            x = torch.cat([x, n_emb], 1)
        xs = [x]

        for conv in self.convs:
            xs += [torch.tanh(conv(xs[-1], edge_index))]
        x = torch.cat(xs[1:], dim=-1)

        # Global pooling.
        x = global_sort_pool(x, batch, self.k)
        x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]
        x = F.relu(self.conv1(x))
        x = self.maxpool1d(x)
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]

        # MLP.
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.lin2(x))
        return x
Ejemplo n.º 30
0
    def forward(self, x, edge_index, batch=None):
        attn_weights = dict()

        x = F.one_hot(x, num_classes=self.config.num_feature_dim).float()
        for layer in range(self.config.num_layers - 1):
            x = F.relu(self.gin_convs[layer](x, edge_index))
            x = self.batch_norms[layer](x)

        if self.config.pooling_type == "sagpool":
            x, edge_index, _, batch, attn_weights['pool_perm'], attn_weights[
                'pool_score'] = self.pool1(x, edge_index, batch=batch)
        elif self.config.pooling_type == "topk":
            x, edge_index, _, batch, attn_weights['pool_perm'], attn_weights[
                'pool_score'] = self.pool1(x, edge_index, batch=batch)
        elif self.config.pooling_type == "asa":
            x, edge_index, _, batch, attn_weights['pool_perm'] = self.pool1(
                x, edge_index, batch=batch)
        else:
            pass

        if self.config.readout_type == "add":
            x = global_add_pool(x, batch)
        elif self.config.readout_type == "mean":
            x = global_mean_pool(x, batch)
        elif self.config.readout_type == "max":
            x = global_max_pool(x, batch)
        elif self.config.readout_type == "sort":
            x = global_sort_pool(x, batch, k=100)
        else:
            pass
        attn_weights['batch'] = batch
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)

        return F.log_softmax(x, dim=-1), attn_weights