def message(self, x_j: Tensor, edge_attr: OptTensor): assert edge_attr is not None EPS = 1e-15 F, M = self.rel_in_channels, self.out_channels (E, D), K = edge_attr.size(), self.kernel_size if not self.separate_gaussians: gaussian = -0.5 * (edge_attr.view(E, 1, D) - self.mu.view(1, K, D)).pow(2) gaussian = gaussian / (EPS + self.sigma.view(1, K, D).pow(2)) gaussian = torch.exp(gaussian.sum(dim=-1)) # [E, K] return (x_j.view(E, K, M) * gaussian.view(E, K, 1)).sum(dim=-2) else: gaussian = -0.5 * (edge_attr.view(E, 1, 1, 1, D) - self.mu.view(1, F, M, K, D)).pow(2) gaussian = gaussian / (EPS + self.sigma.view(1, F, M, K, D).pow(2)) gaussian = torch.exp(gaussian.sum(dim=-1)) # [E, F, M, K] gaussian = gaussian * self.g.view(1, F, M, K) gaussian = gaussian.sum(dim=-1) # [E, F, M] return (x_j.view(E, F, 1) * gaussian).sum(dim=-2) # [E, M]
def aggregate(self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None, symnorm_weight: OptTensor = None) -> Tensor: outs = [] for aggr in self.aggregators: if aggr == 'symnorm': assert symnorm_weight is not None out = scatter(inputs * symnorm_weight.view(-1, 1), index, 0, None, dim_size, reduce='sum') elif aggr == 'var' or aggr == 'std': mean = scatter(inputs, index, 0, None, dim_size, reduce='mean') mean_squares = scatter(inputs * inputs, index, 0, None, dim_size, reduce='mean') out = mean_squares - mean * mean if aggr == 'std': out = torch.sqrt(out.relu_() + 1e-5) else: out = scatter(inputs, index, 0, None, dim_size, reduce=aggr) outs.append(out) return torch.stack(outs, dim=1) if len(outs) > 1 else outs[0]
def message(self, x_j, norm: OptTensor): """ In addition, tensors passed to propagate() can be mapped to the respective nodes i and j by appending _i or _j to the variable name, .e.g. x_i and x_j. Note that we generally refer to i as the central nodes that aggregates information, and refer to j as the neighboring nodes, since this is the most common notation. """ return x_j if norm is None else norm.view(-1, 1) * x_j
def message(self, x_j: Tensor, x_i: Tensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: x = x_i + x_j if edge_attr is not None: if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) assert self.lin_edge is not None edge_attr = self.lin_edge(edge_attr) edge_attr = edge_attr.view(-1, self.heads, self.out_channels) x += edge_attr x = F.leaky_relu(x, self.negative_slope) alpha = (x * self.att).sum(dim=-1) alpha = softmax(alpha, index, ptr, size_i) self._alpha = alpha alpha = F.dropout(alpha, p=self.dropout, training=self.training) return x_j * alpha.unsqueeze(-1)
def edge_update(self, alpha_j: Tensor, alpha_i: OptTensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: # Given edge-level attention coefficients for source and target nodes, # we simply need to sum them up to "emulate" concatenation: alpha = alpha_j if alpha_i is None else alpha_j + alpha_i if edge_attr is not None and self.lin_edge is not None: if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) edge_attr = self.lin_edge(edge_attr) edge_attr = edge_attr.view(-1, self.heads, self.out_channels) alpha_edge = (edge_attr * self.att_edge).sum(dim=-1) alpha = alpha + alpha_edge alpha = F.leaky_relu(alpha, self.negative_slope) alpha = softmax(alpha, index, ptr, size_i) alpha = F.dropout(alpha, p=self.dropout, training=self.training) return alpha
def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor) -> Tensor: h: Tensor = x_i # Dummy. if edge_attr is not None: edge_attr = self.edge_encoder(edge_attr) edge_attr = edge_attr.view(-1, 1, self.F_in) edge_attr = edge_attr.repeat(1, self.towers, 1) h = torch.cat([x_i, x_j, edge_attr], dim=-1) else: h = torch.cat([x_i, x_j], dim=-1) hs = [nn(h[:, i]) for i, nn in enumerate(self.pre_nns)] return torch.stack(hs, dim=1)
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
def message(self, weight_j: Tensor, edge_weight: OptTensor) -> Tensor: if edge_weight is None: return weight_j else: return edge_weight.view(-1, 1) * weight_j
def message(self, x_i: Tensor, x_j: Tensor, edge_type: Tensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: if self.num_bases is not None: # Basis-decomposition ================= w = torch.matmul(self.att, self.basis.view(self.num_bases, -1)) w = w.view(self.num_relations, self.in_channels, self.heads * self.out_channels) if self.num_blocks is not None: # Block-diagonal-decomposition ======= if (x_i.dtype == torch.long and x_j.dtype == torch.long and self.num_blocks is not None): raise ValueError('Block-diagonal decomposition not supported ' 'for non-continuous input features.') w = self.weight x_i = x_i.view(-1, 1, w.size(1), w.size(2)) x_j = x_j.view(-1, 1, w.size(1), w.size(2)) w = torch.index_select(w, 0, edge_type) outi = torch.einsum('abcd,acde->ace', x_i, w) outi = outi.contiguous().view(-1, self.heads * self.out_channels) outj = torch.einsum('abcd,acde->ace', x_j, w) outj = outj.contiguous().view(-1, self.heads * self.out_channels) else: # No regularization/Basis-decomposition ======================== if self.num_bases is None: w = self.weight w = torch.index_select(w, 0, edge_type) outi = torch.bmm(x_i.unsqueeze(1), w).squeeze(-2) outj = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2) qi = torch.matmul(outi, self.q) kj = torch.matmul(outj, self.k) alpha_edge, alpha = 0, torch.tensor([0]) if edge_attr is not None: if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) assert self.lin_edge is not None, ( "Please set 'edge_dim = edge_attr.size(-1)' while calling the " "RGATConv layer") edge_attributes = self.lin_edge(edge_attr).view( -1, self.heads * self.out_channels) if edge_attributes.size(0) != edge_attr.size(0): edge_attributes = torch.index_select(edge_attributes, 0, edge_type) alpha_edge = torch.matmul(edge_attributes, self.e) if self.attention_mode == "additive-self-attention": if edge_attr is not None: alpha = torch.add(qi, kj) + alpha_edge else: alpha = torch.add(qi, kj) alpha = F.leaky_relu(alpha, self.negative_slope) elif self.attention_mode == "multiplicative-self-attention": if edge_attr is not None: alpha = (qi * kj) * alpha_edge else: alpha = qi * kj if self.attention_mechanism == "within-relation": across_out = torch.zeros_like(alpha) for r in range(self.num_relations): mask = edge_type == r across_out[mask] = softmax(alpha[mask], index[mask]) alpha = across_out elif self.attention_mechanism == "across-relation": alpha = softmax(alpha, index, ptr, size_i) self._alpha = alpha if self.mod == "additive": if self.attention_mode == "additive-self-attention": ones = torch.ones_like(alpha) h = (outj.view(-1, self.heads, self.out_channels) * ones.view(-1, self.heads, 1)) h = torch.mul(self.w, h) return (outj.view(-1, self.heads, self.out_channels) * alpha.view(-1, self.heads, 1) + h) elif self.attention_mode == "multiplicative-self-attention": ones = torch.ones_like(alpha) h = (outj.view(-1, self.heads, 1, self.out_channels) * ones.view(-1, self.heads, self.dim, 1)) h = torch.mul(self.w, h) return (outj.view(-1, self.heads, 1, self.out_channels) * alpha.view(-1, self.heads, self.dim, 1) + h) elif self.mod == "scaled": if self.attention_mode == "additive-self-attention": ones = alpha.new_ones(index.size()) degree = scatter_add(ones, index, dim_size=size_i)[index].unsqueeze(-1) degree = torch.matmul(degree, self.l1) + self.b1 degree = self.activation(degree) degree = torch.matmul(degree, self.l2) + self.b2 return torch.mul( outj.view(-1, self.heads, self.out_channels) * alpha.view(-1, self.heads, 1), degree.view(-1, 1, self.out_channels)) elif self.attention_mode == "multiplicative-self-attention": ones = alpha.new_ones(index.size()) degree = scatter_add(ones, index, dim_size=size_i)[index].unsqueeze(-1) degree = torch.matmul(degree, self.l1) + self.b1 degree = self.activation(degree) degree = torch.matmul(degree, self.l2) + self.b2 return torch.mul( outj.view(-1, self.heads, 1, self.out_channels) * alpha.view(-1, self.heads, self.dim, 1), degree.view(-1, 1, 1, self.out_channels)) elif self.mod == "f-additive": alpha = torch.where(alpha > 0, alpha + 1, alpha) elif self.mod == "f-scaled": ones = alpha.new_ones(index.size()) degree = scatter_add(ones, index, dim_size=size_i)[index].unsqueeze(-1) alpha = alpha * degree elif self.training and self.dropout > 0: alpha = F.dropout(alpha, p=self.dropout, training=True) else: alpha = alpha # original if self.attention_mode == "additive-self-attention": return alpha.view(-1, self.heads, 1) * outj.view( -1, self.heads, self.out_channels) else: return (alpha.view(-1, self.heads, self.dim, 1) * outj.view(-1, self.heads, 1, self.out_channels))
def message(self, a_j: Tensor, b_i: Tensor, edge_weight: OptTensor) -> Tensor: out = a_j - b_i return out if edge_weight is None else out * edge_weight.view(-1, 1)
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: assert edge_weight is not None return edge_weight.view(-1, 1) * x_j
def message(self, x_j: Tensor, norm: OptTensor) -> Tensor: return norm.view(-1, 1) * x_j if norm is not None else x_j