def add_self_loops( edge_index: Tensor, edge_attr: OptTensor = None, fill_value: Union[float, Tensor, str] = None, num_nodes: Optional[int] = None) -> Tuple[Tensor, OptTensor]: r"""Adds a self-loop :math:`(i,i) \in \mathcal{E}` to every node :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`. In case the graph is weighted or has multi-dimensional edge features (:obj:`edge_attr != None`), edge features of self-loops will be added according to :obj:`fill_value`. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) fill_value (float or Tensor or str, optional): The way to generate edge features of self-loops (in case :obj:`edge_attr != None`). If given as :obj:`float` or :class:`torch.Tensor`, edge features of self-loops will be directly given by :obj:`fill_value`. If given as :obj:`str`, edge features of self-loops are computed by aggregating all features of edges that point to the specific node, according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`1.`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: (:class:`LongTensor`, :class:`Tensor`) """ N = maybe_num_nodes(edge_index, num_nodes) loop_index = torch.arange(0, N, dtype=torch.long, device=edge_index.device) loop_index = loop_index.unsqueeze(0).repeat(2, 1) if edge_attr is not None: if fill_value is None: loop_attr = edge_attr.new_full((N, ) + edge_attr.size()[1:], 1.) elif isinstance(fill_value, float) or isinstance(fill_value, int): loop_attr = edge_attr.new_full((N, ) + edge_attr.size()[1:], fill_value) elif isinstance(fill_value, Tensor): loop_attr = fill_value.to(edge_attr.device, edge_attr.dtype) if edge_attr.dim() != loop_attr.dim(): loop_attr = loop_attr.unsqueeze(0) sizes = [N] + [1] * (loop_attr.dim() - 1) loop_attr = loop_attr.repeat(*sizes) elif isinstance(fill_value, str): loop_attr = scatter(edge_attr, edge_index[1], dim=0, dim_size=N, reduce=fill_value) else: raise AttributeError("No valid 'fill_value' provided") edge_attr = torch.cat([edge_attr, loop_attr], dim=0) edge_index = torch.cat([edge_index, loop_index], dim=1) return edge_index, edge_attr
def message(self, x_j: Tensor, x_i: Tensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: x = x_i + x_j if edge_attr is not None: if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) assert self.lin_edge is not None edge_attr = self.lin_edge(edge_attr) edge_attr = edge_attr.view(-1, self.heads, self.out_channels) x += edge_attr x = F.leaky_relu(x, self.negative_slope) alpha = (x * self.att).sum(dim=-1) alpha = softmax(alpha, index, ptr, size_i) self._alpha = alpha alpha = F.dropout(alpha, p=self.dropout, training=self.training) return x_j * alpha.unsqueeze(-1)
def edge_update(self, alpha_j: Tensor, alpha_i: OptTensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: # Given edge-level attention coefficients for source and target nodes, # we simply need to sum them up to "emulate" concatenation: alpha = alpha_j if alpha_i is None else alpha_j + alpha_i if edge_attr is not None and self.lin_edge is not None: if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) edge_attr = self.lin_edge(edge_attr) edge_attr = edge_attr.view(-1, self.heads, self.out_channels) alpha_edge = (edge_attr * self.att_edge).sum(dim=-1) alpha = alpha + alpha_edge alpha = F.leaky_relu(alpha, self.negative_slope) alpha = softmax(alpha, index, ptr, size_i) alpha = F.dropout(alpha, p=self.dropout, training=self.training) return alpha
def message(self, x_i: Tensor, x_j: Tensor, edge_type: Tensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: if self.num_bases is not None: # Basis-decomposition ================= w = torch.matmul(self.att, self.basis.view(self.num_bases, -1)) w = w.view(self.num_relations, self.in_channels, self.heads * self.out_channels) if self.num_blocks is not None: # Block-diagonal-decomposition ======= if (x_i.dtype == torch.long and x_j.dtype == torch.long and self.num_blocks is not None): raise ValueError('Block-diagonal decomposition not supported ' 'for non-continuous input features.') w = self.weight x_i = x_i.view(-1, 1, w.size(1), w.size(2)) x_j = x_j.view(-1, 1, w.size(1), w.size(2)) w = torch.index_select(w, 0, edge_type) outi = torch.einsum('abcd,acde->ace', x_i, w) outi = outi.contiguous().view(-1, self.heads * self.out_channels) outj = torch.einsum('abcd,acde->ace', x_j, w) outj = outj.contiguous().view(-1, self.heads * self.out_channels) else: # No regularization/Basis-decomposition ======================== if self.num_bases is None: w = self.weight w = torch.index_select(w, 0, edge_type) outi = torch.bmm(x_i.unsqueeze(1), w).squeeze(-2) outj = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2) qi = torch.matmul(outi, self.q) kj = torch.matmul(outj, self.k) alpha_edge, alpha = 0, torch.tensor([0]) if edge_attr is not None: if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) assert self.lin_edge is not None, ( "Please set 'edge_dim = edge_attr.size(-1)' while calling the " "RGATConv layer") edge_attributes = self.lin_edge(edge_attr).view( -1, self.heads * self.out_channels) if edge_attributes.size(0) != edge_attr.size(0): edge_attributes = torch.index_select(edge_attributes, 0, edge_type) alpha_edge = torch.matmul(edge_attributes, self.e) if self.attention_mode == "additive-self-attention": if edge_attr is not None: alpha = torch.add(qi, kj) + alpha_edge else: alpha = torch.add(qi, kj) alpha = F.leaky_relu(alpha, self.negative_slope) elif self.attention_mode == "multiplicative-self-attention": if edge_attr is not None: alpha = (qi * kj) * alpha_edge else: alpha = qi * kj if self.attention_mechanism == "within-relation": across_out = torch.zeros_like(alpha) for r in range(self.num_relations): mask = edge_type == r across_out[mask] = softmax(alpha[mask], index[mask]) alpha = across_out elif self.attention_mechanism == "across-relation": alpha = softmax(alpha, index, ptr, size_i) self._alpha = alpha if self.mod == "additive": if self.attention_mode == "additive-self-attention": ones = torch.ones_like(alpha) h = (outj.view(-1, self.heads, self.out_channels) * ones.view(-1, self.heads, 1)) h = torch.mul(self.w, h) return (outj.view(-1, self.heads, self.out_channels) * alpha.view(-1, self.heads, 1) + h) elif self.attention_mode == "multiplicative-self-attention": ones = torch.ones_like(alpha) h = (outj.view(-1, self.heads, 1, self.out_channels) * ones.view(-1, self.heads, self.dim, 1)) h = torch.mul(self.w, h) return (outj.view(-1, self.heads, 1, self.out_channels) * alpha.view(-1, self.heads, self.dim, 1) + h) elif self.mod == "scaled": if self.attention_mode == "additive-self-attention": ones = alpha.new_ones(index.size()) degree = scatter_add(ones, index, dim_size=size_i)[index].unsqueeze(-1) degree = torch.matmul(degree, self.l1) + self.b1 degree = self.activation(degree) degree = torch.matmul(degree, self.l2) + self.b2 return torch.mul( outj.view(-1, self.heads, self.out_channels) * alpha.view(-1, self.heads, 1), degree.view(-1, 1, self.out_channels)) elif self.attention_mode == "multiplicative-self-attention": ones = alpha.new_ones(index.size()) degree = scatter_add(ones, index, dim_size=size_i)[index].unsqueeze(-1) degree = torch.matmul(degree, self.l1) + self.b1 degree = self.activation(degree) degree = torch.matmul(degree, self.l2) + self.b2 return torch.mul( outj.view(-1, self.heads, 1, self.out_channels) * alpha.view(-1, self.heads, self.dim, 1), degree.view(-1, 1, 1, self.out_channels)) elif self.mod == "f-additive": alpha = torch.where(alpha > 0, alpha + 1, alpha) elif self.mod == "f-scaled": ones = alpha.new_ones(index.size()) degree = scatter_add(ones, index, dim_size=size_i)[index].unsqueeze(-1) alpha = alpha * degree elif self.training and self.dropout > 0: alpha = F.dropout(alpha, p=self.dropout, training=True) else: alpha = alpha # original if self.attention_mode == "additive-self-attention": return alpha.view(-1, self.heads, 1) * outj.view( -1, self.heads, self.out_channels) else: return (alpha.view(-1, self.heads, self.dim, 1) * outj.view(-1, self.heads, 1, self.out_channels))