def forward(self, x, batch: Optional[torch.Tensor] = None): x = self.datanorm * x x = self.inputnet(x) edge_index = to_undirected(knn_graph(x, self.k, batch, loop=False, flow=self.edgeconv1.flow)) x = self.edgeconv1(x, edge_index) weight = normalized_cut_2d(edge_index, x) cluster = graclus(edge_index, weight, x.size(0)) edge_attr = None x, edge_index, batch, edge_attr = max_pool(cluster, x, edge_index, batch) # Additional layer by Shamik edge_index = to_undirected(knn_graph(x, self.k, batch, loop=False, flow=self.edgeconv3.flow)) x = self.edgeconv1(x, edge_index) weight = normalized_cut_2d(edge_index, x) cluster = graclus(edge_index, weight, x.size(0)) edge_attr = None x, edge_index, batch, edge_attr = max_pool(cluster, x, edge_index, batch) edge_index = to_undirected(knn_graph(x, self.k, batch, loop=False, flow=self.edgeconv2.flow)) x = self.edgeconv2(x, edge_index) weight = normalized_cut_2d(edge_index, x) cluster = graclus(edge_index, weight, x.size(0)) x, batch = max_pool_x(cluster, x, batch) if not batch is None: x = global_max_pool(x, batch) return self.output(x).squeeze(-1)
def forward(self, x, pos, batch=None): # add dummy features in case there is none if x is None: x = torch.ones((pos.shape[0], 1), device=pos.get_device()) # first block x = self.mlp_input(x) edge_index = knn_graph(pos, k=self.k, batch=batch) x = self.transformer_input(x, pos, edge_index) # backbone for i in range(len(self.transformers_down)): x, pos, batch = self.transition_down[i](x, pos, batch=batch) edge_index = knn_graph(pos, k=self.k, batch=batch) x = self.transformers_down[i](x, pos, edge_index) # GlobalAveragePooling x = global_mean_pool(x, batch) # Class score out = self.mlp_output(x) return F.log_softmax(out, dim=-1)
def forward(self, x, ret_activations=False, relu_activations=False): batch_size = x.size(0) x = x.reshape(batch_size * self.num_hits, self.node_feat_size) zeros = torch.zeros(batch_size * self.num_hits, dtype=int).to(self.device) zeros[torch.arange(batch_size) * self.num_hits] = 1 batch = torch.cumsum(zeros, 0) - 1 for i in range(self.num_edge_convs): edge_index = knn_graph( x[:, :2], self.k, batch) if i == 0 else knn_graph( x, self.k, batch ) # using only angular coords for knn in first edgeconv block x = torch.cat( (self.edge_convs[i](x, edge_index), x), dim=1 ) # concatenating with original features i.e. skip connection x = global_mean_pool(x, batch) x = self.fc1(x) if ret_activations: if relu_activations: return F.relu(x) else: return x # for Frechet ParticleNet Distance else: x = self.dropout_layer(F.relu(x)) return self.fc2( x ) # no softmax because pytorch cross entropy loss includes softmax
def forward(self, data): data.x = self.datanorm * data.x data.x = self.inputnet(data.x) data.edge_index = to_undirected( knn_graph(data.x, self.k, data.batch, loop=False, flow=self.edgeconv1.flow)) data.x = self.edgeconv1(data.x, data.edge_index) weight = normalized_cut_2d(data.edge_index, data.x) cluster = graclus(data.edge_index, weight, data.x.size(0)) data.edge_attr = None data = max_pool(cluster, data) data.edge_index = to_undirected( knn_graph(data.x, self.k, data.batch, loop=False, flow=self.edgeconv2.flow)) data.x = self.edgeconv2(data.x, data.edge_index) weight = normalized_cut_2d(data.edge_index, data.x) cluster = graclus(data.edge_index, weight, data.x.size(0)) x, batch = max_pool_x(cluster, data.x, data.batch) x = global_max_pool(x, batch) return self.output(x).squeeze(-1)
def forward(self, x, pos, batch=None): # add dummy features in case there is none if x is None: x = torch.ones((pos.shape[0], 1)).to(pos.get_device()) out_x = [] out_pos = [] out_batch = [] # first block x = self.mlp_input(x) edge_index = knn_graph(pos, k=self.k, batch=batch) x = self.transformer_input(x, pos, edge_index) # save outputs for skipping connections out_x.append(x) out_pos.append(pos) out_batch.append(batch) # backbone down : #reduce cardinality and augment dimensionnality for i in range(len(self.transformers_down)): x, pos, batch = self.transition_down[i](x, pos, batch=batch) edge_index = knn_graph(pos, k=self.k, batch=batch) x = self.transformers_down[i](x, pos, edge_index) out_x.append(x) out_pos.append(pos) out_batch.append(batch) # summit x = self.mlp_summit(x) edge_index = knn_graph(pos, k=self.k, batch=batch) x = self.transformer_summit(x, pos, edge_index) # backbone up : augment cardinality and reduce dimensionnality n = len(self.transformers_down) for i in range(n): x = self.transition_up[-i - 1](x=out_x[-i - 2], x_sub=x, pos=out_pos[-i - 2], pos_sub=out_pos[-i - 1], batch_sub=out_batch[-i - 1], batch=out_batch[-i - 2]) edge_index = knn_graph(out_pos[-i - 2], k=self.k, batch=out_batch[-i - 2]) x = self.transformers_up[-i - 1](x, out_pos[-i - 2], edge_index) # Class score out = self.mlp_output(x) return F.log_softmax(out, dim=-1)
def test_knn_graph(dtype, device): x = tensor([ [-1, -1], [-1, +1], [+1, +1], [+1, -1], ], dtype, device) edge_index = knn_graph(x, k=2, flow='target_to_source') assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1), (2, 3), (3, 0), (3, 2)]) edge_index = knn_graph(x, k=2, flow='source_to_target') assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), (3, 2), (0, 3), (2, 3)])
def forward(self, data): k = self.k device = self.device mode = self.mode pos_idx = self.pos_idx #changing xtype to float, change back after saving graphs properly x, edge_index, batch = data.x, data.edge_index, data.batch edge_index = knn_graph(x=x[:,pos_idx],k=k,batch=batch).to(device) a = self.conv_add(x,edge_index) edge_index = knn_graph(x=a[:,pos_idx],k=k,batch=batch).to(device) "check if this recalculation of edge indices is correct, maybe you can do it over all of x" b = self.conv_add2(a,edge_index) edge_index = knn_graph(x=b[:,pos_idx],k=k,batch=batch).to(device) c = self.conv_add3(b,edge_index) edge_index = knn_graph(x=c[:,pos_idx],k=k,batch=batch).to(device) d = self.conv_add4(c,edge_index) x = torch.cat((x,a,b,c,d),dim = 1) del a,b,c,d x = self.nn1(x) x = self.relu(x) x = self.nn2(x) a,_ = scatter_max(x, batch, dim = 0) b,_ = scatter_min(x, batch, dim = 0) c = scatter_sum(x,batch,dim = 0) d = scatter_mean(x,batch,dim= 0) x = torch.cat((a,b,c,d),dim = 1) x = self.relu(x) x = self.nn3(x) x = self.relu(x) x = self.nn4(x) if mode == 'angle': x[:,0] = self.tanh(x[:,0]) x[:,1] = self.tanh(x[:,1]) return x
def forward(self, x, batch=None): spatial = self.lin_s(x) to_propagate = self.lin_flr(x) if self.neighbor_algo == "knn": edge_index = knn_graph(spatial, self.k, batch, loop=False, flow=self.flow, cosine=False) elif self.neighbor_algo == "radius": edge_index = radius_graph(spatial, self.radius, batch, loop=False, flow=self.flow, max_num_neighbors=self.k) else: raise Exception("Unknown neighbor algo {}".format( self.neighbor_algo)) reference = spatial.index_select(0, edge_index[1]) neighbors = spatial.index_select(0, edge_index[0]) distancessq = torch.sum((reference - neighbors)**2, dim=-1) # Factor 10 gives a better initial spread distance_weight = torch.exp(-10. * distancessq) prop_feat = self.propagate(edge_index, x=to_propagate, edge_weight=distance_weight) return edge_index, self.lin_fout(torch.cat([prop_feat, x], dim=-1))
def forward(self, x, pos, batch=None): """""" pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos (N, D), K = pos.size(), self.kernel_size row, col = knn_graph(pos, K * self.dilation, batch, loop=True) if self.dilation > 1: dil = self.dilation index = torch.randint( K * dil, (N, K), dtype=torch.long, device=row.device) arange = torch.arange(N, dtype=torch.long, device=row.device) arange = arange * (K * dil) index = (index + arange.view(-1, 1)).view(-1) row, col = row[index], col[index] pos = pos[col] - pos[row] x_star = self.mlp1(pos.view(N * K, D)) if x is not None: x = x.unsqueeze(-1) if x.dim() == 1 else x x = x[col].view(N, K, self.in_channels) x_star = torch.cat([x_star, x], dim=-1) x_star = x_star.transpose(1, 2).contiguous() x_star = x_star.view(N, self.in_channels + self.hidden_channels, K, 1) transform_matrix = self.mlp2(pos.view(N, K * D)) transform_matrix = transform_matrix.view(N, 1, K, K) x_transformed = torch.matmul(transform_matrix, x_star) x_transformed = x_transformed.view(N, -1, K) out = self.conv(x_transformed) return out
def test_cluster(self): x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) batch = torch.tensor([0, 0, 0, 0]) edge_index = torch_cluster.knn_graph(x, k=2, batch=batch, loop=False) test_edge_index = torch.LongTensor([[2, 1, 3, 0, 3, 0, 1, 2], [0, 0, 1, 1, 2, 2, 3, 3]]) self.assertTrue(torch.all(torch.eq(test_edge_index, edge_index)))
def forward(self, x: Tensor, pos: Tensor, batch: Optional[Tensor] = None): """""" pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos (N, D), K = pos.size(), self.kernel_size edge_index = knn_graph(pos, K * self.dilation, batch, loop=True, flow='target_to_source', num_workers=self.num_workers) if self.dilation > 1: edge_index = edge_index[:, ::self.dilation] row, col = edge_index[0], edge_index[1] pos = pos[col] - pos[row] x_star = self.mlp1(pos) if x is not None: x = x.unsqueeze(-1) if x.dim() == 1 else x x = x[col].view(N, K, self.in_channels) x_star = torch.cat([x_star, x], dim=-1) x_star = x_star.transpose(1, 2).contiguous() transform_matrix = self.mlp2(pos.view(N, K * D)) x_transformed = torch.matmul(x_star, transform_matrix) out = self.conv(x_transformed) return out
def get_graph_feature(x, k, batch=None): batch_size = batch.max() + 1 if batch is not None else 1 # knn edges = knn_graph(x, k, batch=batch) x = torch.cat([x[edges[1]] - x[edges[0]], x[edges[0]]], dim=1) x = x.view(batch_size, -1, x.size(1)) return x.permute(0, 2, 1).contiguous()
def knn(x, k=2, loop=False, dtype=None, device=None): N, D = x.shape batch = torch.zeros(N, dtype=torch.long) edge_index = knn_graph(x, k, batch=batch, loop=loop).to(device) edge_val = torch.ones(edge_index.shape[-1], dtype=dtype, device=device) return SparseTensor( row=edge_index[0], col=edge_index[1], value=edge_val, sparse_sizes=(N, N) )
def test_knn_graph(dtype, device): x = tensor([ [-1, -1], [-1, +1], [+1, +1], [+1, -1], ], dtype, device) row, col = knn_graph(x, k=2, flow='target_to_source') col = col.view(-1, 2).sort(dim=-1)[0].view(-1) assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2] row, col = knn_graph(x, k=2, flow='source_to_target') row = row.view(-1, 2).sort(dim=-1)[0].view(-1) assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2] assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
def forward(self, batch): edge_index = knn_graph(batch.pos, k=self.k, batch=batch.batch, loop=False) batch.edge_index = edge_index return batch
def forward(self, x, batch=None): """""" edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow, cosine=True) return super(DynamicEdgeConv, self).forward(x, edge_index)
def test_knn_graph_large(dtype, device): x = torch.randn(1000, 3, dtype=dtype, device=device) edge_index = knn_graph(x, k=5, flow='target_to_source', loop=True) tree = scipy.spatial.cKDTree(x.cpu().numpy()) _, col = tree.query(x.cpu(), k=5) truth = set([(i, j) for i, ns in enumerate(col) for j in ns]) assert to_set(edge_index.cpu()) == truth
def forward(self, pos, batch): edge_index = knn_graph(pos, k=16, batch=batch, loop=True) h = self.conv1(h=pos, pos=pos, edge_index=edge_index) h = h.relu() h = self.conv2(h=h, pos=pos, edge_index=edge_index) h = h.relu() h = global_max_pool(h, batch) h = self.classifier(h) y = torch.sigmoid(h) return y
def forward(self, data): x = data.x if self.k_graph: data.edge_index = knn_graph(data.x, k_graph) x1 = F.elu(self.conv1(x, data.edge_index)) # x = F.dropout(x, p=0.6, training=self.training) x2 = self.conv2(x1, data.edge_index) #x = self.conv3(x, data.edge_index) x = self.pool(torch.cat([x1, x2], dim=1), data.batch) x = self.mlp(x) return x
def forward(self, data): # device = self.device # mode = self.mode k = self.k device = self.device pos_idx = self.pos_idx x, edge_index, batch = data.x, data.edge_index, data.batch edge_index = knn_graph(x=x[:, pos_idx], k=k, batch=batch).to(device) x = self.GGconv1(x, edge_index) x = self.relu(x) x = self.nn1(x) x = self.relu(x) y = self.resblock1(x) x = x + y z = self.resblock2(x) x = x + z del y, z x = self.nn2(x) x = self.relu(x) x = self.GGconv2(x, edge_index) x = self.relu(x) p = self.resblock3(x) x = x + p o = self.resblock4(x) x = x + o del p, o x = self.nn3(x) x = self.relu(x) a, _ = scatter_max(x, batch, dim=0) b, _ = scatter_min(x, batch, dim=0) c = scatter_sum(x, batch, dim=0) d = scatter_mean(x, batch, dim=0) x = torch.cat((a, b, c, d), dim=1) # print ("cat size",x.size()) del a, b, c, d x = self.nncat(x) x = self.relu(x) # if(torch.sum(torch.isnan(x)) != 0): # print('NAN ENCOUNTERED AT NN2') # print ("xsize %s batchsize %s a size %s b size %s y size %s end forward" %(x.size(),batch.size(),a.size(),b.size(),data.y[:,0].size())) return x
def test_knn_graph(dtype, device): x = tensor([ [-1, -1], [-1, +1], [+1, +1], [+1, -1], ], dtype, device) row, col = knn_graph(x, k=2) col = col.view(-1, 2).sort(dim=-1)[0].view(-1) assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
def forward(self, data): pos, batch, eidx = data.pos, data.batch, data.edge_index x1 = self.conv1(pos, eidx) x2 = self.conv2(x1, batch) x = self.lin1(torch.cat([x1, x2], dim=1)) x = global_max_pool(x, batch) x = self.mlp(x) out_knn = knn_graph(x, self.k_global+1, batch=None, loop=True)[0] # assuming k_global < min streamline length out_knn = x[out_knn.view(-1, self.k_global+1)].mean(1) # pseudo_class = F.log_softmax(out_knn) out = self.lin2(out_knn) return out
def forward(self, data): # Use the coords for the first knn step print('data.x:', data.x.size()) print('data.batch:', data.batch.size()) clustering1 = to_undirected( knn_graph(data.x, self.k, data.batch, loop=False, flow=self.edgeconv1.flow)) print('clustering1:', clustering1.size()) out1 = self.edgeconv1(data.features, clustering1) print('out1:', out1.size()) raise Exception('stop') # Now use the outputted features of the previous layer for the knn clustering2 = to_undirected( knn_graph(out1, self.k, data.batch, loop=False, flow=self.edgeconv2.flow)) out2 = self.edgeconv2(out1, clustering2) clustering3 = to_undirected( knn_graph(out2, self.k, data.batch, loop=False, flow=self.edgeconv3.flow)) out3 = self.edgeconv3(out2, clustering3) # Cat all outputs together edgeconv_out = torch.cat([data.features, out1, out2, out3]) # Run the output layer return self.output(edgeconv_out).squeeze(-1)
def load_edges(self, idx, nhits): if self.distance_weighted: distEdgeTensor = self.dist_pos_matrix(idx, nhits) self.edge_index = distEdgeTensor[0] self.edge_attr = distEdgeTensor[1] else: if self.fully_connected: edge_index = torch.ones([nhits, nhits], dtype=torch.int64) self.edge_index = edge_index.to_sparse()._indices() else: pos = torch.as_tensor(self.event_data[idx, :nhits, 2:5], dtype=torch.float) self.edge_index = knn_graph(pos, k=self.k_neighbours)
def forward(self, x, pos, batch=None, edge_index=None): if edge_index is None: edge_index = knn_graph(x, self.k, batch, loop=False) edge_index = edge_index.to(device) if self.pool: new_adj, new_feat, new_pos, new_batch, index, values, origsize, newsize = mgpool(x, pos, edge_index, batch) return self.layers(new_feat, new_pos, new_adj, new_batch, self.k), new_pos, new_batch, ( index, values, origsize, newsize) else: new_pos = pos new_batch = batch new_feat = x new_adj = edge_index return self.layers(new_feat, new_pos, new_adj, new_batch, self.k), new_pos, new_batch
def test_cluster(self): x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]).to(device) batch = torch.tensor([0, 0, 0, 0]).to(device) edge_index = torch_cluster.knn_graph(x, k=2, batch=batch, loop=False).to(device) test_edge_index = torch.LongTensor([[2, 1, 3, 0, 3, 0, 1, 2], [0, 0, 1, 1, 2, 2, 3, 3]]).to(device) edge_list = edge_index.tolist() test_edge_list = test_edge_index.tolist() del edge_index, test_edge_index # need to transpose the edges to (ei, ej) format edge_list = [(edge_list[0][i], edge_list[1][i]) for i in range(len(edge_list[0]))] test_edge_list = [(test_edge_list[0][i], test_edge_list[1][i]) for i in range(len(test_edge_list[0]))] self.assertCountEqual(edge_list, test_edge_list)
def _featurize_as_graph(self, protein): name = protein['name'] with torch.no_grad(): coords = torch.as_tensor(protein['coords'], device=self.device, dtype=torch.float32) seq = torch.as_tensor( [self.letter_to_num[a] for a in protein['seq']], device=self.device, dtype=torch.long) mask = torch.isfinite(coords.sum(dim=(1, 2))) coords[~mask] = np.inf X_ca = coords[:, 1] edge_index = torch_cluster.knn_graph(X_ca, k=self.top_k) pos_embeddings = self._positional_embeddings(edge_index) E_vectors = X_ca[edge_index[0]] - X_ca[edge_index[1]] rbf = _rbf(E_vectors.norm(dim=-1), D_count=self.num_rbf, device=self.device) dihedrals = self._dihedrals(coords) orientations = self._orientations(X_ca) sidechains = self._sidechains(coords) node_s = dihedrals node_v = torch.cat( [orientations, sidechains.unsqueeze(-2)], dim=-2) edge_s = torch.cat([rbf, pos_embeddings], dim=-1) edge_v = _normalize(E_vectors).unsqueeze(-2) node_s, node_v, edge_s, edge_v = map( torch.nan_to_num, (node_s, node_v, edge_s, edge_v)) data = torch_geometric.data.Data(x=X_ca, seq=seq, name=name, node_s=node_s, node_v=node_v, edge_s=edge_s, edge_v=edge_v, edge_index=edge_index, mask=mask) return data
def forward(self, x, labels=None, epoch=None): x = F.leaky_relu(self.dense(x), negative_slope=self.args.leaky_relu_alpha) batch_size = x.size(0) x = x.reshape(batch_size * self.args.num_hits, self.args.graphcnng_layers[0]) zeros = torch.zeros(batch_size * self.args.num_hits, dtype=int).to(self.args.device) zeros[torch.arange(batch_size) * self.args.num_hits] = 1 batch = torch.cumsum(zeros, 0) - 1 for i in range(len(self.layers)): edge_index = knn_graph(x, self.args.num_knn, batch) edge_attr = x[edge_index[0]] - x[edge_index[1]] x = self.bn_layers[i](self.layers[i](x, edge_index, edge_attr)) if i < (len(self.layers) - 1): x = F.leaky_relu(x, negative_slope=self.args.leaky_relu_alpha) if self.args.graphcnng_tanh: x = F.tanh(x) return x.reshape(batch_size, self.args.num_hits, self.args.node_feat_size)
def forward(self, pos, batch): # Compute the kNN graph: # Here, we need to pass the batch vector to the function call in order # to prevent creating edges between points of different examples. # We also add `loop=True` which will add self-loops to the graph in # order to preserve central point information. edge_index = knn_graph(pos, k=16, batch=batch, loop=True) # 3. Start bipartite message passing. h = self.conv1(h=pos, pos=pos, edge_index=edge_index) h = h.relu() h = self.conv2(h=h, pos=pos, edge_index=edge_index) h = h.relu() h = self.conv3(h=h, pos=pos, edge_index=edge_index) h = h.relu() out = self.classifier(h) return out
def forward(self, x, batch=None): """""" x = x.unsqueeze(-1) if x.dim() == 1 else x row, col = knn_graph(x, self.k, batch, loop=False) x_row, x_col = x.index_select(0, row), x.index_select(0, col) out = torch.cat([x_row, x_col - x_row], dim=1) out = self.nn(out) out = out.view(-1, self.k, out.size(-1)) if self.aggr == 'add': out = out.sum(dim=1) elif self.aggr == 'mean': out = out.mean(dim=1) else: out = out.max(dim=1)[0] return out