def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch if data.num_node_features == 0: x = torch.ones(data.num_nodes, 1) for i in range(self.num_layers): x = self.convs[i](x, edge_index) x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) if not i == self.num_layers - 1: x = self.norm[i](x) if not self.global_pool: x = pyg_nn.global_mean_pool(x, batch) elif self.global_pool == 'max': x = pyg_nn.global_max_pool(x, batch) elif self.global_pool == 'mix': x1 = pyg_nn.global_mean_pool(x, batch) x2 = pyg_nn.global_max_pool(x, batch) x = torch.cat((x1, x2), 1) x = self.post_mp(x) emb = x out = F.log_softmax(x, dim=1) return emb, out
def forward(self, data): x, edge_index = data.x, data.edge_index batch = data.batch if hasattr(data, 'batch') else None x = F.relu(self.conv1(x, edge_index)) x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch) x1 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1) x = F.relu(self.conv2(x, edge_index)) x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch) x2 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1) x = F.relu(self.conv3(x, edge_index)) x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch) x3 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1) x = x1 + x2 + x3 x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = F.relu(self.lin2(x)) return x
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch) x1 = torch.cat( [gnn.global_max_pool(x, batch), gnn.global_mean_pool(x, batch)], dim=1) x = F.relu(self.conv2(x, edge_index)) x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch) x2 = torch.cat( [gnn.global_max_pool(x, batch), gnn.global_mean_pool(x, batch)], dim=1) x = F.relu(self.conv3(x, edge_index)) x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch) x3 = torch.cat( [gnn.global_max_pool(x, batch), gnn.global_mean_pool(x, batch)], dim=1) x = x1 + x2 + x3 x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = F.relu(self.lin2(x)) x = F.log_softmax(self.lin3(x), dim=-1) return x
def test_global_pool(): N_1, N_2 = 4, 6 x = torch.randn(N_1 + N_2, 4) batch = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)]) out = global_add_pool(x, batch) assert out.size() == (2, 4) assert out[0].tolist() == x[:4].sum(dim=0).tolist() assert out[1].tolist() == x[4:].sum(dim=0).tolist() out = global_add_pool(x, None) assert out.size() == (1, 4) assert out.tolist() == x.sum(dim=0, keepdim=True).tolist() out = global_mean_pool(x, batch) assert out.size() == (2, 4) assert out[0].tolist() == x[:4].mean(dim=0).tolist() assert out[1].tolist() == x[4:].mean(dim=0).tolist() out = global_mean_pool(x, None) assert out.size() == (1, 4) assert out.tolist() == x.mean(dim=0, keepdim=True).tolist() out = global_max_pool(x, batch) assert out.size() == (2, 4) assert out[0].tolist() == x[:4].max(dim=0)[0].tolist() assert out[1].tolist() == x[4:].max(dim=0)[0].tolist() out = global_max_pool(x, None) assert out.size() == (1, 4) assert out.tolist() == x.max(dim=0, keepdim=True)[0].tolist()
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = self.item_embedding(x).squeeze(1) x = F.relu(self.conv1(x, edge_index)) x, edge_index, _, batch, *_ = self.pool1(x, edge_index, batch=batch) x1 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1) x = F.relu(self.conv2(x, edge_index)) x, edge_index, _, batch, *_ = self.pool2(x, edge_index, batch=batch) x2 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1) x = F.relu(self.conv3(x, edge_index)) x, edge_index, _, batch, *_ = self.pool3(x, edge_index, batch=batch) x3 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1) x = x1 + x2 + x3 x = self.fc1(x) x = self.fc2(x) x = self.drop(x) x = torch.sigmoid(self.linear(x)).squeeze(1) return x
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch ############################################################################ # TODO: Your code here! # Each layer in GNN should consist of a convolution (specified in model_type), # a non-linearity (use RELU), and dropout. # HINT: the __init__ function contains parameters you will need. You may # also find pyg_nn.global_max_pool useful for graph classification. # Our implementation is ~6 lines, but don't worry if you deviate from this. # sanity check if data.num_node_features == 0: x = torch.ones(data.num_nodes, 1) for i in range(self.num_layers): x = self.convs[i](x, edge_index) emb = x x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) # global_mean_pool will lead to a terrible result if self.task == 'graph': x = pyg_nn.global_max_pool(x, batch) ############################################################################ x = self.post_mp(x) return F.log_softmax(x, dim=1)
def partial_forward(self, x, edge_index): x = self.activation(self.gcn1(x, edge_index)) x = self.activation(self.gcn2(x, edge_index)) x = self.activation(self.gcn3(x, edge_index)) batch = torch.zeros(x.shape[0]).long() x = geonn.global_max_pool(x, batch) return x
def forward(self, data): x, pos, batch, u = data.x, data.pos, data.batch, data.u # Get edges using positions by computing the kNNs or the neighbors within a radius #edge_index = knn_graph(pos, k=self.k_nn, batch=batch, loop=self.loop) edge_index = radius_graph(pos, r=self.k_nn, batch=batch, loop=self.loop) # Start message passing for layer in self.layers: if self.namemodel == "DeepSet": x = layer(x) elif self.namemodel == "PointNet": x = layer(x=x, pos=pos, edge_index=edge_index) elif self.namemodel == "MetaNet": x, dumb, u = layer(x, edge_index, None, u, batch) else: x = layer(x=x, edge_index=edge_index) self.h = x x = x.relu() # Mix different global pooling layers addpool = global_add_pool(x, batch) # [num_examples, hidden_channels] meanpool = global_mean_pool(x, batch) maxpool = global_max_pool(x, batch) #self.pooled = torch.cat([addpool, meanpool, maxpool], dim=1) self.pooled = torch.cat([addpool, meanpool, maxpool, u], dim=1) # Final linear layer return self.lin(self.pooled)
def forward(self, data): data.x = self.datanorm * data.x data.x = self.inputnet(data.x) data.edge_index = to_undirected( knn_graph(data.x, self.k, data.batch, loop=False, flow=self.edgeconv1.flow)) data.x = self.edgeconv1(data.x, data.edge_index) weight = normalized_cut_2d(data.edge_index, data.x) cluster = graclus(data.edge_index, weight, data.x.size(0)) data.edge_attr = None data = max_pool(cluster, data) data.edge_index = to_undirected( knn_graph(data.x, self.k, data.batch, loop=False, flow=self.edgeconv2.flow)) data.x = self.edgeconv2(data.x, data.edge_index) weight = normalized_cut_2d(data.edge_index, data.x) cluster = graclus(data.edge_index, weight, data.x.size(0)) x, batch = max_pool_x(cluster, data.x, data.batch) x = global_max_pool(x, batch) return self.output(x).squeeze(-1)
def forward(self, x, batch: Optional[torch.Tensor] = None): x = self.datanorm * x x = self.inputnet(x) edge_index = to_undirected(knn_graph(x, self.k, batch, loop=False, flow=self.edgeconv1.flow)) x = self.edgeconv1(x, edge_index) weight = normalized_cut_2d(edge_index, x) cluster = graclus(edge_index, weight, x.size(0)) edge_attr = None x, edge_index, batch, edge_attr = max_pool(cluster, x, edge_index, batch) # Additional layer by Shamik edge_index = to_undirected(knn_graph(x, self.k, batch, loop=False, flow=self.edgeconv3.flow)) x = self.edgeconv1(x, edge_index) weight = normalized_cut_2d(edge_index, x) cluster = graclus(edge_index, weight, x.size(0)) edge_attr = None x, edge_index, batch, edge_attr = max_pool(cluster, x, edge_index, batch) edge_index = to_undirected(knn_graph(x, self.k, batch, loop=False, flow=self.edgeconv2.flow)) x = self.edgeconv2(x, edge_index) weight = normalized_cut_2d(edge_index, x) cluster = graclus(edge_index, weight, x.size(0)) x, batch = max_pool_x(cluster, x, batch) if not batch is None: x = global_max_pool(x, batch) return self.output(x).squeeze(-1)
def forward(self, data): t, pos, batch = data.x, data.pos, data.batch pos = pos.cuda() batch = batch.cuda() t = t.cuda() #edge_index = data.edge_index # pos = pos.double() # batch = batch.long() dsize = pos.size()[0] bsize = batch[-1].item() + 1 edge_index = knn_graph(pos, k=30, batch=batch) x1 = self.conv1(pos, edge_index) edge_index = knn_graph(x1, k=30, batch=batch) x2 = self.conv2(x1, edge_index) x2max = F.relu(self.lin0(x2)) x2max = global_max_pool(x2max, batch) globalfeats = x2max.repeat(1, int(dsize / bsize)).view( dsize, x2max.size()[1]) concat_features = torch.cat((x1, x2, globalfeats), dim=1) x = F.relu(self.lin1(concat_features)) x = F.relu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1)
def forward(self, data): if not hasattr(data, 'batch'): # Create virtual null batch if singlet graph input setattr(data, 'batch', torch.tensor(np.zeros(data.x.shape[0]), dtype=torch.long)) x = F.elu(self.conv1(data.x, data.edge_index)) x = F.dropout(x, training=self.training) x = F.elu(self.conv2(x, data.edge_index)) x = F.dropout(x, training=self.training) # ** Global pooling ** if self.task == 'graph': x = global_max_pool(x, data.batch) # Global features concatenated if self.G > 0: u = data.u.view(-1, self.G) x = torch.cat((x, u), 1) x = F.relu(self.mlp1(x)) x = F.relu(self.mlp2(x)) return x
def forward(self, points, features, batch): ratio = 1 / self.nb_neighbors fps_indices = gnn.fps(x=points, batch=batch, ratio=ratio) fps_points = points[fps_indices] fps_batch = batch[fps_indices] radius_cluster, radius_indices = gnn.radius(x=points, y=fps_points, batch_x=batch, batch_y=fps_batch, r=self.radius) anchor_points = fps_points[radius_cluster] radius_points = points[radius_indices] radius_features = features[radius_indices] relative_points = (radius_points - anchor_points) / self.radius rel_encoded = self.neighborhood_enc(relative_points, radius_cluster) rel_enc_mapped = rel_encoded[radius_cluster] fc_input = torch.cat( [relative_points, rel_enc_mapped, radius_features], dim=1) fc1_features = F.relu(self.fc1(fc_input)) max_features = gnn.global_max_pool(x=fc1_features, batch=radius_cluster) fc1_global_features = F.relu(self.fc1_global(max_features)) output_features = torch.cat([rel_encoded, fc1_global_features], dim=1) return fps_points, output_features, fps_batch
def forward(self, data): #print('---------------') #print(data.x.shape) out = F.relu(self.lin0(data.x)) #print(out.shape) h = out.unsqueeze(0) for i in range(8): m = F.relu(self.conv(out, data.edge_index, data.edge_attr)) out, h = self.gru(m.unsqueeze(0), h) out = out.squeeze(0) #print(out.shape) out, edge_index, _, batch, perm, score = self.pool1( out, data.edge_index, None, data.batch) # print(out.shape) # out = self.gatt(out, data.batch) out = global_max_pool(out, data.batch) print(out.shape) # print(out.shape) #print(out.shape) # out = F.relu(self.lin1(out)) # #print(out.shape) out = self.lin2(out) #print(out.shape) #print('-----------------') return out.view(-1)
def forward(self, data): x = self.atom_encoder(data.x) edge_index = data.edge_index edge_attr = data.edge_attr edge_attr = torch.LongTensor([ edge_type[0] + edge_type[1] * 5 + edge_type[2] * 30 for edge_type in edge_attr ]).to(self.device) for i in range(len(self.rgcn_list)): x_rgcn = self.rgcn_list[i](x, edge_index, edge_attr) x_gconv = self.graphconv_list[i](x, edge_index) x = torch.cat((x_rgcn, x_gconv), 1) x = F.relu(x) if i == len(self.rgcn_list) - 1: continue x = self.batchnorm(x) # x = self.graph_conv(x,edge_index) # x = F.relu(x) if self.pool_layer == 'add': x = global_add_pool(x, data.batch) if self.pool_layer == 'mean': x = global_mean_pool(x, data.batch) if self.pool_layer == 'max': x = global_max_pool(x, data.batch) if self.pool_layer == 'sort': x = global_sort_pool(x, data.batch, self.k) x = F.relu(self.linear1(x)) x = self.linear2(x) return x
def forward(self, data): x = self.atom_encoder(data.x) edge_index = data.edge_index edge_attr = data.edge_attr edge_attr = torch.LongTensor([ edge_type[0] + edge_type[1] * 5 + edge_type[2] * 30 for edge_type in edge_attr ]).to(self.device) for i, layer in enumerate(self.rgcn_list): x = layer(x, edge_index, edge_attr) x = F.relu(x) if i == len(self.rgcn_list) - 1: continue x = self.batchnorm(x) if self.pool_layer == 'add': x = global_add_pool(x, data.batch) if self.pool_layer == 'mean': x = global_mean_pool(x, data.batch) if self.pool_layer == 'max': x = global_max_pool(x, data.batch) if self.pool_layer == 'sort': x = global_sort_pool(x, data.batch, self.k) x = F.relu(self.linear1(x)) x = self.linear2(x) return x
def embed_graph(self, x, edge_index, batch=None): attn_weights = dict() x = F.one_hot(x, num_classes=self.config.num_feature_dim).float() x = F.relu(self.gc1(x, edge_index)) x = F.dropout(x, self.config.dropout, training=self.training) x = self.gc2(x, edge_index) if self.config.pooling_type == "sagpool": x, edge_index, _, batch, attn_weights['pool_perm'], attn_weights[ 'pool_score'] = self.pool1(x, edge_index, batch=batch) elif self.config.pooling_type == "topk": x, edge_index, _, batch, attn_weights['pool_perm'], attn_weights[ 'pool_score'] = self.pool1(x, edge_index, batch=batch) elif self.config.pooling_type == "asa": x, edge_index, _, batch, attn_weights['pool_perm'] = self.pool1( x, edge_index, batch=batch) if self.config.readout_type == "add": x = global_add_pool(x, batch) elif self.config.readout_type == "mean": x = global_mean_pool(x, batch) elif self.config.readout_type == "max": x = global_max_pool(x, batch) elif self.config.readout_type == "sort": x = global_sort_pool(x, batch, k=100) else: pass attn_weights['batch'] = batch x = self.fc(x) return x, attn_weights
def forward(self, positions, features, batch_indices): # Lab: (B,), Pos: (N, 3), Batch: (N,) pos, feat, batch = positions, features, batch_indices # TransformNet: x = pos # Don't use the normals! x = self.transform_1(x, batch) # (N, 3) -> (N, 128) x = self.transform_2(x) # (N, 128) -> (N, 1024) x = global_max_pool(x, batch) # (B, 1024) x = self.transform_3(x) # (B, 256) x = self.transform_4(x) # (B, 3*3) x = x[batch] # (N, 3*3) x = x.view(-1, 3, 3) # (N, 3, 3) # Apply the transform: x0 = torch.einsum("ni,nij->nj", pos, x) # (N, 3) # Add features to coordinates x = torch.cat([x0, feat], dim=-1).contiguous() for i in range(self.n_layers): x_i = self.conv_layers[i](x, batch) x_i = self.linear_layers[i](x_i) x = self.linear_transform[i](x) x = x + x_i return x
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch ############################################################################ # TODO: Your code here! # Each layer in GNN should consist of a convolution (specified in model_type), # a non-linearity (use RELU), and dropout. # HINT: the __init__ function contains parameters you will need. You may # also find pyg_nn.global_max_pool useful for graph classification. # Our implementation is ~6 lines, but don't worry if you deviate from this. x = self.convs[0](x, edge_index) if self.num_layers > 1: for l in range(1, self.num_layers): x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) x = self.convs[l](x, edge_index) if self.model_type != "GCN": x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) x = self.post_mp(x) if self.task == "graph": #x = scatter_mean(x, batch, dim=0) x = pyg_nn.global_max_pool(x, batch) ############################################################################ return F.log_softmax(x, dim=1)
def func_global_max_pooling(self, x3, batch): if self._use_scatter_pooling: return global_max_pool(x3, batch) else: global_feature = x3.view( (self._batch_size, -1, x3.shape[-1])).max(1) return global_feature[0]
def forward(self, data): if not hasattr(data, 'batch'): # Create virtual null batch if singlet graph input setattr(data, 'batch', torch.tensor(np.zeros(data.x.shape[0]), dtype=torch.long)) x = data.x x1 = self.conv1(x, data.edge_index) x2 = self.conv2(x1, data.edge_index) if self.conclayers: x = self.lin1(torch.cat([x1, x2], dim=1)) else: x = self.lin1(x2) # ** Global pooling (to handle graph level classification) ** if self.task == 'graph': x = global_max_pool(x, data.batch) # Global features concatenated if self.G > 0: u = data.u.view(-1, self.G) x = torch.cat((x, u), 1) # Final layers x = self.mlp1(x) return x
def forward(self, data): x, y, edge_index, edge_attr, batch = data.x, data.y, data.edge_index, data.edge_attr, data.batch # Initial CGC ?? # x = self.cgc1(x, edge_index, edge_attr) if self.edge_flag == 1: edge_attr = self.linear_edge1(edge_attr) edge_attr = self.linear_edge2(edge_attr) else: edge_attr = None for i in range(self.n_layers): x = self.gnn_layers[i](x, edge_index, edge_attr) # print(criterias) # print(batch) # Pooling x = global_max_pool(x, batch) # Output block x = F.dropout(x, p=self.dropout_rate, training=self.training) # dropout_rate x = torch.relu(self.linear1(x)) x = self.bn2(x) x = self.linear2(x) if torch.isnan(torch.mean(self.linear2.weight)): raise RuntimeError("Exploding gradients. Tune learning rate") x = torch.sigmoid(x) # 二分类,输出约束在(0, 1) return x
def pool_func(x, batch, mode="sum"): if mode == "sum": return global_add_pool(x, batch) elif mode == "mean": return global_mean_pool(x, batch) elif mode == "max": return global_max_pool(x, batch)
def homo_gnn_softmax(x, index, size=None): ''' re-scale so that homo_softmax(s*x) = homo_softmax(x) when s > 0 ''' assert(not torch.sum(torch.isnan(x))) x_max = global_max_pool(x, index, size=size) assert(not torch.sum(torch.isnan(x_max))) x_min = -global_max_pool(-x, index, size=size) assert(not torch.sum(torch.isnan(x_min))) x_diff = (x_max-x_min)[index] assert(not torch.sum(torch.isnan(x_diff))) zero_mask = (x_diff == 0).type(torch.float) x_diff = torch.ones_like(x_diff)*zero_mask + x_diff*(1.-zero_mask) x = x/x_diff assert(not torch.sum(torch.isnan(x))) return gnn_softmax(x, index, size)
def project(self, data, reg_hook=False): "Projects data up to last hidden layer for visualization." x, edge_index = data.x, data.edge_index for conv_layer in self.conv_encoder: x = conv_layer(x, edge_index) #x = F.relu(x) x = torch.tanh(x) if reg_hook: h = x.register_hook(self.activations_hook) if self.pooling == 'mean': x = global_mean_pool(x, data.batch) elif self.pooling == 'add': x = global_add_pool(x, data.batch) elif self.pooling == 'max': x = global_max_pool(x, data.batch) for dense_layer in self.linear_layers[:-1]: x = dense_layer(x) #x = torch.tanh(x) x = F.relu(x) x = self.linear_layers[-1](x) return x
def forward(self, pos, batch): radius = 0.2 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.features[0](None, pos, edge_index)) idx = fps(pos, batch, ratio=0.5) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 0.4 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.features[1](x, pos, edge_index)) idx = fps(pos, batch, ratio=0.25) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 1 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.features[2](x, pos, edge_index)) x = global_max_pool(x, batch) feat = x x = F.relu(self.classifier[0](x)) x = F.relu(self.classifier[1](x)) x = F.dropout(x, p=0.5, training=self.training) x = self.classifier[2](x) x2 = F.relu(self.discriminator[0](feat)) x2 = F.dropout(x2, p=0.5, training=self.training) x2 = self.discriminator[1](x2) return F.log_softmax(x, dim=-1), F.log_softmax(x2, dim=-1)
def test_permuted_global_pool(): N_1, N_2 = 4, 6 x = torch.randn(N_1 + N_2, 4) batch = torch.cat([torch.zeros(N_1), torch.ones(N_2)]).to(torch.long) perm = torch.randperm(N_1 + N_2) px = x[perm] pbatch = batch[perm] px1 = px[pbatch == 0] px2 = px[pbatch == 1] out = global_add_pool(px, pbatch) assert out.size() == (2, 4) assert torch.allclose(out[0], px1.sum(dim=0)) assert torch.allclose(out[1], px2.sum(dim=0)) out = global_mean_pool(px, pbatch) assert out.size() == (2, 4) assert torch.allclose(out[0], px1.mean(dim=0)) assert torch.allclose(out[1], px2.mean(dim=0)) out = global_max_pool(px, pbatch) assert out.size() == (2, 4) assert torch.allclose(out[0], px1.max(dim=0)[0]) assert torch.allclose(out[1], px2.max(dim=0)[0])
def forward(self, pos, batch): radius = 0.2 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.conv1(None, pos, edge_index)) idx = fps(pos, batch, ratio=0.5) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 0.4 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.conv2(x, pos, edge_index)) idx = fps(pos, batch, ratio=0.25) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 1 edge_index = radius_graph(pos, r=radius, batch=batch) x = F.relu(self.conv3(x, pos, edge_index)) x = global_max_pool(x, batch) x = F.relu(self.lin1(x)) x = F.relu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1)
def forward(self, data): x, pos, batch = data.x, data.pos[:, :3], data.batch x = F.hardtanh(self.conv1(None, pos, batch)) idx = fps(pos, batch, ratio=0.375) x, pos, batch = x[idx], pos[idx], batch[idx] x = F.hardtanh(self.conv2(x, pos, batch)) idx = fps(pos, batch, ratio=0.334) x, pos, batch = x[idx], pos[idx], batch[idx] x = F.hardtanh(self.conv3(x, pos, batch)) x = F.hardtanh(self.conv4(x, pos, batch)) if self.pool == 'max': x = global_max_pool(x, batch) elif self.pool == 'mean': x = global_mean_pool(x, batch) x = F.hardtanh(self.lin1(x)) x = F.hardtanh(self.lin2(x)) x = self.lin3(x) return { 'out': F.log_softmax(x, dim=-1) }
def forward(self, pos, batch): x1 = self.conv1(pos, batch) x2 = self.conv2(x1, batch) out = self.lin1(torch.cat([x1, x2], dim=1)) out = global_max_pool(out, batch) out = self.mlp(out) return F.log_softmax(out, dim=1)