def gcn_no_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, add_self_loops=True, dtype=None): fill_value = 2. if improved else 1. if isinstance(edge_index, SparseTensor): adj_t = edge_index if not adj_t.has_value(): adj_t = adj_t.fill_value(1., dtype=dtype) if add_self_loops: adj_t = fill_diag(adj_t, fill_value) return adj_t else: num_nodes = maybe_num_nodes(edge_index, num_nodes) if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) if add_self_loops: edge_index, tmp_edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value, num_nodes) assert tmp_edge_weight is not None edge_weight = tmp_edge_weight row, col = edge_index[0], edge_index[1] return edge_index, edge_weight
def gcn_norm(adj_t): adj_t = torch_sparse.fill_diag(adj_t, 1) # add self-loop deg = torch_sparse.sum(adj_t, dim=1).pow_(-0.5) # compute normalized degree matrix deg.masked_fill_(deg == float('inf'), 0.) # for numerical stability adj_t = torch_sparse.mul(adj_t, deg.view(-1, 1)) # row-wise mul adj_t = torch_sparse.mul(adj_t, deg.view(1, -1)) # col-wise mul return adj_t
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, add_self_loops=True, flow="source_to_target", dtype=None): fill_value = 2. if improved else 1. if isinstance(edge_index, SparseTensor): assert flow in ["source_to_target"] adj_t = edge_index if not adj_t.has_value(): adj_t = adj_t.fill_value(1., dtype=dtype) if add_self_loops: adj_t = fill_diag(adj_t, fill_value) deg = sparsesum(adj_t, dim=1) deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.) adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1)) adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1)) return adj_t else: assert flow in ["source_to_target", "target_to_source"] num_nodes = maybe_num_nodes(edge_index, num_nodes) if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) if add_self_loops: edge_index, tmp_edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value, num_nodes) assert tmp_edge_weight is not None edge_weight = tmp_edge_weight row, col = edge_index[0], edge_index[1] idx = col if flow == "source_to_target" else row deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def _sogcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, add_self_loops=True, dtype=None): fill_value = 2. if improved else 1. if isinstance(edge_index, torch_sparse.SparseTensor): adj_t = edge_index if not adj_t.has_value(): adj_t = adj_t.fill_value(1.) if add_self_loops: adj_t = torch_sparse.fill_diag(adj_t, fill_value) deg = sum(adj_t, dim=1) deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.) adj_t = torch_sparse.mul(adj_t, deg_inv_sqrt.view(-1, 1)) adj_t = torch_sparse.mul(adj_t, deg_inv_sqrt.view(1, -1)) return adj_t else: num_nodes = maybe_num_nodes(edge_index, num_nodes) if edge_weight is None: edge_weight = torch.ones((edge_index.shape[1], ), dtype=dtype, device=edge_index.device) if add_self_loops: edge_index, tmp_edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value, num_nodes) assert tmp_edge_weight is not None edge_weight = tmp_edge_weight row, col = edge_index[0], edge_index[1] deg = torch_scatter.scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def forward(self, x, edge_index, edge_attr=None, batch=None): if batch is None: batch = edge_index.new_zeros(x.size(0)) num_node = x.size(0) k = F.relu(self.lin_2(x)) A = SparseTensor.from_edge_index(edge_index=edge_index, edge_attr=edge_attr, sparse_sizes=(num_node, num_node)) I = SparseTensor.eye(num_node, device=self.args.device) A_wave = fill_diag(A, 1) s = A_wave @ k score = s.squeeze() perm = topk(score, self.ratio, batch) A = self.norm(A) K_neighbor = A * k.T x_neighbor = K_neighbor @ x # ----modified deg = sum(A, dim=1) deg_inv = deg.pow_(-1) deg_inv.masked_fill_(deg_inv == float('inf'), 0.) x_neighbor = x_neighbor * deg_inv.view(1, -1).T # ---- x_self = x * k x = x_neighbor * ( 1 - self.args.combine_ratio) + x_self * self.args.combine_ratio x = x[perm] batch = batch[perm] edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=s.size(0)) return x, edge_index, edge_attr, batch, perm
def forward(self, x: Tensor, edge_index: Adj) -> Tensor: """""" symnorm_weight: OptTensor = None if "symnorm" in self.aggregators: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if cache is None: edge_index, symnorm_weight = gcn_norm( # yapf: disable edge_index, None, num_nodes=x.size(self.node_dim), improved=False, add_self_loops=self.add_self_loops) if self.cached: self._cached_edge_index = (edge_index, symnorm_weight) else: edge_index, symnorm_weight = cache elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, None, num_nodes=x.size(self.node_dim), improved=False, add_self_loops=self.add_self_loops) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache elif self.add_self_loops: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if self.cached and cache is not None: edge_index = cache[0] else: edge_index, _ = add_remaining_self_loops(edge_index) if self.cached: self._cached_edge_index = (edge_index, None) elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if self.cached and cache is not None: edge_index = cache else: edge_index = fill_diag(edge_index, 1.0) if self.cached: self._cached_adj_t = edge_index # [num_nodes, (out_channels // num_heads) * num_bases] bases = self.bases_lin(x) # [num_nodes, num_heads * num_bases * num_aggrs] weightings = self.comb_lin(x) # [num_nodes, num_aggregators, (out_channels // num_heads) * num_bases] # propagate_type: (x: Tensor, symnorm_weight: OptTensor) aggregated = self.propagate(edge_index, x=bases, symnorm_weight=symnorm_weight, size=None) weightings = weightings.view(-1, self.num_heads, self.num_bases * len(self.aggregators)) aggregated = aggregated.view( -1, len(self.aggregators) * self.num_bases, self.out_channels // self.num_heads, ) # [num_nodes, num_heads, out_channels // num_heads] out = torch.matmul(weightings, aggregated) out = out.view(-1, self.out_channels) if self.bias is not None: out += self.bias return out