Example #1
0
def _rwr_trace_to_cogdl_graph(g: Graph,
                              seed: int,
                              trace: torch.Tensor,
                              positional_embedding_size: int,
                              entire_graph: bool = False):
    subv = torch.unique(trace).tolist()
    try:
        subv.remove(seed)
    except ValueError:
        pass
    subv = [seed] + subv
    if entire_graph:
        subg = copy.deepcopy(g)
    else:
        subg = g.subgraph(subv)

    subg = _add_undirected_graph_positional_embedding(
        subg, positional_embedding_size)

    subg.seed = torch.zeros(subg.num_nodes, dtype=torch.long)
    if entire_graph:
        subg.seed[seed] = 1
    else:
        subg.seed[0] = 1
    return subg
Example #2
0
    def top_k(self, graph, x: torch.Tensor, scores: torch.Tensor) -> Tuple[Graph, torch.Tensor]:
        org_n_nodes = x.shape[0]
        num = int(self.pooling_rate * x.shape[0])
        values, indices = torch.topk(scores, max(2, num))

        if self.aug_adj:
            edge_index = graph.edge_index.cpu()
            edge_attr = torch.ones(edge_index.shape[1])
            edge_index, _ = spspmm(edge_index, edge_attr, edge_index, edge_attr, org_n_nodes, org_n_nodes, org_n_nodes)
            edge_index = edge_index.to(x.device)
            batch = Graph(x=x, edge_index=edge_index)
        else:
            batch = graph
        new_batch = batch.subgraph(indices)

        new_batch.row_norm()
        return new_batch, indices