def test_global_pool(): N_1, N_2 = 4, 6 x = torch.randn(N_1 + N_2, 4) batch = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)]) out = global_add_pool(x, batch) assert out.size() == (2, 4) assert out[0].tolist() == x[:4].sum(dim=0).tolist() assert out[1].tolist() == x[4:].sum(dim=0).tolist() out = global_add_pool(x, None) assert out.size() == (1, 4) assert out.tolist() == x.sum(dim=0, keepdim=True).tolist() out = global_mean_pool(x, batch) assert out.size() == (2, 4) assert out[0].tolist() == x[:4].mean(dim=0).tolist() assert out[1].tolist() == x[4:].mean(dim=0).tolist() out = global_mean_pool(x, None) assert out.size() == (1, 4) assert out.tolist() == x.mean(dim=0, keepdim=True).tolist() out = global_max_pool(x, batch) assert out.size() == (2, 4) assert out[0].tolist() == x[:4].max(dim=0)[0].tolist() assert out[1].tolist() == x[4:].max(dim=0)[0].tolist() out = global_max_pool(x, None) assert out.size() == (1, 4) assert out.tolist() == x.max(dim=0, keepdim=True)[0].tolist()
def forward(self, data): x, edge_index, node_depth, batch = data.x, data.edge_index, data.node_depth, data.batch x = self.node_encoder(x, node_depth.view(-1, )) edge_weight = None x = F.relu(self.conv1(x, edge_index)) xs = [global_mean_pool(x, batch)] for i, conv in enumerate(self.convs): x = conv(x=x, edge_index=edge_index, edge_weight=edge_weight) x = F.relu(x) xs += [global_mean_pool(x, batch)] if i % 2 == 0 and i < len(self.convs) - 1: pool = self.pools[i // 2] x, edge_index, edge_weight, batch, _ = pool( x=x, edge_index=edge_index, edge_weight=edge_weight, batch=batch) x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) # x = self.lin2(x) # return F.log_softmax(x, dim=-1) if self.num_class > 0: return self.graph_pred_linear(x) pred_list = [] for i in range(self.max_seq_len): pred_list.append(self.graph_pred_linear_list[i](x)) return pred_list
def forward(self, z, edge_index, batch, x=None, edge_weight=None, node_id=None): z_emb = self.z_embedding(z) if z_emb.ndim == 3: # in case z has multiple integer labels z_emb = z_emb.sum(dim=1) if self.use_feature and x is not None: x = torch.cat([z_emb, x.to(torch.float)], 1) else: x = z_emb if self.node_embedding is not None and node_id is not None: n_emb = self.node_embedding(node_id) x = torch.cat([x, n_emb], 1) x = self.conv1(x, edge_index) xs = [x] for conv in self.convs: x = conv(x, edge_index) xs += [x] if self.jk: x = global_mean_pool(torch.cat(xs, dim=1), batch) else: x = global_mean_pool(xs[-1], batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=self.dropout, training=self.training) x = self.lin2(x) return x
def forward(self, data): x, edge_index = data.x, data.edge_index batch = data.batch if hasattr(data, 'batch') else None x = F.relu(self.conv1(x, edge_index)) x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch) x1 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1) x = F.relu(self.conv2(x, edge_index)) x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch) x2 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1) x = F.relu(self.conv3(x, edge_index)) x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch) x3 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1) x = x1 + x2 + x3 x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = F.relu(self.lin2(x)) return x
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch if data.num_node_features == 0: x = torch.ones(data.num_nodes, 1) for i in range(self.num_layers): x = self.convs[i](x, edge_index) x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) if not i == self.num_layers - 1: x = self.norm[i](x) if not self.global_pool: x = pyg_nn.global_mean_pool(x, batch) elif self.global_pool == 'max': x = pyg_nn.global_max_pool(x, batch) elif self.global_pool == 'mix': x1 = pyg_nn.global_mean_pool(x, batch) x2 = pyg_nn.global_max_pool(x, batch) x = torch.cat((x1, x2), 1) x = self.post_mp(x) emb = x out = F.log_softmax(x, dim=1) return emb, out
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch) x1 = torch.cat( [gnn.global_max_pool(x, batch), gnn.global_mean_pool(x, batch)], dim=1) x = F.relu(self.conv2(x, edge_index)) x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch) x2 = torch.cat( [gnn.global_max_pool(x, batch), gnn.global_mean_pool(x, batch)], dim=1) x = F.relu(self.conv3(x, edge_index)) x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch) x3 = torch.cat( [gnn.global_max_pool(x, batch), gnn.global_mean_pool(x, batch)], dim=1) x = x1 + x2 + x3 x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = F.relu(self.lin2(x)) x = F.log_softmax(self.lin3(x), dim=-1) return x
def forward(self, x, edge_index, batch): if x.dim() == 1: x = x.unsqueeze(-1) if self.readout == 'mean': output_list = [global_mean_pool(x, batch)] else: output_list = [global_add_pool(x, batch)] hid_x = self.fc(x) for conv in self.conv_layers: hid_x = conv(hid_x, edge_index) if self.readout == 'mean': output_list.append(global_mean_pool(hid_x, batch)) else: output_list.append(global_add_pool(hid_x, batch)) score_over_layer = 0 for layer, h in enumerate(output_list): h = self.bns_fc[layer](h) score_over_layer += F.relu(self.linears_prediction[layer](h)) if self.dropout > 0: x = F.dropout(x, p=self.dropout, training=self.training) x = self.linear(score_over_layer) return F.log_softmax(x, dim=-1)
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = self.item_embedding(x).squeeze(1) x = F.relu(self.conv1(x, edge_index)) x, edge_index, _, batch, *_ = self.pool1(x, edge_index, batch=batch) x1 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1) x = F.relu(self.conv2(x, edge_index)) x, edge_index, _, batch, *_ = self.pool2(x, edge_index, batch=batch) x2 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1) x = F.relu(self.conv3(x, edge_index)) x, edge_index, _, batch, *_ = self.pool3(x, edge_index, batch=batch) x3 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1) x = x1 + x2 + x3 x = self.fc1(x) x = self.fc2(x) x = self.drop(x) x = torch.sigmoid(self.linear(x)).squeeze(1) return x
def forward(self,entry1_data,entry2_data,get_latent_varaible=False): entry2_data,entry2_seq_data = entry2_data entry1_data,entry1_seq_data = entry1_data entry1_x,entry1_edge_index,entry1_edge_attr,entry1_batch = entry1_data.x,entry1_data.edge_index,entry1_data.edge_attr,entry1_data.batch entry1_out = self.gconv1(entry1_x,entry1_edge_index,entry1_edge_attr,entry1_batch ) entry1_mean = global_mean_pool(entry1_out,entry1_batch) entry1_seq_mean = self.gconv1_seq(entry1_seq_data) entry1_mean = t.cat([entry1_mean,entry1_seq_mean],dim=-1) entry2_x,entry2_edge_index,entry2_edge_attr,entry2_batch = entry2_data.x,entry2_data.edge_index,entry2_data.edge_attr,entry2_data.batch entry2_out = self.gconv2(entry2_x,entry2_edge_index,entry2_edge_attr,entry2_batch) entry2_mean = global_mean_pool(entry2_out,entry2_batch) entry2_seq_mean = self.gconv2_seq(entry2_seq_data) entry2_mean = t.cat([entry2_mean,entry2_seq_mean],dim=-1) cat_features = t.cat([entry1_mean,entry2_mean],dim=-1) x = self.global_fc_nn(cat_features) if get_latent_varaible: return x else: x = self.fc2(x) if self.out_activation_func == 'softmax': return F.softmax(x, dim=-1) # F.log_softmax(x, dim=-1) elif self.out_activation_func == 'sigmoid': return t.sigmoid(x) elif self.out_activation_func is None : return x
def forward(self, data): x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch #x = F.relu(self.conv1(x=x, edge_index = edge_index)) if (not self.use_weight): x = F.relu(self.conv1(x=x, edge_index=edge_index)) else: x = F.relu( self.conv1(x=x, edge_index=edge_index, edge_weight=edge_attr)) xs = [global_mean_pool(x, batch)] for i, conv in enumerate(self.convs): #x = F.relu(conv(x=x, edge_index = edge_index)) if (not self.use_weight): x = F.relu(conv(x=x, edge_index=edge_index)) else: x = F.relu( conv(x=x, edge_index=edge_index, edge_weight=edge_attr)) xs += [global_mean_pool(x, batch)] if i % 2 == 0 and i < len(self.convs) - 1: pool = self.pools[i // 2] #x, edge_index, _, batch, _, _ = pool(x, edge_index,batch=batch) if (not self.use_weight): x, edge_index, _, batch, _, _ = pool(x, edge_index, batch=batch) else: x, edge_index, edge_attr, batch, _, _ = pool( x, edge_index, edge_attr=edge_attr, batch=batch) x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) #print(x.shape) return F.log_softmax(x, dim=-1)
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch if self.encode_edge: x = self.atom_encoder(x) x = self.conv1(x, edge_index, data.edge_attr) else: x = self.conv1(x, edge_index) x = F.relu(x) xs = [global_mean_pool(x, batch)] for i, conv in enumerate(self.convs): x = F.relu(conv(x, edge_index)) xs += [global_mean_pool(x, batch)] if self.pooling_type != 'none': if self.pooling_type == 'complement': complement = batched_negative_edges(edge_index=edge_index, batch=batch, force_undirected=True) cluster = graclus(complement, num_nodes=x.size(0)) elif self.pooling_type == 'graclus': cluster = graclus(edge_index, num_nodes=x.size(0)) data = Batch(x=x, edge_index=edge_index, batch=batch) data = max_pool(cluster, data) x, edge_index, batch = data.x, data.edge_index, data.batch if not self.no_cat: x = self.jump(xs) else: x = global_mean_pool(x, batch) x = F.relu(self.lin1(x)) x = self.lin2(x) return x
def forward(self, x, edge_index, edge_weight, batch): x = self.lin1(x) x = x.relu() gcn1 = F.relu(self.gcn1(x, edge_index, edge_weight)) x1, edge_index1, edge_attr1, batch1, _, _ = \ self.pool1(gcn1, edge_index, edge_weight, batch=batch) global_pool1 = geo_nn.global_mean_pool(x1, batch1) # global_pool1 = torch.cat( # [geo_nn.global_mean_pool(x1, batch1), # geo_nn.global_max_pool(x1, batch1)], # dim=1) gcn2 = F.relu(self.gcn2(x1, edge_index1, edge_attr1)) x2, edge_index2, edge_attr2, batch2, _, _ = \ self.pool2(gcn2, edge_index1, edge_attr1, batch=batch1) global_pool2 = geo_nn.global_mean_pool(x2, batch2) # global_pool2 = torch.cat( # [geo_nn.global_mean_pool(x2, batch2), # geo_nn.global_max_pool(x2, batch2)], # dim=1) gcn3 = F.relu(self.gcn3(x2, edge_index2, edge_attr2)) x3, edge_index3, edge_attr3, batch3, _, _ = \ self.pool3(gcn3, edge_index2, edge_attr2, batch=batch2) global_pool3 = geo_nn.global_mean_pool(x3, batch3) # global_pool3 = torch.cat( # [geo_nn.global_mean_pool(x3, batch3), # geo_nn.global_max_pool(x3, batch3)], # dim=1) x = global_pool1 + global_pool2 + global_pool3 x = self.mlp(x) return x
def forward(self, h, edge_index, edge_attr, batch): if self.training: for l, layer in enumerate(self.layers): t1 = time.perf_counter() h = layer(h, edge_index, self.pseudo_proj[l](edge_attr)) print("conv", l, "forward time: ", time.perf_counter() - t1) h.register_hook(hook_gcn) h = F.relu(h) h.register_hook(hook_relu) # h = self.dropout(h) t2 = time.perf_counter() h = global_mean_pool(h, batch) print("pooling forward time: ", time.perf_counter() - t2) h.register_hook(hook_pool) t3 = time.perf_counter() h = self.fc1(h) print("fc1 forward time: ", time.perf_counter() - t3) h.register_hook(hook) h = F.elu(h) h.register_hook(hook_relu) t4 = time.perf_counter() h = self.fc2(h) print("fc2 forward time: ", time.perf_counter() - t4) h.register_hook(hook) # h = self.readout(h) h = F.log_softmax(h, dim=0) h.register_hook(hook) else: for l, layer in enumerate(self.layers): # t1 = time.perf_counter() h = layer(h, edge_index, self.pseudo_proj[l](edge_attr)) # print("conv", l, "forward time: ", time.perf_counter() - t1) h = F.relu(h) # h = self.dropout(h) # t2 = time.perf_counter() h = global_mean_pool(h, batch) # print("pooling forward time: ", time.perf_counter() - t2) # t3 = time.perf_counter() h = self.fc1(h) # print("fc1 forward time: ", time.perf_counter() - t3) h = F.elu(h) # t4 = time.perf_counter() h = self.fc2(h) # print("fc2 forward time: ", time.perf_counter() - t4) # h = self.readout(h) h = F.log_softmax(h, dim=0) return h
def forward(self, x, edge, batch, type='mean_pool'): if type == 'mean_pool': return global_mean_pool(x, batch) elif type == 'max_pool': return global_max_pool(x, batch) elif type == 'sum_pool': return global_add_pool(x, batch) elif type == 'sag_pool': x1, _, _, batch, _, _ = self.sag_pool(x, edge, batch=batch) return global_mean_pool(x1, batch)
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) xs = [global_mean_pool(x, batch)] for i, (conv, pool) in enumerate(zip(self.convs, self.pools)): x = F.relu(conv(x, edge_index)) xs += [global_mean_pool(x, batch)] if i % 2 == 0: x, edge_index, _, batch, _ = pool(x, edge_index, batch=batch) x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1)
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) xs = [global_mean_pool(x, batch)] for i, conv in enumerate(self.convs): x = F.relu(conv(x, edge_index)) xs += [global_mean_pool(x, batch)] if i % 2 == 0 and i < len(self.convs) - 1: pool = self.pools[i // 2] x, edge_index, batch, _ = pool(x, edge_index, batch=batch) x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1)
def forward(self, data16, data20, datacol): x16, edge_index16, batch16 = data16.features, data16.edge_index, data16.batch x20, edge_index20, batch20 = data20.features, data20.edge_index, data20.batch xcol, edge_indexcol, batchcol = datacol.features, datacol.edge_index, datacol.batch x16 = self.conv1(x16, edge_index16) x16 = F.relu(x16) x16 = self.conv2(x16, edge_index16) x16 = F.relu(x16) x16 = self.conv3(x16, edge_index16) x16 = torch.transpose(x16, 0, 1) x16 = self.i1(x16.unsqueeze(0)) x16 = self.i2(x16) x16 = torch.transpose(x16.squeeze(0), 0, 1) x16 = global_mean_pool(x16, batch16) x20 = self.conv1(x20, edge_index20) x20 = F.relu(x20) x20 = self.conv2(x20, edge_index20) x20 = F.relu(x20) x20 = self.conv3(x20, edge_index20) x20 = torch.transpose(x20, 0, 1) x20 = self.i1(x20.unsqueeze(0)) x20 = self.i2(x20) x20 = torch.transpose(x20.squeeze(0), 0, 1) x20 = global_mean_pool(x20, batch20) xcol = self.colconv(xcol, edge_indexcol) xcol = F.relu(xcol) xcol = self.conv2(xcol, edge_indexcol) xcol = F.relu(xcol) xcol = self.conv3(xcol, edge_indexcol) xcol = torch.transpose(xcol, 0, 1) xcol = self.i1(xcol.unsqueeze(0)) xcol = self.i2(xcol) xcol = torch.transpose(xcol.squeeze(0), 0, 1) xcol = global_mean_pool(xcol, batchcol) xcol = self.lin2(xcol) xcol = self.mlp2(xcol) x = torch.cat([x16, x20], dim=1) x = self.lin1(x) x = self.mlp(x) out = torch.cat([x, xcol], dim=1) out = self.finallin(out) return F.log_softmax(out, dim=-1)
def test_permuted_global_pool(): N_1, N_2 = 4, 6 x = torch.randn(N_1 + N_2, 4) batch = torch.cat([torch.zeros(N_1), torch.ones(N_2)]).to(torch.long) perm = torch.randperm(N_1 + N_2) px = x[perm] pbatch = batch[perm] px1 = px[pbatch == 0] px2 = px[pbatch == 1] out = global_add_pool(px, pbatch) assert out.size() == (2, 4) assert torch.allclose(out[0], px1.sum(dim=0)) assert torch.allclose(out[1], px2.sum(dim=0)) out = global_mean_pool(px, pbatch) assert out.size() == (2, 4) assert torch.allclose(out[0], px1.mean(dim=0)) assert torch.allclose(out[1], px2.mean(dim=0)) out = global_max_pool(px, pbatch) assert out.size() == (2, 4) assert torch.allclose(out[0], px1.max(dim=0)[0]) assert torch.allclose(out[1], px2.max(dim=0)[0])
def forward(self, pos, batch): x = pos.new_ones((pos.size(0), 1)) radius = 0.2 edge_index = radius_graph(pos, r=radius, batch=batch) pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5 pseudo = pseudo.clamp(min=0, max=1) x = F.elu(self.conv1(x, edge_index, pseudo)) idx = fps(pos, batch, ratio=0.5) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 0.4 edge_index = radius_graph(pos, r=radius, batch=batch) pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5 pseudo = pseudo.clamp(min=0, max=1) x = F.elu(self.conv2(x, edge_index, pseudo)) idx = fps(pos, batch, ratio=0.25) x, pos, batch = x[idx], pos[idx], batch[idx] radius = 1 edge_index = radius_graph(pos, r=radius, batch=batch) pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5 pseudo = pseudo.clamp(min=0, max=1) x = F.elu(self.conv3(x, edge_index, pseudo)) x = global_mean_pool(x, batch) x = F.elu(self.lin1(x)) x = F.elu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1)
def forward(self, data): data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr)) weight = normalized_cut_2d(data.edge_index, data.pos) cluster = graclus(data.edge_index, weight, data.x.size(0)) data.edge_attr = None data = max_pool(cluster, data, transform=T.Cartesian(cat=False)) row, col = data.edge_index data.edge_attr = (data.pos[col] - data.pos[row]) / (2 * self.args.cutoff) + 0.5 data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) weight = normalized_cut_2d(data.edge_index, data.pos) cluster = graclus(data.edge_index, weight, data.x.size(0)) data = max_pool(cluster, data, transform=T.Cartesian(cat=False)) row, col = data.edge_index data.edge_attr = (data.pos[col] - data.pos[row]) / (2 * self.args.cutoff) + 0.5 data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr)) x = global_mean_pool(data.x, data.batch) x = F.elu(self.fc1(x)) x = F.dropout(x, training=self.training, p=self.args.disc_dropout) y = self.fc2(x) if (self.args.wgan): return y return torch.sigmoid(y)
def forward(self, x, edge_index, batch, x_mord): # FORWARD CNN x = self.graph_conv1(x, edge_index) x = x.relu() x = self.graph_conv2(x, edge_index) x = x.relu() x = self.graph_conv3(x, edge_index) x = global_mean_pool(x, batch) x_mord = F.relu(self.dense_fc1(x_mord)) x_mord = self.dense_batch_norm1(x_mord) x_mord = self.dense_dropout(x_mord) x_mord = F.relu(self.dense_fc2(x_mord)) x_mord = self.dense_batch_norm2(x_mord) x_mord = self.dense_dropout(x_mord) x_mord = F.relu(self.dense_fc3(x_mord)) x_mord = self.dense_batch_norm3(x_mord) x_mord = self.dense_dropout(x_mord) x = torch.cat([x, x_mord], dim=1) return torch.sigmoid(self.linear(x))
def forward(self, data): x, edge_index, y, batch = data.x, data.edge_index, data.y, data.batch x = F.normalize(x, p=1., dim=-1) self.original_x = x self.label = y.long() e = self.conv1(x, edge_index) e = self.conv2(e, edge_index) e = self.conv3(e, edge_index) # 2. Readout layer c = global_mean_pool(e, batch) c = F.dropout(c, p=0.1, training=self.training) c = self.lin1(c) b = self.bn1(c) b = torch.tanh(b) b = F.dropout(b, p=0.1, training=self.training) c = self.lin2(b) b = self.bn2(c) b = torch.tanh(b) b = F.dropout(b, p=0.1, training=self.training) c = self.lin3(b) b = self.bn3(c) b = torch.tanh(b) c = self.lin4(b) return c
def forward(self, data): x = data.x x_1 = F.relu(self.conv1_1(x, data.edge_index_1)) x_2 = F.relu(self.conv1_2(x, data.edge_index_2)) x_1_r = self.mlp_1(torch.cat([x_1, x_2], dim=-1)) x_1_r = self.bn1(x_1_r) x_1 = F.relu(self.conv2_1(x_1_r, data.edge_index_1)) x_2 = F.relu(self.conv2_2(x_1_r, data.edge_index_2)) x_2_r = self.mlp_2(torch.cat([x_1, x_2], dim=-1)) x_2_r = self.bn2(x_2_r) x_1 = F.relu(self.conv3_1(x_2_r, data.edge_index_1)) x_2 = F.relu(self.conv3_2(x_2_r, data.edge_index_2)) x_3_r = self.mlp_3(torch.cat([x_1, x_2], dim=-1)) x_3_r = self.bn3(x_3_r) x_1 = F.relu(self.conv4_1(x_3_r, data.edge_index_1)) x_2 = F.relu(self.conv4_2(x_3_r, data.edge_index_2)) x_4_r = self.mlp_4(torch.cat([x_1, x_2], dim=-1)) x_4_r = self.bn4(x_4_r) x = torch.cat([x_1_r, x_2_r, x_3_r, x_4_r], dim=-1) x = global_mean_pool(x, data.batch) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = F.relu(self.fc3(x)) x = self.fc4(x) return x.view(-1)
def project(self, data, reg_hook=False): "Projects data up to last hidden layer for visualization." x, edge_index = data.x, data.edge_index for conv_layer in self.conv_encoder: x = conv_layer(x, edge_index) #x = F.relu(x) x = torch.tanh(x) if reg_hook: h = x.register_hook(self.activations_hook) if self.pooling == 'mean': x = global_mean_pool(x, data.batch) elif self.pooling == 'add': x = global_add_pool(x, data.batch) elif self.pooling == 'max': x = global_max_pool(x, data.batch) for dense_layer in self.linear_layers[:-1]: x = dense_layer(x) #x = torch.tanh(x) x = F.relu(x) x = self.linear_layers[-1](x) return x
def forward(self, data): """Run a forward pass. We do not use any activations just ot make sure an untrained netwrok will be still able to propagate gradients. Parameters ---------- data : torch_gometric.data.Batch Batch graph. Return ------ y : torch.tensor Per graph predictions of shape `(n_samples, dim)`. """ x = data.x # (n_nodes_batch, n_node_features) edge_index = data.edge_index # (2, n_edges_batch) edge_features = data.edge_features # (n_edges_batch, n_edge_features=1) batch = data.batch # (n_nodes_batch,) x = self.conv(x, edge_index, edge_features) # (n_nodes_batch, n_channels) x = global_mean_pool(x, batch) # (n_samples, n_channels) x = self.fc1(x) # (n_samples, hidden_size) y = self.fc2(x) # (n_samples, n_targets) return y
def forward(self, x, edge_index, batch): # 1. Obtain node embeddings x = self.batchn1(x) x = self.conv1(x, edge_index) x = x.relu() x = self.batchn2(x) x = self.conv2(x, edge_index) x = x.relu() x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch) x = self.batchn3(x) x = self.conv3(x, edge_index) x = x.relu() x = self.batchn4(x) x = self.conv4(x, edge_index) x = x.relu() x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch) x = self.batchn5(x) x = self.conv5(x, edge_index) x = x.relu() x = self.batchn6(x) x = self.conv6(x, edge_index) x = x.relu() # x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch) x = self.batchn7(x) x = self.conv7(x, edge_index) # 2. Readout layer x = global_mean_pool(x, batch) # [batch_size, hidden_channels] # 3. Apply a final classifier # x = F.dropout(x, p=0.1, training=self.training) x = self.lin(x) return x
def forward(self, data): x, pos, batch = data.x, data.pos[:, :3], data.batch x = F.hardtanh(self.conv1(None, pos, batch)) idx = fps(pos, batch, ratio=0.375) x, pos, batch = x[idx], pos[idx], batch[idx] x = F.hardtanh(self.conv2(x, pos, batch)) idx = fps(pos, batch, ratio=0.334) x, pos, batch = x[idx], pos[idx], batch[idx] x = F.hardtanh(self.conv3(x, pos, batch)) x = F.hardtanh(self.conv4(x, pos, batch)) if self.pool == 'max': x = global_max_pool(x, batch) elif self.pool == 'mean': x = global_mean_pool(x, batch) x = F.hardtanh(self.lin1(x)) x = F.hardtanh(self.lin2(x)) x = self.lin3(x) return { 'out': F.log_softmax(x, dim=-1) }
def forward(self, data): row, col = data.edge_index data.edge_attr = (data.pos[col] - data.pos[row]) / (2 * 28 * cutoff) + 0.5 # print(data.edge_index.shape) # print(data.edge_index[:, -20:]) data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr)) weight = normalized_cut_2d(data.edge_index, data.pos) cluster = graclus(data.edge_index, weight, data.x.size(0)) data.edge_attr = None data = max_pool(cluster, data, transform=T.Cartesian(cat=False)) row, col = data.edge_index data.edge_attr = (data.pos[col] - data.pos[row]) / (2 * 28 * cutoff) + 0.5 data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr)) weight = normalized_cut_2d(data.edge_index, data.pos) cluster = graclus(data.edge_index, weight, data.x.size(0)) data = max_pool(cluster, data, transform=T.Cartesian(cat=False)) row, col = data.edge_index data.edge_attr = (data.pos[col] - data.pos[row]) / (2 * 28 * cutoff) + 0.5 data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr)) x = global_mean_pool(data.x, data.batch) return self.fc1(x) x = F.elu(self.fc1(x)) x = F.dropout(x, training=self.training) return F.log_softmax(self.fc2(x), dim=1)
def forward(self, data): x = data.x x = x.long() x_new = torch.zeros(x.size(0), 492).to(device) x_new[range(x_new.shape[0]), x.view(1, x.size(0))] = 1 x_1 = F.relu(self.conv1_1(x_new, data.edge_index_1)) x_2 = F.relu(self.conv1_2(x_new, data.edge_index_2)) x_1_r = self.mlp_1(torch.cat([x_1, x_2], dim=-1)) x_1_r = self.bn1(x_1_r) x_1 = F.relu(self.conv2_1(x_1_r, data.edge_index_1)) x_2 = F.relu(self.conv2_2(x_1_r, data.edge_index_2)) x_2_r = self.mlp_2(torch.cat([x_1, x_2], dim=-1)) x_2_r = self.bn2(x_2_r) x_1 = F.relu(self.conv3_1(x_2_r, data.edge_index_1)) x_2 = F.relu(self.conv3_2(x_2_r, data.edge_index_2)) x_3_r = self.mlp_3(torch.cat([x_1, x_2], dim=-1)) x_3_r = self.bn3(x_3_r) x_1 = F.relu(self.conv4_1(x_3_r, data.edge_index_1)) x_2 = F.relu(self.conv4_2(x_3_r, data.edge_index_2)) x_4_r = self.mlp_4(torch.cat([x_1, x_2], dim=-1)) x_4_r = self.bn4(x_4_r) x = torch.cat([x_1_r, x_2_r, x_3_r, x_4_r], dim=-1) x = global_mean_pool(x, data.batch) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = F.relu(self.fc3(x)) x = self.fc4(x) return x.view(-1)
def pool_func(x, batch, mode="sum"): if mode == "sum": return global_add_pool(x, batch) elif mode == "mean": return global_mean_pool(x, batch) elif mode == "max": return global_max_pool(x, batch)