def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None) -> Tensor: """""" if isinstance(x, Tensor): x: OptPairTensor = (x, x) # Node and edge feature dimensionalites need to match. if isinstance(edge_index, Tensor): if edge_attr is not None: assert x[0].size(-1) == edge_attr.size(-1) elif isinstance(edge_index, SparseTensor): edge_attr = edge_index.storage.value() if edge_attr is not None: assert x[0].size(-1) == edge_attr.size(-1) # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) if self.msg_norm is not None: out = self.msg_norm(x[0], out) x_r = x[1] if x_r is not None: out += x_r return self.mlp(out)
def message(self, x_j: Tensor, edge_attr: OptTensor): assert edge_attr is not None EPS = 1e-15 F, M = self.rel_in_channels, self.out_channels (E, D), K = edge_attr.size(), self.kernel_size if not self.separate_gaussians: gaussian = -0.5 * (edge_attr.view(E, 1, D) - self.mu.view(1, K, D)).pow(2) gaussian = gaussian / (EPS + self.sigma.view(1, K, D).pow(2)) gaussian = torch.exp(gaussian.sum(dim=-1)) # [E, K] return (x_j.view(E, K, M) * gaussian.view(E, K, 1)).sum(dim=-2) else: gaussian = -0.5 * (edge_attr.view(E, 1, 1, 1, D) - self.mu.view(1, F, M, K, D)).pow(2) gaussian = gaussian / (EPS + self.sigma.view(1, F, M, K, D).pow(2)) gaussian = torch.exp(gaussian.sum(dim=-1)) # [E, F, M, K] gaussian = gaussian * self.g.view(1, F, M, K) gaussian = gaussian.sum(dim=-1) # [E, F, M] return (x_j.view(E, F, 1) * gaussian).sum(dim=-2) # [E, M]
def add_self_loops( edge_index: Tensor, edge_attr: OptTensor = None, fill_value: Union[float, Tensor, str] = None, num_nodes: Optional[int] = None) -> Tuple[Tensor, OptTensor]: r"""Adds a self-loop :math:`(i,i) \in \mathcal{E}` to every node :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`. In case the graph is weighted or has multi-dimensional edge features (:obj:`edge_attr != None`), edge features of self-loops will be added according to :obj:`fill_value`. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) fill_value (float or Tensor or str, optional): The way to generate edge features of self-loops (in case :obj:`edge_attr != None`). If given as :obj:`float` or :class:`torch.Tensor`, edge features of self-loops will be directly given by :obj:`fill_value`. If given as :obj:`str`, edge features of self-loops are computed by aggregating all features of edges that point to the specific node, according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`1.`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: (:class:`LongTensor`, :class:`Tensor`) """ N = maybe_num_nodes(edge_index, num_nodes) loop_index = torch.arange(0, N, dtype=torch.long, device=edge_index.device) loop_index = loop_index.unsqueeze(0).repeat(2, 1) if edge_attr is not None: if fill_value is None: loop_attr = edge_attr.new_full((N, ) + edge_attr.size()[1:], 1.) elif isinstance(fill_value, float) or isinstance(fill_value, int): loop_attr = edge_attr.new_full((N, ) + edge_attr.size()[1:], fill_value) elif isinstance(fill_value, Tensor): loop_attr = fill_value.to(edge_attr.device, edge_attr.dtype) if edge_attr.dim() != loop_attr.dim(): loop_attr = loop_attr.unsqueeze(0) sizes = [N] + [1] * (loop_attr.dim() - 1) loop_attr = loop_attr.repeat(*sizes) elif isinstance(fill_value, str): loop_attr = scatter(edge_attr, edge_index[1], dim=0, dim_size=N, reduce=fill_value) else: raise AttributeError("No valid 'fill_value' provided") edge_attr = torch.cat([edge_attr, loop_attr], dim=0) edge_index = torch.cat([edge_index, loop_index], dim=1) return edge_index, edge_attr
def __norm__(self, edge_index, num_nodes: Optional[int], edge_weight: OptTensor, normalization: Optional[str], lambda_max, dtype: Optional[int] = None, batch: OptTensor = None): edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) edge_index, edge_weight = get_laplacian(edge_index, edge_weight, normalization, dtype, num_nodes) if batch is not None and lambda_max.numel() > 1: lambda_max = lambda_max[batch[edge_index[0]]] edge_weight = (2.0 * edge_weight) / lambda_max edge_weight.masked_fill_(edge_weight == float('inf'), 0) edge_index, edge_weight = add_self_loops(edge_index, edge_weight, fill_value=-1., num_nodes=num_nodes) assert edge_weight is not None return edge_index, edge_weight
def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: """""" if self.normalize and edge_weight is None: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if cache is None: edge_index, edge_weight = gnn.conv.gcn_conv.gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), self.improved, self.add_self_loops, dtype=x.dtype) if self.cached: self._cached_edge_index = (edge_index, edge_weight) else: edge_index, edge_weight = cache[0], cache[1] elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gnn.conv.gcn_conv.gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), self.improved, self.add_self_loops, dtype=x.dtype) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache # --- add require_grad --- edge_weight.requires_grad_(True) x = torch.matmul(x, self.weight) # propagate_type: (x: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) if self.bias is not None: out += self.bias # --- My: record edge_weight --- self.edge_weight = edge_weight return out
def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor) -> Tensor: h: Tensor = x_i # Dummy. if edge_attr is not None: edge_attr = self.edge_encoder(edge_attr) edge_attr = edge_attr.view(-1, 1, self.F_in) edge_attr = edge_attr.repeat(1, self.towers, 1) h = torch.cat([x_i, x_j, edge_attr], dim=-1) else: h = torch.cat([x_i, x_j], dim=-1) hs = [nn(h[:, i]) for i, nn in enumerate(self.pre_nns)] return torch.stack(hs, dim=1)
def aggregate(self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None, symnorm_weight: OptTensor = None) -> Tensor: outs = [] for aggr in self.aggregators: if aggr == 'symnorm': assert symnorm_weight is not None out = scatter(inputs * symnorm_weight.view(-1, 1), index, 0, None, dim_size, reduce='sum') elif aggr == 'var' or aggr == 'std': mean = scatter(inputs, index, 0, None, dim_size, reduce='mean') mean_squares = scatter(inputs * inputs, index, 0, None, dim_size, reduce='mean') out = mean_squares - mean * mean if aggr == 'std': out = torch.sqrt(out.relu_() + 1e-5) else: out = scatter(inputs, index, 0, None, dim_size, reduce=aggr) outs.append(out) return torch.stack(outs, dim=1) if len(outs) > 1 else outs[0]
def forward(self, t, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None) -> Tensor: """""" if isinstance(x, Tensor): x: OptPairTensor = (x, x) # Node and edge feature dimensionalites need to match. if isinstance(edge_index, Tensor): assert edge_attr is not None assert x[0].size(-1) == edge_attr.size(-1) elif isinstance(edge_index, SparseTensor): assert x[0].size(-1) == edge_index.size(-1) # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) x_r = x[1] if x_r is not None: out += (1 + self.eps) * x_r return self.nn(t, out)
def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor: """""" if batch is None: x = x - x.mean() out = x / (x.std(unbiased=False) + self.eps) else: batch_size = int(batch.max()) + 1 norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1) norm = norm.mul_(x.size(-1)).view(-1, 1) mean = scatter(x, batch, dim=0, dim_size=batch_size, reduce='add').sum(dim=-1, keepdim=True) / norm x = x - mean[batch] var = scatter(x * x, batch, dim=0, dim_size=batch_size, reduce='add').sum(dim=-1, keepdim=True) var = var / norm out = x / (var.sqrt()[batch] + self.eps) if self.weight is not None and self.bias is not None: out = out * self.weight + self.bias return out
def message(self, x_j, norm: OptTensor): """ In addition, tensors passed to propagate() can be mapped to the respective nodes i and j by appending _i or _j to the variable name, .e.g. x_i and x_j. Note that we generally refer to i as the central nodes that aggregates information, and refer to j as the neighboring nodes, since this is the most common notation. """ return x_j if norm is None else norm.view(-1, 1) * x_j
def message(self, x_j: Tensor, x_i: Tensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: x = x_i + x_j if edge_attr is not None: if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) assert self.lin_edge is not None edge_attr = self.lin_edge(edge_attr) edge_attr = edge_attr.view(-1, self.heads, self.out_channels) x += edge_attr x = F.leaky_relu(x, self.negative_slope) alpha = (x * self.att).sum(dim=-1) alpha = softmax(alpha, index, ptr, size_i) self._alpha = alpha alpha = F.dropout(alpha, p=self.dropout, training=self.training) return x_j * alpha.unsqueeze(-1)
def edge_update(self, alpha_j: Tensor, alpha_i: OptTensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: # Given edge-level attention coefficients for source and target nodes, # we simply need to sum them up to "emulate" concatenation: alpha = alpha_j if alpha_i is None else alpha_j + alpha_i if edge_attr is not None and self.lin_edge is not None: if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) edge_attr = self.lin_edge(edge_attr) edge_attr = edge_attr.view(-1, self.heads, self.out_channels) alpha_edge = (edge_attr * self.att_edge).sum(dim=-1) alpha = alpha + alpha_edge alpha = F.leaky_relu(alpha, self.negative_slope) alpha = softmax(alpha, index, ptr, size_i) alpha = F.dropout(alpha, p=self.dropout, training=self.training) return alpha
def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor, index: Tensor, ptr: OptTensor, edge_weight: OptTensor, size_i: Optional[int]) -> Tensor: alpha = alpha_j if alpha_i is None else alpha_j + alpha_i alpha = F.leaky_relu(alpha, self.negative_slope) alpha = softmax(alpha, index, ptr, size_i) alpha = self.att_norm(alpha * edge_weight.unsqueeze(-1), index) self._alpha = alpha alpha = F.dropout(alpha, p=self.dropout, training=self.training) return x_j * alpha.unsqueeze(-1)
def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor: """""" if batch is None: out = F.instance_norm( x.t().unsqueeze(0), self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, self.momentum, self.eps) return out.squeeze(0).t() batch_size = int(batch.max()) + 1 mean = var = unbiased_var = x # Dummies. if self.training or not self.track_running_stats: norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1) norm = norm.view(-1, 1) unbiased_norm = (norm - 1).clamp_(min=1) mean = scatter(x, batch, dim=0, dim_size=batch_size, reduce='add') / norm x = x - mean[batch] var = scatter(x * x, batch, dim=0, dim_size=batch_size, reduce='add') unbiased_var = var / unbiased_norm var = var / norm momentum = self.momentum if self.running_mean is not None: self.running_mean = ( 1 - momentum) * self.running_mean + momentum * mean.mean(0) if self.running_var is not None: self.running_var = ( 1 - momentum ) * self.running_var + momentum * unbiased_var.mean(0) else: if self.running_mean is not None: mean = self.running_mean.view(1, -1).expand(batch_size, -1) if self.running_var is not None: var = self.running_var.view(1, -1).expand(batch_size, -1) x = x - mean[batch] out = x / (var + self.eps).sqrt()[batch] if self.weight is not None and self.bias is not None: out = out * self.weight.view(1, -1) + self.bias.view(1, -1) return out
def filter_edge_store_(store: EdgeStorage, out_store: EdgeStorage, row: Tensor, col: Tensor, index: Tensor, perm: OptTensor = None) -> EdgeStorage: # Filters a edge storage object to only hold the edges in `index`, # which represents the new graph as denoted by `(row, col)`: for key, value in store.items(): if key == 'edge_index': edge_index = torch.stack([row, col], dim=0) out_store.edge_index = edge_index.to(value.device) elif key == 'adj_t': # NOTE: We expect `(row, col)` to be sorted by `col` (CSC layout). row = row.to(value.device()) col = col.to(value.device()) edge_attr = value.storage.value() if edge_attr is not None: index = index.to(edge_attr.device) edge_attr = edge_attr[index] sparse_sizes = out_store.size()[::-1] # TODO Currently, we set `is_sorted=False`, see: # https://github.com/pyg-team/pytorch_geometric/issues/4346 out_store.adj_t = SparseTensor(row=col, col=row, value=edge_attr, sparse_sizes=sparse_sizes, is_sorted=False, trust_data=True) elif store.is_edge_attr(key): dim = store._parent().__cat_dim__(key, value, store) if perm is None: index = index.to(value.device) out_store[key] = index_select(value, index, dim=dim) else: perm = perm.to(value.device) index = index.to(value.device) out_store[key] = index_select(value, perm[index], dim=dim) return store
def message(self, x_j: Tensor, edge_attr: OptTensor, edge_attr_len: OptTensor, duplicates_idx: OptTensor) -> Tensor: # print('in message') # print('x_j.size(), edge_attr.size()', x_j.size(), edge_attr.size()) just_made_dup = False if duplicates_idx is None: # 1st layer, so make duplicates_idx just_made_dup = True edge_indice, cnt_ = torch.unique( edge_attr, return_counts=True, ) duplicates_idx = [ [idx] * (c - 1) for idx, c in zip(edge_indice[cnt_ != 1], cnt_[cnt_ != 1]) ] duplicates_idx = [_ for li in duplicates_idx for _ in li] # duplicates_idx is not None x_j = torch.cat((x_j, x_j[duplicates_idx, ]), dim=0) # ((E+num_duplicates),H) # print("duplicates_idx: ", duplicates_idx) # print("x_j: ", x_j) if just_made_dup: for i, idx in enumerate(duplicates_idx): edge_attr[torch.arange(edge_attr.size(0))[edge_attr == idx] [-1]] = x_j.size(0) - len(duplicates_idx) + i # print('x_j.size(), edge_attr.size()', x_j.size(), edge_attr.size()) if self.edge_enc: # print("Adding edge_enc ... ") edge_pos_emb = self.edge_enc(edge_attr_len) # print(edge_pos_emb.size()) x_j[edge_attr] += edge_pos_emb.squeeze() # print("x_j size", x_j.size()) self.duplicates_idx = duplicates_idx return F.relu(x_j) + self.eps, duplicates_idx
def filter_edge_store_(store: EdgeStorage, out_store: EdgeStorage, row: Tensor, col: Tensor, index: Tensor, perm: OptTensor = None) -> EdgeStorage: # Filters a edge storage object to only hold the edges in `index`, # which represents the new graph as denoted by `(row, col)`: for key, value in store.items(): if key == 'edge_index': edge_index = torch.stack([row, col], dim=0) out_store.edge_index = edge_index.to(value.device) elif key == 'adj_t': # NOTE: We expect `(row, col)` to be sorted by `col` (CSC layout). row = row.to(value.device()) col = col.to(value.device()) edge_attr = value.storage.value() if edge_attr is not None: index = index.to(edge_attr.device) edge_attr = edge_attr[index] sparse_sizes = store.size()[::-1] out_store.adj_t = SparseTensor(row=col, col=row, value=edge_attr, sparse_sizes=sparse_sizes, is_sorted=True) elif store.is_edge_attr(key): if perm is None: index = index.to(value.device) out_store[key] = index_select(value, index, dim=0) else: perm = perm.to(value.device) index = index.to(value.device) out_store[key] = index_select(value, perm[index], dim=0) return store
def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor: """""" if self.mode == 'graph': if batch is None: x = x - x.mean() out = x / (x.std(unbiased=False) + self.eps) else: batch_size = int(batch.max()) + 1 norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1) norm = norm.mul_(x.size(-1)).view(-1, 1) mean = scatter( x, batch, dim=0, dim_size=batch_size, reduce='add').sum( dim=-1, keepdim=True) / norm x = x - mean.index_select(0, batch) var = scatter(x * x, batch, dim=0, dim_size=batch_size, reduce='add').sum(dim=-1, keepdim=True) var = var / norm out = x / (var + self.eps).sqrt().index_select(0, batch) if self.weight is not None and self.bias is not None: out = out * self.weight + self.bias return out if self.mode == 'node': return F.layer_norm(x, (self.in_channels, ), self.weight, self.bias, self.eps) raise ValueError(f"Unknow normalization mode: {self.mode}")
def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]], edge_index: Adj, edge_type: OptTensor = None): r""" Args: x: The input node features. Can be either a :obj:`[num_nodes, in_channels]` node feature matrix, or an optional one-dimensional node index tensor (in which case input features are treated as trainable node embeddings). Furthermore, :obj:`x` can be of type :obj:`tuple` denoting source and destination node features. edge_index (LongTensor or SparseTensor): The edge indices. edge_type: The one-dimensional relation type/index for each edge in :obj:`edge_index`. Should be only :obj:`None` in case :obj:`edge_index` is of type :class:`torch_sparse.tensor.SparseTensor`. (default: :obj:`None`) """ # Convert input features to a pair of node features or node indices. x_l: OptTensor = None if isinstance(x, tuple): x_l = x[0] else: x_l = x if x_l is None: x_l = torch.arange(self.in_channels_l, device=self.weight.device) x_r: Tensor = x_l if isinstance(x, tuple): x_r = x[1] size = (x_l.size(0), x_r.size(0)) if isinstance(edge_index, SparseTensor): edge_type = edge_index.storage.value() assert edge_type is not None # propagate_type: (x: Tensor, edge_type_ptr: OptTensor) out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device) weight = self.weight if self.num_bases is not None: # Basis-decomposition ================= weight = (self.comp @ weight.view(self.num_bases, -1)).view( self.num_relations, self.in_channels_l, self.out_channels) if self.num_blocks is not None: # Block-diagonal-decomposition ===== if x_l.dtype == torch.long and self.num_blocks is not None: raise ValueError('Block-diagonal decomposition not supported ' 'for non-continuous input features.') for i in range(self.num_relations): tmp = masked_edge_index(edge_index, edge_type == i) h = self.propagate(tmp, x=x_l, edge_type_ptr=None, size=size) h = h.view(-1, weight.size(1), weight.size(2)) h = torch.einsum('abc,bcd->abd', h, weight[i]) out += h.contiguous().view(-1, self.out_channels) else: # No regularization/Basis-decomposition ======================== if self._WITH_PYG_LIB and isinstance(edge_index, Tensor): if not self.is_sorted: if (edge_type[1:] < edge_type[:-1]).any(): edge_type, perm = edge_type.sort() edge_index = edge_index[:, perm] edge_type_ptr = torch.ops.torch_sparse.ind2ptr( edge_type, self.num_relations) out = self.propagate(edge_index, x=x_l, edge_type_ptr=edge_type_ptr, size=size) else: for i in range(self.num_relations): tmp = masked_edge_index(edge_index, edge_type == i) if x_l.dtype == torch.long: out += self.propagate(tmp, x=weight[i, x_l], edge_type_ptr=None, size=size) else: h = self.propagate(tmp, x=x_l, edge_type_ptr=None, size=size) out = out + (h @ weight[i]) root = self.root if root is not None: out += root[x_r] if x_r.dtype == torch.long else x_r @ root if self.bias is not None: out += self.bias return out
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
def message(self, weight_j: Tensor, edge_weight: OptTensor) -> Tensor: if edge_weight is None: return weight_j else: return edge_weight.view(-1, 1) * weight_j
def message(self, x_i: Tensor, x_j: Tensor, edge_type: Tensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: if self.num_bases is not None: # Basis-decomposition ================= w = torch.matmul(self.att, self.basis.view(self.num_bases, -1)) w = w.view(self.num_relations, self.in_channels, self.heads * self.out_channels) if self.num_blocks is not None: # Block-diagonal-decomposition ======= if (x_i.dtype == torch.long and x_j.dtype == torch.long and self.num_blocks is not None): raise ValueError('Block-diagonal decomposition not supported ' 'for non-continuous input features.') w = self.weight x_i = x_i.view(-1, 1, w.size(1), w.size(2)) x_j = x_j.view(-1, 1, w.size(1), w.size(2)) w = torch.index_select(w, 0, edge_type) outi = torch.einsum('abcd,acde->ace', x_i, w) outi = outi.contiguous().view(-1, self.heads * self.out_channels) outj = torch.einsum('abcd,acde->ace', x_j, w) outj = outj.contiguous().view(-1, self.heads * self.out_channels) else: # No regularization/Basis-decomposition ======================== if self.num_bases is None: w = self.weight w = torch.index_select(w, 0, edge_type) outi = torch.bmm(x_i.unsqueeze(1), w).squeeze(-2) outj = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2) qi = torch.matmul(outi, self.q) kj = torch.matmul(outj, self.k) alpha_edge, alpha = 0, torch.tensor([0]) if edge_attr is not None: if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) assert self.lin_edge is not None, ( "Please set 'edge_dim = edge_attr.size(-1)' while calling the " "RGATConv layer") edge_attributes = self.lin_edge(edge_attr).view( -1, self.heads * self.out_channels) if edge_attributes.size(0) != edge_attr.size(0): edge_attributes = torch.index_select(edge_attributes, 0, edge_type) alpha_edge = torch.matmul(edge_attributes, self.e) if self.attention_mode == "additive-self-attention": if edge_attr is not None: alpha = torch.add(qi, kj) + alpha_edge else: alpha = torch.add(qi, kj) alpha = F.leaky_relu(alpha, self.negative_slope) elif self.attention_mode == "multiplicative-self-attention": if edge_attr is not None: alpha = (qi * kj) * alpha_edge else: alpha = qi * kj if self.attention_mechanism == "within-relation": across_out = torch.zeros_like(alpha) for r in range(self.num_relations): mask = edge_type == r across_out[mask] = softmax(alpha[mask], index[mask]) alpha = across_out elif self.attention_mechanism == "across-relation": alpha = softmax(alpha, index, ptr, size_i) self._alpha = alpha if self.mod == "additive": if self.attention_mode == "additive-self-attention": ones = torch.ones_like(alpha) h = (outj.view(-1, self.heads, self.out_channels) * ones.view(-1, self.heads, 1)) h = torch.mul(self.w, h) return (outj.view(-1, self.heads, self.out_channels) * alpha.view(-1, self.heads, 1) + h) elif self.attention_mode == "multiplicative-self-attention": ones = torch.ones_like(alpha) h = (outj.view(-1, self.heads, 1, self.out_channels) * ones.view(-1, self.heads, self.dim, 1)) h = torch.mul(self.w, h) return (outj.view(-1, self.heads, 1, self.out_channels) * alpha.view(-1, self.heads, self.dim, 1) + h) elif self.mod == "scaled": if self.attention_mode == "additive-self-attention": ones = alpha.new_ones(index.size()) degree = scatter_add(ones, index, dim_size=size_i)[index].unsqueeze(-1) degree = torch.matmul(degree, self.l1) + self.b1 degree = self.activation(degree) degree = torch.matmul(degree, self.l2) + self.b2 return torch.mul( outj.view(-1, self.heads, self.out_channels) * alpha.view(-1, self.heads, 1), degree.view(-1, 1, self.out_channels)) elif self.attention_mode == "multiplicative-self-attention": ones = alpha.new_ones(index.size()) degree = scatter_add(ones, index, dim_size=size_i)[index].unsqueeze(-1) degree = torch.matmul(degree, self.l1) + self.b1 degree = self.activation(degree) degree = torch.matmul(degree, self.l2) + self.b2 return torch.mul( outj.view(-1, self.heads, 1, self.out_channels) * alpha.view(-1, self.heads, self.dim, 1), degree.view(-1, 1, 1, self.out_channels)) elif self.mod == "f-additive": alpha = torch.where(alpha > 0, alpha + 1, alpha) elif self.mod == "f-scaled": ones = alpha.new_ones(index.size()) degree = scatter_add(ones, index, dim_size=size_i)[index].unsqueeze(-1) alpha = alpha * degree elif self.training and self.dropout > 0: alpha = F.dropout(alpha, p=self.dropout, training=True) else: alpha = alpha # original if self.attention_mode == "additive-self-attention": return alpha.view(-1, self.heads, 1) * outj.view( -1, self.heads, self.out_channels) else: return (alpha.view(-1, self.heads, self.dim, 1) * outj.view(-1, self.heads, 1, self.out_channels))
def message(self, a_j: Tensor, b_i: Tensor, edge_weight: OptTensor) -> Tensor: out = a_j - b_i return out if edge_weight is None else out * edge_weight.view(-1, 1)
def homophily(edge_index: Adj, y: Tensor, batch: OptTensor = None, method: str = 'edge') -> Union[float, Tensor]: r"""The homophily of a graph characterizes how likely nodes with the same label are near each other in a graph. There are many measures of homophily that fits this definition. In particular: - In the `"Beyond Homophily in Graph Neural Networks: Current Limitations and Effective Designs" <https://arxiv.org/abs/2006.11468>`_ paper, the homophily is the fraction of edges in a graph which connects nodes that have the same class label: .. math:: \frac{| \{ (v,w) : (v,w) \in \mathcal{E} \wedge y_v = y_w \} | } {|\mathcal{E}|} That measure is called the *edge homophily ratio*. - In the `"Geom-GCN: Geometric Graph Convolutional Networks" <https://arxiv.org/abs/2002.05287>`_ paper, edge homophily is normalized across neighborhoods: .. math:: \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \frac{ | \{ (w,v) : w \in \mathcal{N}(v) \wedge y_v = y_w \} | } { |\mathcal{N}(v)| } That measure is called the *node homophily ratio*. - In the `"Large-Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods" <https://arxiv.org/abs/2110.14446>`_ paper, edge homophily is modified to be insensitive to the number of classes and size of each class: .. math:: \frac{1}{C-1} \sum_{k=1}^{C} \max \left(0, h_k - \frac{|\mathcal{C}_k|} {|\mathcal{V}|} \right), where :math:`C` denotes the number of classes, :math:`|\mathcal{C}_k|` denotes the number of nodes of class :math:`k`, and :math:`h_k` denotes the edge homophily ratio of nodes of class :math:`k`. Thus, that measure is called the *class insensitive edge homophily ratio*. Args: edge_index (Tensor or SparseTensor): The graph connectivity. y (Tensor): The labels. batch (LongTensor, optional): Batch vector\ :math:`\mathbf{b} \in {\{ 0, \ldots,B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) method (str, optional): The method used to calculate the homophily, either :obj:`"edge"` (first formula), :obj:`"node"` (second formula) or :obj:`"edge_insensitive"` (third formula). (default: :obj:`"edge"`) """ assert method in {'edge', 'node', 'edge_insensitive'} y = y.squeeze(-1) if y.dim() > 1 else y if isinstance(edge_index, SparseTensor): row, col, _ = edge_index.coo() else: row, col = edge_index if method == 'edge': out = torch.zeros(row.size(0), device=row.device) out[y[row] == y[col]] = 1. if batch is None: return float(out.mean()) else: dim_size = int(batch.max()) + 1 return scatter_mean(out, batch[col], dim=0, dim_size=dim_size) elif method == 'node': out = torch.zeros(row.size(0), device=row.device) out[y[row] == y[col]] = 1. out = scatter_mean(out, col, 0, dim_size=y.size(0)) if batch is None: return float(out.mean()) else: return scatter_mean(out, batch, dim=0) elif method == 'edge_insensitive': assert y.dim() == 1 num_classes = int(y.max()) + 1 assert num_classes >= 2 batch = torch.zeros_like(y) if batch is None else batch num_nodes = degree(batch, dtype=torch.int64) num_graphs = num_nodes.numel() batch = num_classes * batch + y h = homophily(edge_index, y, batch, method='edge') h = h.view(num_graphs, num_classes) counts = batch.bincount(minlength=num_classes * num_graphs) counts = counts.view(num_graphs, num_classes) proportions = counts / num_nodes.view(-1, 1) out = (h - proportions).clamp_(min=0).sum(dim=-1) out /= num_classes - 1 return out if out.numel() > 1 else float(out) else: raise NotImplementedError
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: assert edge_weight is not None return x_j * edge_weight.unsqueeze(1)
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: assert edge_weight is not None return edge_weight.view(-1, 1) * x_j
def message(self, x_j: Tensor, norm: OptTensor) -> Tensor: return norm.view(-1, 1) * x_j if norm is not None else x_j