def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        if data.num_node_features == 0:
            x = torch.ones(data.num_nodes, 1)

        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            if not i == self.num_layers - 1:
                x = self.norm[i](x)

        if not self.global_pool:
            x = pyg_nn.global_mean_pool(x, batch)
        elif self.global_pool == 'max':
            x = pyg_nn.global_max_pool(x, batch)
        elif self.global_pool == 'mix':
            x1 = pyg_nn.global_mean_pool(x, batch)
            x2 = pyg_nn.global_max_pool(x, batch)
            x = torch.cat((x1, x2), 1)

        x = self.post_mp(x)
        emb = x
        out = F.log_softmax(x, dim=1)

        return emb, out
Exemple #2
0
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        batch = data.batch if hasattr(data, 'batch') else None

        x = F.relu(self.conv1(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([global_max_pool(x, batch),
                        global_mean_pool(x, batch)],
                       dim=1)

        x = F.relu(self.conv2(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([global_max_pool(x, batch),
                        global_mean_pool(x, batch)],
                       dim=1)

        x = F.relu(self.conv3(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([global_max_pool(x, batch),
                        global_mean_pool(x, batch)],
                       dim=1)

        x = x1 + x2 + x3

        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.lin2(x))

        return x
Exemple #3
0
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat(
            [gnn.global_max_pool(x, batch),
             gnn.global_mean_pool(x, batch)],
            dim=1)

        x = F.relu(self.conv2(x, edge_index))
        x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat(
            [gnn.global_max_pool(x, batch),
             gnn.global_mean_pool(x, batch)],
            dim=1)

        x = F.relu(self.conv3(x, edge_index))
        x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)
        x3 = torch.cat(
            [gnn.global_max_pool(x, batch),
             gnn.global_mean_pool(x, batch)],
            dim=1)

        x = x1 + x2 + x3

        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.lin2(x))
        x = F.log_softmax(self.lin3(x), dim=-1)

        return x
Exemple #4
0
def test_global_pool():
    N_1, N_2 = 4, 6
    x = torch.randn(N_1 + N_2, 4)
    batch = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])

    out = global_add_pool(x, batch)
    assert out.size() == (2, 4)
    assert out[0].tolist() == x[:4].sum(dim=0).tolist()
    assert out[1].tolist() == x[4:].sum(dim=0).tolist()

    out = global_add_pool(x, None)
    assert out.size() == (1, 4)
    assert out.tolist() == x.sum(dim=0, keepdim=True).tolist()

    out = global_mean_pool(x, batch)
    assert out.size() == (2, 4)
    assert out[0].tolist() == x[:4].mean(dim=0).tolist()
    assert out[1].tolist() == x[4:].mean(dim=0).tolist()

    out = global_mean_pool(x, None)
    assert out.size() == (1, 4)
    assert out.tolist() == x.mean(dim=0, keepdim=True).tolist()

    out = global_max_pool(x, batch)
    assert out.size() == (2, 4)
    assert out[0].tolist() == x[:4].max(dim=0)[0].tolist()
    assert out[1].tolist() == x[4:].max(dim=0)[0].tolist()

    out = global_max_pool(x, None)
    assert out.size() == (1, 4)
    assert out.tolist() == x.max(dim=0, keepdim=True)[0].tolist()
Exemple #5
0
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.item_embedding(x).squeeze(1)

        x = F.relu(self.conv1(x, edge_index))
        x, edge_index, _, batch, *_ = self.pool1(x, edge_index, batch=batch)
        x1 = torch.cat([global_max_pool(x, batch),
                        global_mean_pool(x, batch)],
                       dim=1)

        x = F.relu(self.conv2(x, edge_index))
        x, edge_index, _, batch, *_ = self.pool2(x, edge_index, batch=batch)
        x2 = torch.cat([global_max_pool(x, batch),
                        global_mean_pool(x, batch)],
                       dim=1)

        x = F.relu(self.conv3(x, edge_index))
        x, edge_index, _, batch, *_ = self.pool3(x, edge_index, batch=batch)
        x3 = torch.cat([global_max_pool(x, batch),
                        global_mean_pool(x, batch)],
                       dim=1)

        x = x1 + x2 + x3

        x = self.fc1(x)
        x = self.fc2(x)
        x = self.drop(x)

        x = torch.sigmoid(self.linear(x)).squeeze(1)

        return x
Exemple #6
0
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        ############################################################################
        # TODO: Your code here! 
        # Each layer in GNN should consist of a convolution (specified in model_type),
        # a non-linearity (use RELU), and dropout. 
        # HINT: the __init__ function contains parameters you will need. You may 
        # also find pyg_nn.global_max_pool useful for graph classification.
        # Our implementation is ~6 lines, but don't worry if you deviate from this.

        # sanity check
        if data.num_node_features == 0:
            x = torch.ones(data.num_nodes, 1)

        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            emb = x
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        # global_mean_pool will lead to a terrible result
        if self.task == 'graph':
            x = pyg_nn.global_max_pool(x, batch)

        ############################################################################

        x = self.post_mp(x)

        return F.log_softmax(x, dim=1)
 def partial_forward(self, x, edge_index):
     x = self.activation(self.gcn1(x, edge_index))
     x = self.activation(self.gcn2(x, edge_index))
     x = self.activation(self.gcn3(x, edge_index))
     batch = torch.zeros(x.shape[0]).long()
     x = geonn.global_max_pool(x, batch)
     return x
Exemple #8
0
    def forward(self, data):

        x, pos, batch, u = data.x, data.pos, data.batch, data.u

        # Get edges using positions by computing the kNNs or the neighbors within a radius
        #edge_index = knn_graph(pos, k=self.k_nn, batch=batch, loop=self.loop)
        edge_index = radius_graph(pos,
                                  r=self.k_nn,
                                  batch=batch,
                                  loop=self.loop)

        # Start message passing
        for layer in self.layers:
            if self.namemodel == "DeepSet":
                x = layer(x)
            elif self.namemodel == "PointNet":
                x = layer(x=x, pos=pos, edge_index=edge_index)
            elif self.namemodel == "MetaNet":
                x, dumb, u = layer(x, edge_index, None, u, batch)
            else:
                x = layer(x=x, edge_index=edge_index)
            self.h = x
            x = x.relu()

        # Mix different global pooling layers
        addpool = global_add_pool(x, batch)  # [num_examples, hidden_channels]
        meanpool = global_mean_pool(x, batch)
        maxpool = global_max_pool(x, batch)
        #self.pooled = torch.cat([addpool, meanpool, maxpool], dim=1)
        self.pooled = torch.cat([addpool, meanpool, maxpool, u], dim=1)

        # Final linear layer
        return self.lin(self.pooled)
    def forward(self, data):
        data.x = self.datanorm * data.x
        data.x = self.inputnet(data.x)

        data.edge_index = to_undirected(
            knn_graph(data.x,
                      self.k,
                      data.batch,
                      loop=False,
                      flow=self.edgeconv1.flow))
        data.x = self.edgeconv1(data.x, data.edge_index)

        weight = normalized_cut_2d(data.edge_index, data.x)
        cluster = graclus(data.edge_index, weight, data.x.size(0))
        data.edge_attr = None
        data = max_pool(cluster, data)

        data.edge_index = to_undirected(
            knn_graph(data.x,
                      self.k,
                      data.batch,
                      loop=False,
                      flow=self.edgeconv2.flow))
        data.x = self.edgeconv2(data.x, data.edge_index)

        weight = normalized_cut_2d(data.edge_index, data.x)
        cluster = graclus(data.edge_index, weight, data.x.size(0))
        x, batch = max_pool_x(cluster, data.x, data.batch)

        x = global_max_pool(x, batch)

        return self.output(x).squeeze(-1)
Exemple #10
0
    def forward(self, x, batch: Optional[torch.Tensor] = None):
        x = self.datanorm * x
        x = self.inputnet(x)
        
        edge_index = to_undirected(knn_graph(x, self.k, batch, loop=False, flow=self.edgeconv1.flow))
        x = self.edgeconv1(x, edge_index)        
        weight = normalized_cut_2d(edge_index, x)
        cluster = graclus(edge_index, weight, x.size(0))
        edge_attr = None
        x, edge_index, batch, edge_attr = max_pool(cluster, x, edge_index, batch)

        # Additional layer by Shamik
        edge_index = to_undirected(knn_graph(x, self.k, batch, loop=False, flow=self.edgeconv3.flow))
        x = self.edgeconv1(x, edge_index)        
        weight = normalized_cut_2d(edge_index, x)
        cluster = graclus(edge_index, weight, x.size(0))
        edge_attr = None
        x, edge_index, batch, edge_attr = max_pool(cluster, x, edge_index, batch)
        
        edge_index = to_undirected(knn_graph(x, self.k, batch, loop=False, flow=self.edgeconv2.flow))
        x = self.edgeconv2(x, edge_index)
        
        weight = normalized_cut_2d(edge_index, x)
        cluster = graclus(edge_index, weight, x.size(0))
        x, batch = max_pool_x(cluster, x, batch)

        if not batch is None:
            x = global_max_pool(x, batch)
        
        return self.output(x).squeeze(-1)
    def forward(self, data):

        t, pos, batch = data.x, data.pos, data.batch
        pos = pos.cuda()
        batch = batch.cuda()
        t = t.cuda()

        #edge_index = data.edge_index
        # pos = pos.double()
        # batch = batch.long()
        dsize = pos.size()[0]
        bsize = batch[-1].item() + 1
        edge_index = knn_graph(pos, k=30, batch=batch)
        x1 = self.conv1(pos, edge_index)
        edge_index = knn_graph(x1, k=30, batch=batch)
        x2 = self.conv2(x1, edge_index)

        x2max = F.relu(self.lin0(x2))

        x2max = global_max_pool(x2max, batch)
        globalfeats = x2max.repeat(1, int(dsize / bsize)).view(
            dsize,
            x2max.size()[1])

        concat_features = torch.cat((x1, x2, globalfeats), dim=1)

        x = F.relu(self.lin1(concat_features))
        x = F.relu(self.lin2(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin3(x)
        return F.log_softmax(x, dim=-1)
Exemple #12
0
    def forward(self, data):

        if not hasattr(data, 'batch'):
            # Create virtual null batch if singlet graph input
            setattr(data, 'batch',
                    torch.tensor(np.zeros(data.x.shape[0]), dtype=torch.long))

        x = F.elu(self.conv1(data.x, data.edge_index))
        x = F.dropout(x, training=self.training)

        x = F.elu(self.conv2(x, data.edge_index))
        x = F.dropout(x, training=self.training)

        # ** Global pooling **
        if self.task == 'graph':
            x = global_max_pool(x, data.batch)

        # Global features concatenated
        if self.G > 0:
            u = data.u.view(-1, self.G)
            x = torch.cat((x, u), 1)

        x = F.relu(self.mlp1(x))
        x = F.relu(self.mlp2(x))

        return x
Exemple #13
0
    def forward(self, points, features, batch):
        ratio = 1 / self.nb_neighbors
        fps_indices = gnn.fps(x=points, batch=batch, ratio=ratio)
        fps_points = points[fps_indices]
        fps_batch = batch[fps_indices]

        radius_cluster, radius_indices = gnn.radius(x=points,
                                                    y=fps_points,
                                                    batch_x=batch,
                                                    batch_y=fps_batch,
                                                    r=self.radius)

        anchor_points = fps_points[radius_cluster]
        radius_points = points[radius_indices]
        radius_features = features[radius_indices]

        relative_points = (radius_points - anchor_points) / self.radius
        rel_encoded = self.neighborhood_enc(relative_points, radius_cluster)
        rel_enc_mapped = rel_encoded[radius_cluster]

        fc_input = torch.cat(
            [relative_points, rel_enc_mapped, radius_features], dim=1)

        fc1_features = F.relu(self.fc1(fc_input))

        max_features = gnn.global_max_pool(x=fc1_features,
                                           batch=radius_cluster)

        fc1_global_features = F.relu(self.fc1_global(max_features))

        output_features = torch.cat([rel_encoded, fc1_global_features], dim=1)

        return fps_points, output_features, fps_batch
    def forward(self, data):
        #print('---------------')
        #print(data.x.shape)
        out = F.relu(self.lin0(data.x))
        #print(out.shape)
        h = out.unsqueeze(0)

        for i in range(8):
            m = F.relu(self.conv(out, data.edge_index, data.edge_attr))
            out, h = self.gru(m.unsqueeze(0), h)
            out = out.squeeze(0)

        #print(out.shape)
        out, edge_index, _, batch, perm, score = self.pool1(
            out, data.edge_index, None, data.batch)
        # print(out.shape)
        # out = self.gatt(out, data.batch)
        out = global_max_pool(out, data.batch)
        print(out.shape)
        # print(out.shape)
        #print(out.shape)
        # out = F.relu(self.lin1(out))
        # #print(out.shape)
        out = self.lin2(out)
        #print(out.shape)
        #print('-----------------')

        return out.view(-1)
Exemple #15
0
    def forward(self, data):
        x = self.atom_encoder(data.x)
        edge_index = data.edge_index
        edge_attr = data.edge_attr
        edge_attr = torch.LongTensor([
            edge_type[0] + edge_type[1] * 5 + edge_type[2] * 30
            for edge_type in edge_attr
        ]).to(self.device)
        for i in range(len(self.rgcn_list)):
            x_rgcn = self.rgcn_list[i](x, edge_index, edge_attr)
            x_gconv = self.graphconv_list[i](x, edge_index)
            x = torch.cat((x_rgcn, x_gconv), 1)
            x = F.relu(x)
            if i == len(self.rgcn_list) - 1: continue
            x = self.batchnorm(x)

        # x = self.graph_conv(x,edge_index)
        # x = F.relu(x)

        if self.pool_layer == 'add':
            x = global_add_pool(x, data.batch)
        if self.pool_layer == 'mean':
            x = global_mean_pool(x, data.batch)
        if self.pool_layer == 'max':
            x = global_max_pool(x, data.batch)
        if self.pool_layer == 'sort':
            x = global_sort_pool(x, data.batch, self.k)

        x = F.relu(self.linear1(x))
        x = self.linear2(x)

        return x
Exemple #16
0
    def forward(self, data):
        x = self.atom_encoder(data.x)
        edge_index = data.edge_index
        edge_attr = data.edge_attr
        edge_attr = torch.LongTensor([
            edge_type[0] + edge_type[1] * 5 + edge_type[2] * 30
            for edge_type in edge_attr
        ]).to(self.device)
        for i, layer in enumerate(self.rgcn_list):
            x = layer(x, edge_index, edge_attr)
            x = F.relu(x)
            if i == len(self.rgcn_list) - 1: continue
            x = self.batchnorm(x)

        if self.pool_layer == 'add':
            x = global_add_pool(x, data.batch)
        if self.pool_layer == 'mean':
            x = global_mean_pool(x, data.batch)
        if self.pool_layer == 'max':
            x = global_max_pool(x, data.batch)
        if self.pool_layer == 'sort':
            x = global_sort_pool(x, data.batch, self.k)

        x = F.relu(self.linear1(x))
        x = self.linear2(x)

        return x
Exemple #17
0
    def embed_graph(self, x, edge_index, batch=None):
        attn_weights = dict()

        x = F.one_hot(x, num_classes=self.config.num_feature_dim).float()
        x = F.relu(self.gc1(x, edge_index))
        x = F.dropout(x, self.config.dropout, training=self.training)
        x = self.gc2(x, edge_index)

        if self.config.pooling_type == "sagpool":
            x, edge_index, _, batch, attn_weights['pool_perm'], attn_weights[
                'pool_score'] = self.pool1(x, edge_index, batch=batch)
        elif self.config.pooling_type == "topk":
            x, edge_index, _, batch, attn_weights['pool_perm'], attn_weights[
                'pool_score'] = self.pool1(x, edge_index, batch=batch)
        elif self.config.pooling_type == "asa":
            x, edge_index, _, batch, attn_weights['pool_perm'] = self.pool1(
                x, edge_index, batch=batch)

        if self.config.readout_type == "add":
            x = global_add_pool(x, batch)
        elif self.config.readout_type == "mean":
            x = global_mean_pool(x, batch)
        elif self.config.readout_type == "max":
            x = global_max_pool(x, batch)
        elif self.config.readout_type == "sort":
            x = global_sort_pool(x, batch, k=100)
        else:
            pass

        attn_weights['batch'] = batch
        x = self.fc(x)
        return x, attn_weights
    def forward(self, positions, features, batch_indices):
        # Lab: (B,), Pos: (N, 3), Batch: (N,)
        pos, feat, batch = positions, features, batch_indices

        # TransformNet:
        x = pos  # Don't use the normals!

        x = self.transform_1(x, batch)  # (N, 3) -> (N, 128)
        x = self.transform_2(x)  # (N, 128) -> (N, 1024)
        x = global_max_pool(x, batch)  # (B, 1024)

        x = self.transform_3(x)  # (B, 256)
        x = self.transform_4(x)  # (B, 3*3)
        x = x[batch]  # (N, 3*3)
        x = x.view(-1, 3, 3)  # (N, 3, 3)

        # Apply the transform:
        x0 = torch.einsum("ni,nij->nj", pos, x)  # (N, 3)

        # Add features to coordinates
        x = torch.cat([x0, feat], dim=-1).contiguous()

        for i in range(self.n_layers):
            x_i = self.conv_layers[i](x, batch)
            x_i = self.linear_layers[i](x_i)
            x = self.linear_transform[i](x)
            x = x + x_i

        return x
Exemple #19
0
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        ############################################################################
        # TODO: Your code here!
        # Each layer in GNN should consist of a convolution (specified in model_type),
        # a non-linearity (use RELU), and dropout.
        # HINT: the __init__ function contains parameters you will need. You may
        # also find pyg_nn.global_max_pool useful for graph classification.
        # Our implementation is ~6 lines, but don't worry if you deviate from this.
        x = self.convs[0](x, edge_index)
        if self.num_layers > 1:
            for l in range(1, self.num_layers):
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout, training=self.training)
                x = self.convs[l](x, edge_index)
        if self.model_type != "GCN":
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.post_mp(x)
        if self.task == "graph":
            #x = scatter_mean(x, batch, dim=0)
            x = pyg_nn.global_max_pool(x, batch)
        ############################################################################

        return F.log_softmax(x, dim=1)
Exemple #20
0
 def func_global_max_pooling(self, x3, batch):
     if self._use_scatter_pooling:
         return global_max_pool(x3, batch)
     else:
         global_feature = x3.view(
             (self._batch_size, -1, x3.shape[-1])).max(1)
         return global_feature[0]
Exemple #21
0
    def forward(self, data):

        if not hasattr(data, 'batch'):
            # Create virtual null batch if singlet graph input
            setattr(data, 'batch',
                    torch.tensor(np.zeros(data.x.shape[0]), dtype=torch.long))

        x = data.x

        x1 = self.conv1(x, data.edge_index)
        x2 = self.conv2(x1, data.edge_index)

        if self.conclayers:
            x = self.lin1(torch.cat([x1, x2], dim=1))

        else:
            x = self.lin1(x2)

        # ** Global pooling (to handle graph level classification) **
        if self.task == 'graph':
            x = global_max_pool(x, data.batch)

        # Global features concatenated
        if self.G > 0:
            u = data.u.view(-1, self.G)
            x = torch.cat((x, u), 1)

        # Final layers
        x = self.mlp1(x)

        return x
Exemple #22
0
    def forward(self, data):

        x, y, edge_index, edge_attr, batch = data.x, data.y, data.edge_index, data.edge_attr, data.batch

        # Initial CGC ??
        # x = self.cgc1(x, edge_index, edge_attr)

        if self.edge_flag == 1:
            edge_attr = self.linear_edge1(edge_attr)
            edge_attr = self.linear_edge2(edge_attr)
        else:
            edge_attr = None

        for i in range(self.n_layers):
            x = self.gnn_layers[i](x, edge_index, edge_attr)

        # print(criterias)
        # print(batch)

        # Pooling
        x = global_max_pool(x, batch)

        # Output block
        x = F.dropout(x, p=self.dropout_rate,
                      training=self.training)  # dropout_rate
        x = torch.relu(self.linear1(x))
        x = self.bn2(x)
        x = self.linear2(x)

        if torch.isnan(torch.mean(self.linear2.weight)):
            raise RuntimeError("Exploding gradients. Tune learning rate")

        x = torch.sigmoid(x)  # 二分类,输出约束在(0, 1)

        return x
def pool_func(x, batch, mode="sum"):
    if mode == "sum":
        return global_add_pool(x, batch)
    elif mode == "mean":
        return global_mean_pool(x, batch)
    elif mode == "max":
        return global_max_pool(x, batch)
Exemple #24
0
def homo_gnn_softmax(x, index, size=None):
    '''
    re-scale so that homo_softmax(s*x) = homo_softmax(x) when s > 0
    '''
    assert(not torch.sum(torch.isnan(x)))
    x_max = global_max_pool(x, index, size=size)
    assert(not torch.sum(torch.isnan(x_max)))
    x_min = -global_max_pool(-x, index, size=size)
    assert(not torch.sum(torch.isnan(x_min)))
    x_diff = (x_max-x_min)[index]
    assert(not torch.sum(torch.isnan(x_diff)))
    zero_mask = (x_diff == 0).type(torch.float)
    x_diff = torch.ones_like(x_diff)*zero_mask + x_diff*(1.-zero_mask)
    x = x/x_diff
    assert(not torch.sum(torch.isnan(x)))
    return gnn_softmax(x, index, size)
Exemple #25
0
    def project(self, data, reg_hook=False):
        "Projects data up to last hidden layer for visualization."

        x, edge_index = data.x, data.edge_index

        for conv_layer in self.conv_encoder:
            x = conv_layer(x, edge_index)
            #x = F.relu(x)
            x = torch.tanh(x)

        if reg_hook:
            h = x.register_hook(self.activations_hook)

        if self.pooling == 'mean':
            x = global_mean_pool(x, data.batch)
        elif self.pooling == 'add':
            x = global_add_pool(x, data.batch)
        elif self.pooling == 'max':
            x = global_max_pool(x, data.batch)

        for dense_layer in self.linear_layers[:-1]:
            x = dense_layer(x)
            #x = torch.tanh(x)
            x = F.relu(x)

        x = self.linear_layers[-1](x)

        return x
    def forward(self, pos, batch):
        radius = 0.2
        edge_index = radius_graph(pos, r=radius, batch=batch)
        x = F.relu(self.features[0](None, pos, edge_index))

        idx = fps(pos, batch, ratio=0.5)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        radius = 0.4
        edge_index = radius_graph(pos, r=radius, batch=batch)
        x = F.relu(self.features[1](x, pos, edge_index))

        idx = fps(pos, batch, ratio=0.25)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        radius = 1
        edge_index = radius_graph(pos, r=radius, batch=batch)
        x = F.relu(self.features[2](x, pos, edge_index))

        x = global_max_pool(x, batch)
        feat = x

        x = F.relu(self.classifier[0](x))
        x = F.relu(self.classifier[1](x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.classifier[2](x)

        x2 = F.relu(self.discriminator[0](feat))
        x2 = F.dropout(x2, p=0.5, training=self.training)
        x2 = self.discriminator[1](x2)
        return F.log_softmax(x, dim=-1), F.log_softmax(x2, dim=-1)
Exemple #27
0
def test_permuted_global_pool():
    N_1, N_2 = 4, 6
    x = torch.randn(N_1 + N_2, 4)
    batch = torch.cat([torch.zeros(N_1), torch.ones(N_2)]).to(torch.long)
    perm = torch.randperm(N_1 + N_2)

    px = x[perm]
    pbatch = batch[perm]
    px1 = px[pbatch == 0]
    px2 = px[pbatch == 1]

    out = global_add_pool(px, pbatch)
    assert out.size() == (2, 4)
    assert torch.allclose(out[0], px1.sum(dim=0))
    assert torch.allclose(out[1], px2.sum(dim=0))

    out = global_mean_pool(px, pbatch)
    assert out.size() == (2, 4)
    assert torch.allclose(out[0], px1.mean(dim=0))
    assert torch.allclose(out[1], px2.mean(dim=0))

    out = global_max_pool(px, pbatch)
    assert out.size() == (2, 4)
    assert torch.allclose(out[0], px1.max(dim=0)[0])
    assert torch.allclose(out[1], px2.max(dim=0)[0])
    def forward(self, pos, batch):
        radius = 0.2
        edge_index = radius_graph(pos, r=radius, batch=batch)
        x = F.relu(self.conv1(None, pos, edge_index))

        idx = fps(pos, batch, ratio=0.5)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        radius = 0.4
        edge_index = radius_graph(pos, r=radius, batch=batch)
        x = F.relu(self.conv2(x, pos, edge_index))

        idx = fps(pos, batch, ratio=0.25)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        radius = 1
        edge_index = radius_graph(pos, r=radius, batch=batch)
        x = F.relu(self.conv3(x, pos, edge_index))

        x = global_max_pool(x, batch)

        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin3(x)
        return F.log_softmax(x, dim=-1)
Exemple #29
0
    def forward(self, data):
        x, pos, batch = data.x, data.pos[:, :3], data.batch
        x = F.hardtanh(self.conv1(None, pos, batch))

        idx = fps(pos, batch, ratio=0.375)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        x = F.hardtanh(self.conv2(x, pos, batch))

        idx = fps(pos, batch, ratio=0.334)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        x = F.hardtanh(self.conv3(x, pos, batch))
        x = F.hardtanh(self.conv4(x, pos, batch))
        if self.pool == 'max':
            x = global_max_pool(x, batch)
        elif self.pool == 'mean':
            x = global_mean_pool(x, batch)

        x = F.hardtanh(self.lin1(x))
        x = F.hardtanh(self.lin2(x))
        x = self.lin3(x)
        return {
            'out': F.log_softmax(x, dim=-1)
        }
 def forward(self, pos, batch):
     x1 = self.conv1(pos, batch)
     x2 = self.conv2(x1, batch)
     out = self.lin1(torch.cat([x1, x2], dim=1))
     out = global_max_pool(out, batch)
     out = self.mlp(out)
     return F.log_softmax(out, dim=1)