def test_avg_pool_x(): cluster = torch.tensor([0, 1, 0, 1, 2, 2]) x = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) batch = torch.tensor([0, 0, 0, 0, 1, 1]) out = avg_pool_x(cluster, x, batch) assert out[0].tolist() == [[3, 4], [5, 6], [10, 11]] assert out[1].tolist() == [0, 0, 1] out, _ = avg_pool_x(cluster, x, batch, size=2) assert out.tolist() == [[3, 4], [5, 6], [10, 11], [0, 0]]
def forward(self, data, return_hidden_feature=False): #import pdb #pdb.set_trace() if torch.cuda.is_available(): data.x = data.x.cuda() data.edge_attr = data.edge_attr.cuda() data.edge_index = data.edge_index.cuda() data.batch = data.batch.cuda() # make sure that we have undirected graph if not is_undirected(data.edge_index): data.edge_index = to_undirected(data.edge_index) # make sure that nodes can propagate messages to themselves if not contains_self_loops(data.edge_index): data.edge_index, data.edge_attr = add_self_loops( data.edge_index, data.edge_attr.view(-1)) # covalent_propagation # add self loops to enable self propagation covalent_edge_index, covalent_edge_attr = self.covalent_neighbor_threshold( data.edge_index, data.edge_attr) ( non_covalent_edge_index, non_covalent_edge_attr, ) = self.non_covalent_neighbor_threshold(data.edge_index, data.edge_attr) # covalent_propagation and non_covalent_propagation covalent_x = self.covalent_propagation(data.x, covalent_edge_index, covalent_edge_attr) non_covalent_x = self.non_covalent_propagation( covalent_x, non_covalent_edge_index, non_covalent_edge_attr) # zero out the protein features then do ligand only gather...hacky sure but it gets the job done non_covalent_ligand_only_x = non_covalent_x non_covalent_ligand_only_x[data.x[:, 14] == -1] = 0 pool_x = self.global_add_pool(non_covalent_ligand_only_x, data.batch) # fully connected and output layers if return_hidden_feature or self.always_return_hidden_feature: # return prediction and atomistic features (covalent result, non-covalent result, pool result) avg_covalent_x, _ = avg_pool_x(data.batch, covalent_x, data.batch) avg_non_covalent_x, _ = avg_pool_x(data.batch, non_covalent_x, data.batch) fc0_x, fc1_x, output_x = self.output(pool_x, return_hidden_feature=True) return avg_covalent_x, avg_non_covalent_x, pool_x, fc0_x, fc1_x, output_x else: return self.output(pool_x)
def forward(self, data): data.x = self.inputnet(data.x) data.edge_index = to_undirected( knn_graph(data.x, self.k, data.batch, loop=False, flow=self.edgeconv1.flow)) data.x = self.edgeconv1(data.x, data.edge_index) weight = normalized_cut_2d(data.edge_index, data.x) cluster = graclus(data.edge_index, weight, data.x.size(0)) data.edge_attr = None data = avg_pool(cluster, data) data.edge_index = to_undirected( knn_graph(data.x, self.k, data.batch, loop=False, flow=self.edgeconv2.flow)) data.x = self.edgeconv2(data.x, data.edge_index) weight = normalized_cut_2d(data.edge_index, data.x) cluster = graclus(data.edge_index, weight, data.x.size(0)) x, batch = avg_pool_x(cluster, data.x, data.batch) x = global_mean_pool(x, batch) logits = self.output(x).squeeze(-1) return logits
def forward(self, sample): x, edge_index = sample.x, sample.edge_index # Dropout layer # edge_index = self.dropout_edges(edge_index, dropout=0.2) x = self.dense_input(x, self.empty_edges) x = F.gelu(x) x = self.input(x, edge_index) x = F.gelu(x) x = self.conv1(x, edge_index) x = F.gelu(x) # if self.pooling_layers > 1: # batch = torch.tensor([0 for _ in x], dtype=torch.long, device=self.device) # pooled = self.topkpool1(x, edge_index, batch=batch) # x, edge_index = pooled[0], pooled[1] x = self.conv2(x, edge_index) x = F.gelu(x) if self.pooling_layers > 0: batch = torch.tensor([0 for _ in x], dtype=torch.long, device=self.device) pooled = self.topkpool2(x, edge_index, batch=batch) x, edge_index = pooled[0], pooled[1] x = self.conv3(x, edge_index) x = F.gelu(x) # For large graphs # while len(x) > 8: # batch = torch.tensor([0 for _ in x], dtype=torch.long, device=self.device) # pooled = self.topkpool3(x, edge_index, batch=batch) # x, edge_index = pooled[0], pooled[1] # x = self.conv4(x, edge_index) # x = F.gelu(x) batch = torch.tensor([0 for _ in x], dtype=torch.long, device=self.device) # With sort_pool it works but we have the same problem: the output layer learns the order of the pooled nodes # using k = 3, let's see what happens by shuffling the nodes if self.final_pooling == "avg_pool_x": cluster = torch.as_tensor([i % self.final_nodes for i in range(len(x))], device=self.device) (x, cluster) = avg_pool_x(cluster, x, batch) elif self.final_pooling == "sort_pooling": x = global_sort_pool(x, batch, self.final_nodes) elif self.final_pooling == "topk" or self.final_pooling == "asap" or self.final_pooling == "sag": pooled = self.last_pooling_layer(x, edge_index) x = pooled[0] elif self.final_pooling == "max_pool_x": cluster = torch.as_tensor([i % self.final_nodes for i in range(len(x))], device=self.device) (x, cluster) = max_pool_x(cluster, x, batch) # (x2, cluster2) = avg_pool_x(cluster, x, batch) # x = torch.cat([x1.view(-1), x2.view(-1)]) return self.output(x.view(-1))
def forward(self, data): pos, edge_index, batch = data.pos, data.edge_index, data.batch real_batch_size = pos.size(0) / self.nr_points real_batch_size = int(real_batch_size) # Build first edges edge_index = knn_graph(pos, self.k, batch, loop=False) #extract features in 3d _, _, features_dd, _ = self.ds1(pos, edge_index, None) #graclus cluster = graclus(edge_index) pos_gra, batch_gra = avg_pool_x(cluster, pos, batch) features_gra, _ = max_pool_x(cluster, features_dd, batch) #knn(f) with torch.no_grad(): edge_index_gra = knn_graph(features_gra.norm(dim=2), self.k, batch_gra, loop=False) # DD2 _, _, features_dd2, _ = self.dd2(pos_gra, edge_index_gra, features_gra) y1 = self.nn1(features_dd2) y1_pool, _ = max_pool_x(batch_gra, y1, batch_gra) y1_pool = torch.nn.functional.relu(y1_pool) y1_pool = self.bn1(y1_pool) y2 = self.nn2(y1_pool) y2 = torch.nn.functional.relu(y2) y2 = self.bn2(y2) y3 = self.nn3(y2) y3 = torch.nn.functional.relu(y3) y3 = self.bn3(y3) y4 = self.nn4(y3) out = self.sm(y4) return out
def forward(self, data, return_hidden_feature=False): data.x = data.x.cuda() data.edge_attr = data.edge_attr.cuda() data.edge_index = data.edge_index.cuda() data.batch = data.batch.cuda() # make sure that we have undirected graph if not is_undirected(data.edge_index): data.edge_index = to_undirected(data.edge_index) # make sure that nodes can propagate messages to themselves if not contains_self_loops(data.edge_index): data.edge_index, data.edge_attr = add_self_loops( data.edge_index, data.edge_attr.view(-1)) """ # now select the top 5 closest neighbors to each node dense_adj = sparse_to_dense(edge_index=data.edge_index, edge_attr=data.edge_attr) #top_k_vals, top_k_idxs = torch.topk(dense_adj, dim=0, k=5, largest=False) #dense_adj = torch.zeros_like(dense_adj).scatter(1, top_k_idxs, top_k_vals) dense_adj[dense_adj == 0] = 10000 # insert artificially large values for 0 valued entries that will throw off NN calculation top_k_vals, top_k_idxs = torch.topk(dense_adj, dim=1, k=15, largest=False) dense_adj = torch.zeros_like(dense_adj).scatter(1, top_k_idxs, top_k_vals) data.edge_index, data.edge_attr = dense_to_sparse(dense_adj) """ # covalent_propagation # add self loops to enable self propagation covalent_edge_index, covalent_edge_attr = self.covalent_neighbor_threshold( data.edge_index, data.edge_attr) ( non_covalent_edge_index, non_covalent_edge_attr, ) = self.non_covalent_neighbor_threshold(data.edge_index, data.edge_attr) # covalent_propagation and non_covalent_propagation covalent_x = self.covalent_propagation(data.x, covalent_edge_index, covalent_edge_attr) non_covalent_x = self.non_covalent_propagation( covalent_x, non_covalent_edge_index, non_covalent_edge_attr) # zero out the protein features then do ligand only gather...hacky sure but it gets the job done non_covalent_ligand_only_x = non_covalent_x non_covalent_ligand_only_x[data.x[:, 14] == -1] = 0 pool_x = self.global_add_pool(non_covalent_ligand_only_x, data.batch) # fully connected and output layers if return_hidden_feature: # return prediction and atomistic features (covalent result, non-covalent result, pool result) avg_covalent_x, _ = avg_pool_x(data.batch, covalent_x, data.batch) avg_non_covalent_x, _ = avg_pool_x(data.batch, non_covalent_x, data.batch) fc0_x, fc1_x, output_x = self.output(pool_x, return_hidden_feature=True) return avg_covalent_x, avg_non_covalent_x, pool_x, fc0_x, fc1_x, output_x else: return self.output(pool_x)