def _rwr_trace_to_cogdl_graph(g: Graph, seed: int, trace: torch.Tensor, positional_embedding_size: int, entire_graph: bool = False): subv = torch.unique(trace).tolist() try: subv.remove(seed) except ValueError: pass subv = [seed] + subv if entire_graph: subg = copy.deepcopy(g) else: subg = g.subgraph(subv) subg = _add_undirected_graph_positional_embedding( subg, positional_embedding_size) subg.seed = torch.zeros(subg.num_nodes, dtype=torch.long) if entire_graph: subg.seed[seed] = 1 else: subg.seed[0] = 1 return subg
def top_k(self, graph, x: torch.Tensor, scores: torch.Tensor) -> Tuple[Graph, torch.Tensor]: org_n_nodes = x.shape[0] num = int(self.pooling_rate * x.shape[0]) values, indices = torch.topk(scores, max(2, num)) if self.aug_adj: edge_index = graph.edge_index.cpu() edge_attr = torch.ones(edge_index.shape[1]) edge_index, _ = spspmm(edge_index, edge_attr, edge_index, edge_attr, org_n_nodes, org_n_nodes, org_n_nodes) edge_index = edge_index.to(x.device) batch = Graph(x=x, edge_index=edge_index) else: batch = graph new_batch = batch.subgraph(indices) new_batch.row_norm() return new_batch, indices