def forward(self, x: Tensor, edge_index: Adj) -> Tensor: """""" edge_weight: OptTensor = None if isinstance(edge_index, Tensor): num_nodes = x.size(self.node_dim) if self.add_self_loops: edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) row, col = edge_index[0], edge_index[1] deg_inv = 1. / degree(col, num_nodes=num_nodes).clamp_(1.) edge_weight = deg_inv[col] edge_weight[row == col] += self.diag_lambda * deg_inv elif isinstance(edge_index, SparseTensor): if self.add_self_loops: edge_index = set_diag(edge_index) col, row, _ = edge_index.coo() # Transposed. deg_inv = 1. / sparsesum(edge_index, dim=1).clamp_(1.) edge_weight = deg_inv[col] edge_weight[row == col] += self.diag_lambda * deg_inv edge_index = edge_index.set_value(edge_weight, layout='coo') # propagate_type: (x: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) out = self.lin_out(out) + self.lin_root(x) return out
def adj_type_to_edge_tensor_type(layout: EdgeLayout, edge_index: Adj) -> EdgeTensorType: r"""Converts a PyG Adj tensor to an EdgeTensorType equivalent.""" if isinstance(edge_index, Tensor): return (edge_index[0], edge_index[1]) # (row, col) if layout == EdgeLayout.COO: return edge_index.coo()[:-1] # (row, col) elif layout == EdgeLayout.CSR: return edge_index.csr()[:-1] # (rowptr, col) else: return edge_index.csr()[-2::-1] # (row, colptr)
def forward(self, y: Tensor, edge_index: Adj, mask: Tensor, edge_weight: OptTensor = None) -> Tensor: """""" if mask.dtype == torch.long: idx = mask mask = mask.new_zeros(y.size(0), dtype=torch.bool) mask[idx] = True y = F.one_hot(y.view(-1)).to(torch.float) y[~mask] = 0. if isinstance(edge_index, SparseTensor) and not edge_index.has_value(): edge_index = gcn_norm(edge_index, add_self_loops=False) elif isinstance(edge_index, Tensor) and edge_weight is None: edge_index, edge_weight = gcn_norm(edge_index, num_nodes=y.size(0), add_self_loops=False) res = (1 - self.alpha) * y for _ in range(self.num_layers): # propagate_type: (y: Tensor, edge_weight: OptTensor) y = self.propagate(edge_index, y=y, edge_weight=edge_weight, size=None) y.mul_(self.alpha).add_(res).clamp_(0, 1) return y
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_weight=None) -> Tensor: """""" if isinstance(x, Tensor): x: OptPairTensor = (x, x) # propagate_type: (x: OptPairTensor) # start of modified code ######### if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=x[0].dtype, device=edge_index.device) out = self.propagate(edge_index, x=x, size=None, edge_weight=edge_weight) # end of modified code ######### x_r = x[1] if x_r is not None: out += (1 + self.eps) * x_r return self.nn(out)
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_type: Optional[torch.Tensor] = None, size: Size = None, return_attention_weights=None): r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x = x.view(-1, self.heads, self.out_channels) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) out = self.propagate(edge_index, x=x, edge_type=edge_type, size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward( self, y: Tensor, edge_index: Adj, mask: Optional[Tensor] = None, edge_weight: OptTensor = None, post_step: Callable = lambda y: y.clamp_(0., 1.) ) -> Tensor: """""" if y.dtype == torch.long: y = F.one_hot(y.view(-1)).to(torch.float) out = y if mask is not None: out = torch.zeros_like(y) out[mask] = y[mask] if isinstance(edge_index, SparseTensor) and not edge_index.has_value(): edge_index = gcn_norm(edge_index, add_self_loops=False) elif isinstance(edge_index, Tensor) and edge_weight is None: edge_index, edge_weight = gcn_norm(edge_index, num_nodes=y.size(0), add_self_loops=False) res = (1 - self.alpha) * out for _ in range(self.num_layers): # propagate_type: (y: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=out, edge_weight=edge_weight, size=None) out.mul_(self.alpha).add_(res) out = post_step(out) return out
def forward(self, x: Tensor, edge_index: Adj, edge_attr: OptTensor = None) -> Tensor: """""" if isinstance(edge_index, SparseTensor): edge_attr = edge_index.storage.value() if edge_attr is not None: edge_attr = self.mlp(edge_attr).squeeze(-1) if isinstance(edge_index, SparseTensor): edge_index = edge_index.set_value(edge_attr, layout='coo') if self.normalize: if isinstance(edge_index, Tensor): edge_index, edge_attr = gcn_norm(edge_index, edge_attr, x.size(self.node_dim), False, self.add_self_loops) elif isinstance(edge_index, SparseTensor): edge_index = gcn_norm(edge_index, None, x.size(self.node_dim), False, self.add_self_loops) x = self.lin(x) # propagate_type: (x: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=x, edge_weight=edge_attr, size=None) if self.bias is not None: out += self.bias return out
def forward(self, t, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None) -> Tensor: """""" if isinstance(x, Tensor): x: OptPairTensor = (x, x) # Node and edge feature dimensionalites need to match. if isinstance(edge_index, Tensor): assert edge_attr is not None assert x[0].size(-1) == edge_attr.size(-1) elif isinstance(edge_index, SparseTensor): assert x[0].size(-1) == edge_index.size(-1) # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) x_r = x[1] if x_r is not None: out += (1 + self.eps) * x_r return self.nn(t, out)
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, edge_weight=None) -> Tensor: """""" if isinstance(x, Tensor): x: OptPairTensor = (x, x) # propagate_type: (x: OptPairTensor) # start of modified code ######### if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=x[0].dtype, device=edge_index.device) out = self.propagate(edge_index, x=x, size=size, edge_weight=edge_weight) # start of modified code ######### out = self.lin_l(out) x_r = x[1] if x_r is not None: out += self.lin_r(x_r) if self.normalize: out = F.normalize(out, p=2., dim=-1) return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, return_attention_weights=None): r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_r = self.lin_l(x).view(-1, H, C) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = self.lin_l(x_l).view(-1, H, C) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) assert x_l is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) out = self.propagate(edge_index, x=(x_l, x_r), size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor = None, return_attention_weights=None): # type: (Union[Tensor, PairTensor], Tensor, OptTensor, NoneType) -> Tensor # noqa # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, NoneType) -> Tensor # noqa # type: (Union[Tensor, PairTensor], Tensor, OptTensor, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels if isinstance(x, Tensor): x: PairTensor = (x, x) query = self.lin_query(x[1]).view(-1, H, C) key = self.lin_key(x[0]).view(-1, H, C) value = self.lin_value(x[0]).view(-1, H, C) # propagate_type: (query: Tensor, key:Tensor, value: Tensor, edge_attr: OptTensor) # noqa out = self.propagate(edge_index, query=query, key=key, value=value, edge_attr=edge_attr, size=None) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.root_weight: x_r = self.lin_skip(x[1]) if self.lin_beta is not None: beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1)) beta = beta.sigmoid() out = beta * x_r + (1 - beta) * out else: out += x_r if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: """""" if self.normalize: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if cache is None: edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) if self.cached: self._cached_edge_index = (edge_index, edge_weight) else: edge_index, edge_weight = cache[0], cache[1] elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache h = x for k in range(self.K): if self.dropout > 0 and self.training: if isinstance(edge_index, Tensor): assert edge_weight is not None edge_weight = F.dropout(edge_weight, p=self.dropout) else: value = edge_index.storage.value() assert value is not None value = F.dropout(value, p=self.dropout) edge_index = edge_index.set_value(value, layout='coo') # propagate_type: (x: Tensor, edge_weight: OptTensor) x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) x = x * (1 - self.alpha) x += self.alpha * h return x
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, return_attention_weights=None): H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None alpha_l: OptTensor = None alpha_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_r = x.view(-1, H, C) alpha_l = (x_l*self.att_l).sum(dim=-1) alpha_r = (x_r*self.att_r).sum(dim=-1) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_l.view(-1, H, C) alpha_l = (x_l*self.att_l).sum(dim=-1) if x_r is not None: x_r = x_r.view(-1, H, C) alpha_r = (x_r*self.att_r).sum(dim=-1) assert x_l is not None assert alpha_l is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def to_csc( adj: Adj, edge_attr: EdgeAttr, device: Optional[torch.device] = None, share_memory: bool = False, ) -> Tuple[Tensor, Tensor, OptTensor]: # Convert the graph data into a suitable format for sampling (CSC format). # Returns the `colptr` and `row` indices of the graph, as well as an # `perm` vector that denotes the permutation of edges. # Since no permutation of edges is applied when using `SparseTensor`, # `perm` can be of type `None`. perm: Optional[Tensor] = None layout = edge_attr.layout is_sorted = edge_attr.is_sorted size = edge_attr.size if layout == EdgeLayout.CSR: colptr, row, _ = adj.csc() elif layout == EdgeLayout.CSC: colptr, row, _ = adj.csr() else: if size is None: raise ValueError( f"Edge {edge_attr.edge_type} cannot be converted " f"to a different type without specifying 'size' for " f"the source and destination node types (got {size}). " f"Please specify these parameters for successful execution. ") (row, col) = adj if not is_sorted: perm = (col * size[0]).add_(row).argsort() row = row[perm] colptr = torch.ops.torch_sparse.ind2ptr(col[perm], size[1]) colptr = colptr.to(device) row = row.to(device) perm = perm.to(device) if perm is not None else None if not colptr.is_cuda and share_memory: colptr.share_memory_() row.share_memory_() if perm is not None: perm.share_memory_() return colptr, row, perm
def homophily(edge_index: Adj, y: Tensor, method: str = 'edge'): r"""The homophily of a graph characterizes how likely nodes with the same label are near each other in a graph. There are many measures of homophily that fits this definition. In particular: - In the `"Beyond Homophily in Graph Neural Networks: Current Limitations and Effective Designs" <https://arxiv.org/abs/2006.11468>`_ paper, the homophily is the fraction of edges in a graph which connects nodes that have the same class label: .. math:: \text{homophily} = \frac{| \{ (v,w) : (v,w) \in \mathcal{E} \wedge y_v = y_w \} | } {|\mathcal{E}|} That measure is called the *edge homophily ratio*. - In the `"Geom-GCN: Geometric Graph Convolutional Networks" <https://arxiv.org/abs/2002.05287>`_ paper, edge homophily is normalized across neighborhoods: .. math:: \text{homophily} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \frac{ | \{ (w,v) : w \in \mathcal{N}(v) \wedge y_v = y_w \} | } { |\mathcal{N}(v)| } That measure is called the *node homophily ratio*. Args: edge_index (Tensor or SparseTensor): The graph connectivity. y (Tensor): The labels. method (str, optional): The method used to calculate the homophily, either :obj:`"edge"` (first formula) or :obj:`"node"` (second formula). (default: :obj:`"edge"`) """ assert method in ['edge', 'node'] y = y.squeeze(-1) if y.dim() > 1 else y if isinstance(edge_index, SparseTensor): col, row, _ = edge_index.coo() else: row, col = edge_index if method == 'edge': return int((y[row] == y[col]).sum()) / row.size(0) else: out = torch.zeros_like(row, dtype=float) out[y[row] == y[col]] = 1. out = scatter_mean(out, col, 0, dim_size=y.size(0)) return float(out.mean())
def forward(self, x: Tensor, edge_index: Adj, edge_type: OptTensor = None, edge_attr: OptTensor = None, size: Size = None, return_attention_weights=None): r""" Args: x (Tensor): The input node features. Can be either a :obj:`[num_nodes, in_channels]` node feature matrix, or an optional one-dimensional node index tensor (in which case input features are treated as trainable node embeddings). edge_index (LongTensor or SparseTensor): The edge indices. edge_type: The one-dimensional relation type/index for each edge in :obj:`edge_index`. Should be only :obj:`None` in case :obj:`edge_index` is of type :class:`torch_sparse.tensor.SparseTensor`. (default: :obj:`None`) edge_attr (Tensor, optional): Edge feature matrix. (default: :obj:`None`) return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ # propagate_type: (x: Tensor, edge_type: OptTensor, edge_attr: OptTensor) # noqa out = self.propagate(edge_index=edge_index, edge_type=edge_type, x=x, size=size, edge_attr=edge_attr) alpha = self._alpha assert alpha is not None self._alpha = None if isinstance(return_attention_weights, bool): if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Tensor, edge_index: Adj) -> Tensor: """""" row, col, edge_attr = edge_index.coo() edge_index = torch.stack([row, col], dim=0) # calculate edge weights [edge_attr -> edge_weight] encodings = F.one_hot(torch.flatten(edge_attr).long(), num_classes=self.num_edge_types).double() alpha_m = (encodings * self.att_m).sum(dim=-1).reshape(-1, 1) alpha_l = (x * self.att_l).sum(dim=-1).reshape(-1, 1) alpha_r = (x * self.att_r).sum(dim=-1).reshape(-1, 1) # propagate_type: (x: Tensor, x_norm: Tensor) return self.propagate(edge_index, x=x, alpha=(alpha_l, alpha_r), alpha_m=alpha_m, size=None)
def forward(self, x: Union[Tensor, List[Tensor]], edge_index: Adj, edge_attr: OptTensor = None, size: Union[None, List[int]] = None) -> Tensor: if edge_index.nelement() == 0: out = torch.zeros([ x[1].shape[0], (len(self.aggregators) * len(self.scalers)) * self.F_in[0] ]).to(device) else: out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) out = torch.cat([x[1], out], dim=-1) out = self.post_nns(out) return self.lin(out)
def message(self, x_i, x_j, edge_index: Adj, edge_attr: OptTensor, size) -> Tensor: if self.debug: print('a x_j:', x_j.shape, 'x_i:', x_i.shape, 'edge_attr:', edge_attr.shape) if self.step == 0: x_i = F.leaky_relu(self.atom_fc(x_i)) # code 3 # neighbor_feature => neighbor_fc x_j = torch.cat([x_j, edge_attr], dim=-1) # code 8 if self.debug: print('b neighbor_feature i = 0', x_j.shape) x_j = F.leaky_relu(self.neighbor_fc(x_j)) # code 9 if self.debug: print('c neighbor_feature i = 0', x_j.shape) # align score evu = F.leaky_relu(self.align(torch.cat([x_i, x_j], dim=-1))) # code 10 if self.debug: print('d align_score:', evu.shape) avu = EdgePooling.compute_edge_score_softmax(evu, edge_index, edge_index.max().item() + 1) # code 11 if self.debug: print('e attention_weight:', avu.shape) # to do downscaling 200 fp => 32 c_i = F.elu(torch.mul(avu, self.attend(self.dropout(x_i)))) # code 12 # to do upscaling 32 => 200 fp if self.debug: print('f context', c_i.shape) x_i = self.rnn(c_i, x_i) if self.debug: print('g gru', c_i.shape) return x_i
def forward(self, x: Tensor, edge_index: Adj) -> Tuple[Tensor, SparseTensor]: adj_t: Optional[SparseTensor] = None if isinstance(edge_index, Tensor): adj_t = SparseTensor(row=edge_index[1], col=edge_index[0], sparse_sizes=(x.size(0), x.size(0))) elif isinstance(edge_index, SparseTensor): adj_t = edge_index.set_value(None) assert adj_t is not None adj_t = self.panentropy(adj_t, dtype=x.dtype) deg = adj_t.storage.rowcount().to(x.dtype) deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0. M = deg_inv_sqrt.view(1, -1) * adj_t * deg_inv_sqrt.view(-1, 1) out = self.propagate(M, x=x, edge_weight=None, size=None) out = self.lin(out) return out, M
def forward(self, edge_index: Adj, edge_label_index: OptTensor = None) -> Tensor: r"""Computes rankings for pairs of nodes. Args: edge_index (Tensor or SparseTensor): Edge tensor specifying the connectivity of the graph. edge_label_index (Tensor, optional): Edge tensor specifying the node pairs for which to compute rankings or probabilities. If :obj:`edge_label_index` is set to :obj:`None`, all edges in :obj:`edge_index` will be used instead. (default: :obj:`None`) """ if edge_label_index is None: if isinstance(edge_index, SparseTensor): edge_label_index = torch.stack(edge_index.coo()[:2], dim=0) else: edge_label_index = edge_index out = self.get_embedding(edge_index) out_src = out[edge_label_index[0]] out_dst = out[edge_label_index[1]] return (out_src * out_dst).sum(dim=-1)
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None, return_attention_weights=None): # type: (Union[Tensor, OptPairTensor], Tensor, OptTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, OptTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], Tensor, OptTensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, OptTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ # NOTE: attention weights will be returned whenever # `return_attention_weights` is set to a value, regardless of its # actual value (might be `True` or `False`). This is a current somewhat # hacky workaround to allow for TorchScript support via the # `torch.jit._overload` decorator, as we can only change the output # arguments conditioned on type (`None` or `bool`), not based on its # actual value. H, C = self.heads, self.out_channels # We first transform the input node features. If a tuple is passed, we # transform source and target node features via separate weights: if isinstance(x, Tensor): assert x.dim() == 2, "Static graphs not supported in 'GATConv'" x_src = x_dst = self.lin_src(x).view(-1, H, C) else: # Tuple of source and target node features: x_src, x_dst = x assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" x_src = self.lin_src(x_src).view(-1, H, C) if x_dst is not None: x_dst = self.lin_dst(x_dst).view(-1, H, C) x = (x_src, x_dst) # Next, we compute node-level attention coefficients, both for source # and target nodes (if present): alpha_src = (x_src * self.att_src).sum(dim=-1) alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) alpha = (alpha_src, alpha_dst) if self.add_self_loops: if isinstance(edge_index, Tensor): # We only want to add self-loops for nodes that appear both as # source and target nodes: num_nodes = x_src.size(0) if x_dst is not None: num_nodes = min(num_nodes, x_dst.size(0)) num_nodes = min(size) if size is not None else num_nodes edge_index, edge_attr = remove_self_loops( edge_index, edge_attr) edge_index, edge_attr = add_self_loops( edge_index, edge_attr, fill_value=self.fill_value, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): if self.edge_dim is None: edge_index = set_diag(edge_index) else: raise NotImplementedError( "The usage of 'edge_attr' and 'add_self_loops' " "simultaneously is currently not yet supported for " "'edge_index' in a 'SparseTensor' form") # propagate_type: (x: OptPairTensor, alpha: OptPairTensor, edge_attr: OptTensor) # noqa out = self.propagate(edge_index, x=x, alpha=alpha, edge_attr=edge_attr, size=size) alpha = self._alpha assert alpha is not None self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, return_attention_weights=None, layer=None): # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None alpha_l: OptTensor = None alpha_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_r = self.lin_l(x).view(-1, H, C) alpha_l = (x_l * self.att_l).sum(dim=-1) alpha_r = (x_r * self.att_r).sum(dim=-1) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = self.lin_l(x_l).view(-1, H, C) alpha_l = (x_l * self.att_l).sum(dim=-1) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) alpha_r = (x_r * self.att_r).sum(dim=-1) assert x_l is not None assert alpha_l is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) if layer != None and (layer > 1): # propagate_type: (x: Tensor, edge_weight: OptTensor) L = math.floor(self.out_channels / layer) parts = [] for i in range(layer - 1): parts.append(L) parts.append(x_l.shape[-1] - L * (layer - 1)) parts = tuple(parts) xl = x_l.split(parts, dim=-1) xr = x_r.split(parts, dim=-1) out = self.propagate(edge_index, x=(xl[0], xr[0]), alpha=(alpha_l, alpha_r), size=size) for i in range(1, layer): out1 = self.propagate(edge_index, x=(xl[i], xr[i]), alpha=(alpha_l, alpha_r), size=size) out = torch.cat((out, out1), dim=-1) else: out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, return_attention_weights=None, edge_weight=None): # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None alpha_l: OptTensor = None alpha_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_r = self.lin_l(x).view(-1, H, C) alpha_l = alpha_r = (x_l * self.att_l).sum(dim=-1) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = self.lin_l(x_l).view(-1, H, C) alpha_l = (x_l * self.att_l).sum(dim=-1) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) alpha_r = (x_r * self.att_r).sum(dim=-1) assert x_l is not None assert alpha_l is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) num_nodes = size[1] if size is not None else num_nodes num_nodes = x_r.size(0) if x_r is not None else num_nodes edge_index, edge_weight = remove_self_loops( edge_index, edge_attr=edge_weight) edge_index, edge_weight = add_self_loops( edge_index, edge_weight=edge_weight, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) # This is the modified part: if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=x.dtype, device=edge_index.device) out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size, edge_weight=edge_weight) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None, return_attention_weights: bool = None): # type: (Union[Tensor, PairTensor], Tensor, OptTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, PairTensor], Tensor, OptTensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2 x_l = self.lin_l(x).view(-1, H, C) if self.share_weights: x_r = x_l else: x_r = self.lin_r(x).view(-1, H, C) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2 x_l = self.lin_l(x_l).view(-1, H, C) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) assert x_l is not None assert x_r is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, edge_attr = remove_self_loops( edge_index, edge_attr) edge_index, edge_attr = add_self_loops( edge_index, edge_attr, fill_value=self.fill_value, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): if self.edge_dim is None: edge_index = set_diag(edge_index) else: raise NotImplementedError( "The usage of 'edge_attr' and 'add_self_loops' " "simultaneously is currently not yet supported for " "'edge_index' in a 'SparseTensor' form") # propagate_type: (x: PairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=(x_l, x_r), edge_attr=edge_attr, size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, return_attention_weights=None): # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None alpha_l: OptTensor = None alpha_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_r = self.lin_l(x).view(-1, H, C) alpha_l = (x_l * self.att_l).sum( dim=-1 ) # (num_nodes, heads, out_channels) (1, heads, out_channels) alpha_r = (x_r * self.att_r).sum(dim=-1) if not self.gi_by_gcn: alpha_gi = (x_r * self.global_importance_scorer).sum(dim=-1) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = self.lin_l(x_l).view(-1, H, C) alpha_l = (x_l * self.att_l).sum(dim=-1) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) alpha_r = (x_r * self.att_r).sum(dim=-1) # if not self.gi_by_gcn: # alpha_gi = (x_r * self.global_importance_scorer).sum(dim=-1) assert x_l is not None assert alpha_l is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) if self.gi_by_gcn: alpha_gi = self.global_importance_scorer( x_r, edge_index) # num_nodes, num_heads # A_ij + GI_j 因此GI 应当像alpha_j 一样扩展, 我在思考 alpha_i 和alpha_j 的得到方式是不是可以通过这个思路 # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size) # base class 的propagate 调用了message() alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def homophily(edge_index: Adj, y: Tensor, batch: OptTensor = None, method: str = 'edge') -> Union[float, Tensor]: r"""The homophily of a graph characterizes how likely nodes with the same label are near each other in a graph. There are many measures of homophily that fits this definition. In particular: - In the `"Beyond Homophily in Graph Neural Networks: Current Limitations and Effective Designs" <https://arxiv.org/abs/2006.11468>`_ paper, the homophily is the fraction of edges in a graph which connects nodes that have the same class label: .. math:: \frac{| \{ (v,w) : (v,w) \in \mathcal{E} \wedge y_v = y_w \} | } {|\mathcal{E}|} That measure is called the *edge homophily ratio*. - In the `"Geom-GCN: Geometric Graph Convolutional Networks" <https://arxiv.org/abs/2002.05287>`_ paper, edge homophily is normalized across neighborhoods: .. math:: \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \frac{ | \{ (w,v) : w \in \mathcal{N}(v) \wedge y_v = y_w \} | } { |\mathcal{N}(v)| } That measure is called the *node homophily ratio*. - In the `"Large-Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods" <https://arxiv.org/abs/2110.14446>`_ paper, edge homophily is modified to be insensitive to the number of classes and size of each class: .. math:: \frac{1}{C-1} \sum_{k=1}^{C} \max \left(0, h_k - \frac{|\mathcal{C}_k|} {|\mathcal{V}|} \right), where :math:`C` denotes the number of classes, :math:`|\mathcal{C}_k|` denotes the number of nodes of class :math:`k`, and :math:`h_k` denotes the edge homophily ratio of nodes of class :math:`k`. Thus, that measure is called the *class insensitive edge homophily ratio*. Args: edge_index (Tensor or SparseTensor): The graph connectivity. y (Tensor): The labels. batch (LongTensor, optional): Batch vector\ :math:`\mathbf{b} \in {\{ 0, \ldots,B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) method (str, optional): The method used to calculate the homophily, either :obj:`"edge"` (first formula), :obj:`"node"` (second formula) or :obj:`"edge_insensitive"` (third formula). (default: :obj:`"edge"`) """ assert method in {'edge', 'node', 'edge_insensitive'} y = y.squeeze(-1) if y.dim() > 1 else y if isinstance(edge_index, SparseTensor): row, col, _ = edge_index.coo() else: row, col = edge_index if method == 'edge': out = torch.zeros(row.size(0), device=row.device) out[y[row] == y[col]] = 1. if batch is None: return float(out.mean()) else: dim_size = int(batch.max()) + 1 return scatter_mean(out, batch[col], dim=0, dim_size=dim_size) elif method == 'node': out = torch.zeros(row.size(0), device=row.device) out[y[row] == y[col]] = 1. out = scatter_mean(out, col, 0, dim_size=y.size(0)) if batch is None: return float(out.mean()) else: return scatter_mean(out, batch, dim=0) elif method == 'edge_insensitive': assert y.dim() == 1 num_classes = int(y.max()) + 1 assert num_classes >= 2 batch = torch.zeros_like(y) if batch is None else batch num_nodes = degree(batch, dtype=torch.int64) num_graphs = num_nodes.numel() batch = num_classes * batch + y h = homophily(edge_index, y, batch, method='edge') h = h.view(num_graphs, num_classes) counts = batch.bincount(minlength=num_classes * num_graphs) counts = counts.view(num_graphs, num_classes) proportions = counts / num_nodes.view(-1, 1) out = (h - proportions).clamp_(min=0).sum(dim=-1) out /= num_classes - 1 return out if out.numel() > 1 else float(out) else: raise NotImplementedError
def forward(self, x: Tensor, x_0: Tensor, edge_index: Adj, edge_weight: OptTensor = None, return_attention_weights=None): # type: (Tensor, Tensor, Tensor, OptTensor, NoneType) -> Tensor # noqa # type: (Tensor, Tensor, SparseTensor, OptTensor, NoneType) -> Tensor # noqa # type: (Tensor, Tensor, Tensor, OptTensor, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Tensor, Tensor, SparseTensor, OptTensor, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ if self.normalize: if isinstance(edge_index, Tensor): assert edge_weight is None cache = self._cached_edge_index if cache is None: edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, None, x.size(self.node_dim), False, self.add_self_loops, dtype=x.dtype) if self.cached: self._cached_edge_index = (edge_index, edge_weight) else: edge_index, edge_weight = cache[0], cache[1] elif isinstance(edge_index, SparseTensor): assert not edge_index.has_value() cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, None, x.size(self.node_dim), False, self.add_self_loops, dtype=x.dtype) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache else: if isinstance(edge_index, Tensor): assert edge_weight is not None elif isinstance(edge_index, SparseTensor): assert edge_index.has_value() alpha_l = self.att_l(x) alpha_r = self.att_r(x) # propagate_type: (x: Tensor, alpha: PairTensor, edge_weight: OptTensor) # noqa out = self.propagate(edge_index, x=x, alpha=(alpha_l, alpha_r), edge_weight=edge_weight, size=None) alpha = self._alpha self._alpha = None if self.eps != 0.0: out += self.eps * x_0 if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, PairTensor], edge_attr: Tensor, edge_index: Adj, size: Size = None, return_attention_weights: bool = None, propagate_messages: bool = True): # type: (Union[Tensor, PairTensor], Tensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, PairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, PairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, PairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C, E = self.heads, self.node_out_channels, self.edge_out_channels x_l: OptTensor = None x_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2 x_l = self.lin_l(x).view(-1, H, C) if self.share_weights: x_r = x_l else: x_r = self.lin_r(x).view(-1, H, C) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2 x_l = self.lin_l(x_l).view(-1, H, C) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) assert x_l is not None assert x_r is not None edge_attr = self.lin_e(edge_attr).view(-1, H, E) if propagate_messages: # propagate_type: (x: PairTensor) out = self.propagate(edge_index, x=(x_l, x_r), edge_ij=edge_attr, size=size) else: out = x_l alpha = self._alpha self._alpha = None if self.concat: if self.pna and not propagate_messages: out = out.repeat(1, 1, (len(self.aggregators) * len(self.scalers))) out = out.view(-1, self.heads * self.node_pna_channels) out = self.node_lin_out(out) edge_attr = edge_attr.view(-1, self.heads * self.edge_out_channels) edge_attr = self.edge_lin_out(edge_attr) else: out = out.mean(dim=1) edge_attr = edge_attr.mean(dim=1) if self.node_bias is not None: out += self.node_bias if self.edge_bias is not None: out += self.edge_bias if self.pna: out = self.node_lin_out(out) if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, edge_attr, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_attr, edge_index.set_value(alpha, layout='coo') else: return out, edge_attr
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, cmd, batch, edge_attr=None, size: Size = None, return_attention_weights=None): # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None alpha_l: OptTensor = None alpha_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' # print(x.shape) x_l = self.lin_l(x).view(-1, H, C) x_r = self.lin_r(x).view(-1, H, C) # print(x_l.shape) # sleep proj_cmd = self.proj_cmd(cmd) # (bs, H * C) cal_cmd = self.cal_cmd(cmd) # (bs, H * C) batch_onehot = F.one_hot(batch).float() # (num_node, bs) # print(batch_onehot.shape, proj_cmd.shape) proj_cmd = batch_onehot.matmul(proj_cmd).view(-1, H, C) # (num_node, H, C) cal_cmd = batch_onehot.matmul(cal_cmd).view(-1, H, C) # (num_node, H, C) x_mul = proj_cmd * x_r alpha_l = x_l alpha_r = x_mul # else: # x_l, x_r = x[0], x[1] # assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' # x_l = self.lin_l(x_l).view(-1, H, C) # alpha_l = (x_l * self.att_l).sum(dim=-1) # if x_r is not None: # x_r = self.lin_r(x_r).view(-1, H, C) # alpha_r = (x_r * self.att_r).sum(dim=-1) assert x_l is not None assert alpha_l is not None # # for edge features: # e = self.lin_e(edge_attr).view(-1, H, C) # alpha_e = (e * self.att_e).sum(dim=-1) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) out = self.propagate(edge_index, x=x, cmd=cmd, cal_cmd=(cal_cmd, cal_cmd), alpha=(alpha_l, alpha_r), size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out