def forward(self, x: Tensor, edge_index: Adj, edge_attr: OptTensor = None) -> Tensor: """""" if isinstance(edge_index, SparseTensor): edge_attr = edge_index.storage.value() if edge_attr is not None: edge_attr = self.mlp(edge_attr).squeeze(-1) if isinstance(edge_index, SparseTensor): edge_index = edge_index.set_value(edge_attr, layout='coo') if self.normalize: if isinstance(edge_index, Tensor): edge_index, edge_attr = gcn_norm(edge_index, edge_attr, x.size(self.node_dim), False, self.add_self_loops) elif isinstance(edge_index, SparseTensor): edge_index = gcn_norm(edge_index, None, x.size(self.node_dim), False, self.add_self_loops) x = self.lin(x) # propagate_type: (x: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=x, edge_weight=edge_attr, size=None) if self.bias is not None: out += self.bias return out
def forward(self, x: Tensor, edge_index: Adj) -> Tensor: """""" edge_weight: OptTensor = None if isinstance(edge_index, Tensor): num_nodes = x.size(self.node_dim) if self.add_self_loops: edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) row, col = edge_index[0], edge_index[1] deg_inv = 1. / degree(col, num_nodes=num_nodes).clamp_(1.) edge_weight = deg_inv[col] edge_weight[row == col] += self.diag_lambda * deg_inv elif isinstance(edge_index, SparseTensor): if self.add_self_loops: edge_index = set_diag(edge_index) col, row, _ = edge_index.coo() # Transposed. deg_inv = 1. / sparsesum(edge_index, dim=1).clamp_(1.) edge_weight = deg_inv[col] edge_weight[row == col] += self.diag_lambda * deg_inv edge_index = edge_index.set_value(edge_weight, layout='coo') # propagate_type: (x: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) out = self.lin_out(out) + self.lin_root(x) return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_type: Optional[torch.Tensor] = None, size: Size = None, return_attention_weights=None): r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x = x.view(-1, self.heads, self.out_channels) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) out = self.propagate(edge_index, x=x, edge_type=edge_type, size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor = None, return_attention_weights=None): # type: (Union[Tensor, PairTensor], Tensor, OptTensor, NoneType) -> Tensor # noqa # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, NoneType) -> Tensor # noqa # type: (Union[Tensor, PairTensor], Tensor, OptTensor, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels if isinstance(x, Tensor): x: PairTensor = (x, x) query = self.lin_query(x[1]).view(-1, H, C) key = self.lin_key(x[0]).view(-1, H, C) value = self.lin_value(x[0]).view(-1, H, C) # propagate_type: (query: Tensor, key:Tensor, value: Tensor, edge_attr: OptTensor) # noqa out = self.propagate(edge_index, query=query, key=key, value=value, edge_attr=edge_attr, size=None) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.root_weight: x_r = self.lin_skip(x[1]) if self.lin_beta is not None: beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1)) beta = beta.sigmoid() out = beta * x_r + (1 - beta) * out else: out += x_r if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, return_attention_weights=None): r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_r = self.lin_l(x).view(-1, H, C) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = self.lin_l(x_l).view(-1, H, C) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) assert x_l is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) out = self.propagate(edge_index, x=(x_l, x_r), size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: """""" if self.normalize: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if cache is None: edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) if self.cached: self._cached_edge_index = (edge_index, edge_weight) else: edge_index, edge_weight = cache[0], cache[1] elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache h = x for k in range(self.K): if self.dropout > 0 and self.training: if isinstance(edge_index, Tensor): assert edge_weight is not None edge_weight = F.dropout(edge_weight, p=self.dropout) else: value = edge_index.storage.value() assert value is not None value = F.dropout(value, p=self.dropout) edge_index = edge_index.set_value(value, layout='coo') # propagate_type: (x: Tensor, edge_weight: OptTensor) x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) x = x * (1 - self.alpha) x += self.alpha * h return x
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, return_attention_weights=None): H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None alpha_l: OptTensor = None alpha_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_r = x.view(-1, H, C) alpha_l = (x_l*self.att_l).sum(dim=-1) alpha_r = (x_r*self.att_r).sum(dim=-1) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_l.view(-1, H, C) alpha_l = (x_l*self.att_l).sum(dim=-1) if x_r is not None: x_r = x_r.view(-1, H, C) alpha_r = (x_r*self.att_r).sum(dim=-1) assert x_l is not None assert alpha_l is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Tensor, edge_index: Adj, edge_type: OptTensor = None, edge_attr: OptTensor = None, size: Size = None, return_attention_weights=None): r""" Args: x (Tensor): The input node features. Can be either a :obj:`[num_nodes, in_channels]` node feature matrix, or an optional one-dimensional node index tensor (in which case input features are treated as trainable node embeddings). edge_index (LongTensor or SparseTensor): The edge indices. edge_type: The one-dimensional relation type/index for each edge in :obj:`edge_index`. Should be only :obj:`None` in case :obj:`edge_index` is of type :class:`torch_sparse.tensor.SparseTensor`. (default: :obj:`None`) edge_attr (Tensor, optional): Edge feature matrix. (default: :obj:`None`) return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ # propagate_type: (x: Tensor, edge_type: OptTensor, edge_attr: OptTensor) # noqa out = self.propagate(edge_index=edge_index, edge_type=edge_type, x=x, size=size, edge_attr=edge_attr) alpha = self._alpha assert alpha is not None self._alpha = None if isinstance(return_attention_weights, bool): if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Tensor, edge_index: Adj) -> Tuple[Tensor, SparseTensor]: adj_t: Optional[SparseTensor] = None if isinstance(edge_index, Tensor): adj_t = SparseTensor(row=edge_index[1], col=edge_index[0], sparse_sizes=(x.size(0), x.size(0))) elif isinstance(edge_index, SparseTensor): adj_t = edge_index.set_value(None) assert adj_t is not None adj_t = self.panentropy(adj_t, dtype=x.dtype) deg = adj_t.storage.rowcount().to(x.dtype) deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0. M = deg_inv_sqrt.view(1, -1) * adj_t * deg_inv_sqrt.view(-1, 1) out = self.propagate(M, x=x, edge_weight=None, size=None) out = self.lin(out) return out, M
def forward(self, x: Union[Tensor, PairTensor], edge_attr: Tensor, edge_index: Adj, size: Size = None, return_attention_weights: bool = None, propagate_messages: bool = True): # type: (Union[Tensor, PairTensor], Tensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, PairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, PairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, PairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C, E = self.heads, self.node_out_channels, self.edge_out_channels x_l: OptTensor = None x_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2 x_l = self.lin_l(x).view(-1, H, C) if self.share_weights: x_r = x_l else: x_r = self.lin_r(x).view(-1, H, C) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2 x_l = self.lin_l(x_l).view(-1, H, C) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) assert x_l is not None assert x_r is not None edge_attr = self.lin_e(edge_attr).view(-1, H, E) if propagate_messages: # propagate_type: (x: PairTensor) out = self.propagate(edge_index, x=(x_l, x_r), edge_ij=edge_attr, size=size) else: out = x_l alpha = self._alpha self._alpha = None if self.concat: if self.pna and not propagate_messages: out = out.repeat(1, 1, (len(self.aggregators) * len(self.scalers))) out = out.view(-1, self.heads * self.node_pna_channels) out = self.node_lin_out(out) edge_attr = edge_attr.view(-1, self.heads * self.edge_out_channels) edge_attr = self.edge_lin_out(edge_attr) else: out = out.mean(dim=1) edge_attr = edge_attr.mean(dim=1) if self.node_bias is not None: out += self.node_bias if self.edge_bias is not None: out += self.edge_bias if self.pna: out = self.node_lin_out(out) if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, edge_attr, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_attr, edge_index.set_value(alpha, layout='coo') else: return out, edge_attr
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr, size: Size = None, return_attention_weights=None): # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None alpha_l: OptTensor = None alpha_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_r = self.lin_l(x).view(-1, H, C) alpha_l = (x_l * self.att_l).sum(dim=-1) alpha_r = (x_r * self.att_r).sum(dim=-1) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = self.lin_l(x_l).view(-1, H, C) alpha_l = (x_l * self.att_l).sum(dim=-1) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) alpha_r = (x_r * self.att_r).sum(dim=-1) assert x_l is not None assert alpha_l is not None # for edge features: e = self.lin_e(edge_attr).view(-1, H, C) alpha_e = (e * self.att_e).sum(dim=-1) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), alpha_e=alpha_e, size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None, return_attention_weights=None): # type: (Union[Tensor, OptPairTensor], Tensor, OptTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, OptTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], Tensor, OptTensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, OptTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ # NOTE: attention weights will be returned whenever # `return_attention_weights` is set to a value, regardless of its # actual value (might be `True` or `False`). This is a current somewhat # hacky workaround to allow for TorchScript support via the # `torch.jit._overload` decorator, as we can only change the output # arguments conditioned on type (`None` or `bool`), not based on its # actual value. H, C = self.heads, self.out_channels # We first transform the input node features. If a tuple is passed, we # transform source and target node features via separate weights: if isinstance(x, Tensor): assert x.dim() == 2, "Static graphs not supported in 'GATConv'" x_src = x_dst = self.lin_src(x).view(-1, H, C) else: # Tuple of source and target node features: x_src, x_dst = x assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" x_src = self.lin_src(x_src).view(-1, H, C) if x_dst is not None: x_dst = self.lin_dst(x_dst).view(-1, H, C) x = (x_src, x_dst) # Next, we compute node-level attention coefficients, both for source # and target nodes (if present): alpha_src = (x_src * self.att_src).sum(dim=-1) alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) alpha = (alpha_src, alpha_dst) if self.add_self_loops: if isinstance(edge_index, Tensor): # We only want to add self-loops for nodes that appear both as # source and target nodes: num_nodes = x_src.size(0) if x_dst is not None: num_nodes = min(num_nodes, x_dst.size(0)) num_nodes = min(size) if size is not None else num_nodes edge_index, edge_attr = remove_self_loops( edge_index, edge_attr) edge_index, edge_attr = add_self_loops( edge_index, edge_attr, fill_value=self.fill_value, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): if self.edge_dim is None: edge_index = set_diag(edge_index) else: raise NotImplementedError( "The usage of 'edge_attr' and 'add_self_loops' " "simultaneously is currently not yet supported for " "'edge_index' in a 'SparseTensor' form") # propagate_type: (x: OptPairTensor, alpha: OptPairTensor, edge_attr: OptTensor) # noqa out = self.propagate(edge_index, x=x, alpha=alpha, edge_attr=edge_attr, size=size) alpha = self._alpha assert alpha is not None self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, return_attention_weights=None, edge_weight=None): # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None alpha_l: OptTensor = None alpha_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_r = self.lin_l(x).view(-1, H, C) alpha_l = alpha_r = (x_l * self.att_l).sum(dim=-1) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = self.lin_l(x_l).view(-1, H, C) alpha_l = (x_l * self.att_l).sum(dim=-1) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) alpha_r = (x_r * self.att_r).sum(dim=-1) assert x_l is not None assert alpha_l is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) num_nodes = size[1] if size is not None else num_nodes num_nodes = x_r.size(0) if x_r is not None else num_nodes edge_index, edge_weight = remove_self_loops( edge_index, edge_attr=edge_weight) edge_index, edge_weight = add_self_loops( edge_index, edge_weight=edge_weight, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) # This is the modified part: if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=x.dtype, device=edge_index.device) out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size, edge_weight=edge_weight) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, return_attention_weights=None, layer=None): # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None alpha_l: OptTensor = None alpha_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_r = self.lin_l(x).view(-1, H, C) alpha_l = (x_l * self.att_l).sum(dim=-1) alpha_r = (x_r * self.att_r).sum(dim=-1) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = self.lin_l(x_l).view(-1, H, C) alpha_l = (x_l * self.att_l).sum(dim=-1) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) alpha_r = (x_r * self.att_r).sum(dim=-1) assert x_l is not None assert alpha_l is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) if layer != None and (layer > 1): # propagate_type: (x: Tensor, edge_weight: OptTensor) L = math.floor(self.out_channels / layer) parts = [] for i in range(layer - 1): parts.append(L) parts.append(x_l.shape[-1] - L * (layer - 1)) parts = tuple(parts) xl = x_l.split(parts, dim=-1) xr = x_r.split(parts, dim=-1) out = self.propagate(edge_index, x=(xl[0], xr[0]), alpha=(alpha_l, alpha_r), size=size) for i in range(1, layer): out1 = self.propagate(edge_index, x=(xl[i], xr[i]), alpha=(alpha_l, alpha_r), size=size) out = torch.cat((out, out1), dim=-1) else: out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Tensor, x_0: Tensor, edge_index: Adj, edge_weight: OptTensor = None, return_attention_weights=None): # type: (Tensor, Tensor, Tensor, OptTensor, NoneType) -> Tensor # noqa # type: (Tensor, Tensor, SparseTensor, OptTensor, NoneType) -> Tensor # noqa # type: (Tensor, Tensor, Tensor, OptTensor, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Tensor, Tensor, SparseTensor, OptTensor, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ if self.normalize: if isinstance(edge_index, Tensor): assert edge_weight is None cache = self._cached_edge_index if cache is None: edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, None, x.size(self.node_dim), False, self.add_self_loops, dtype=x.dtype) if self.cached: self._cached_edge_index = (edge_index, edge_weight) else: edge_index, edge_weight = cache[0], cache[1] elif isinstance(edge_index, SparseTensor): assert not edge_index.has_value() cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, None, x.size(self.node_dim), False, self.add_self_loops, dtype=x.dtype) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache else: if isinstance(edge_index, Tensor): assert edge_weight is not None elif isinstance(edge_index, SparseTensor): assert edge_index.has_value() alpha_l = self.att_l(x) alpha_r = self.att_r(x) # propagate_type: (x: Tensor, alpha: PairTensor, edge_weight: OptTensor) # noqa out = self.propagate(edge_index, x=x, alpha=(alpha_l, alpha_r), edge_weight=edge_weight, size=None) alpha = self._alpha self._alpha = None if self.eps != 0.0: out += self.eps * x_0 if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None, return_attention_weights: bool = None): # type: (Union[Tensor, PairTensor], Tensor, OptTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, PairTensor], Tensor, OptTensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2 x_l = self.lin_l(x).view(-1, H, C) if self.share_weights: x_r = x_l else: x_r = self.lin_r(x).view(-1, H, C) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2 x_l = self.lin_l(x_l).view(-1, H, C) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) assert x_l is not None assert x_r is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, edge_attr = remove_self_loops( edge_index, edge_attr) edge_index, edge_attr = add_self_loops( edge_index, edge_attr, fill_value=self.fill_value, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): if self.edge_dim is None: edge_index = set_diag(edge_index) else: raise NotImplementedError( "The usage of 'edge_attr' and 'add_self_loops' " "simultaneously is currently not yet supported for " "'edge_index' in a 'SparseTensor' form") # propagate_type: (x: PairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=(x_l, x_r), edge_attr=edge_attr, size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, return_attention_weights=None): # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None alpha_l: OptTensor = None alpha_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_r = self.lin_l(x).view(-1, H, C) alpha_l = (x_l * self.att_l).sum( dim=-1 ) # (num_nodes, heads, out_channels) (1, heads, out_channels) alpha_r = (x_r * self.att_r).sum(dim=-1) if not self.gi_by_gcn: alpha_gi = (x_r * self.global_importance_scorer).sum(dim=-1) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = self.lin_l(x_l).view(-1, H, C) alpha_l = (x_l * self.att_l).sum(dim=-1) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) alpha_r = (x_r * self.att_r).sum(dim=-1) # if not self.gi_by_gcn: # alpha_gi = (x_r * self.global_importance_scorer).sum(dim=-1) assert x_l is not None assert alpha_l is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) if self.gi_by_gcn: alpha_gi = self.global_importance_scorer( x_r, edge_index) # num_nodes, num_heads # A_ij + GI_j 因此GI 应当像alpha_j 一样扩展, 我在思考 alpha_i 和alpha_j 的得到方式是不是可以通过这个思路 # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size) # base class 的propagate 调用了message() alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, return_attention_weights=None): # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels # We first transform the input node features. If a tuple is passed, we # transform source and target node features via separate weights: if isinstance(x, Tensor): assert x.dim() == 2, "Static graphs not supported in 'GATConv'" x_src = x_dst = self.lin_src(x).view(-1, H, C) else: # Tuple of source and target node features: x_src, x_dst = x assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" x_src = self.lin_src(x_src).view(-1, H, C) if x_dst is not None: x_dst = self.lin_dst(x_dst).view(-1, H, C) x = (x_src, x_dst) # Next, we compute node-level attention coefficients, both for source # and target nodes (if present): alpha_src = (x_src * self.att_src).sum(dim=-1) alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) alpha = (alpha_src, alpha_dst) if self.add_self_loops: if isinstance(edge_index, Tensor): # We only want to add self-loops for nodes that appear both as # source and target nodes: num_nodes = x_src.size(0) if x_dst is not None: num_nodes = min(num_nodes, x_dst.size(0)) num_nodes = min(size) if size is not None else num_nodes edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) out = self.propagate(edge_index, x=x, alpha=alpha, size=size) alpha = self._alpha assert alpha is not None self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, cmd, batch, edge_attr=None, size: Size = None, return_attention_weights=None): # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None alpha_l: OptTensor = None alpha_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' # print(x.shape) x_l = self.lin_l(x).view(-1, H, C) x_r = self.lin_r(x).view(-1, H, C) # print(x_l.shape) # sleep proj_cmd = self.proj_cmd(cmd) # (bs, H * C) cal_cmd = self.cal_cmd(cmd) # (bs, H * C) batch_onehot = F.one_hot(batch).float() # (num_node, bs) # print(batch_onehot.shape, proj_cmd.shape) proj_cmd = batch_onehot.matmul(proj_cmd).view(-1, H, C) # (num_node, H, C) cal_cmd = batch_onehot.matmul(cal_cmd).view(-1, H, C) # (num_node, H, C) x_mul = proj_cmd * x_r alpha_l = x_l alpha_r = x_mul # else: # x_l, x_r = x[0], x[1] # assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' # x_l = self.lin_l(x_l).view(-1, H, C) # alpha_l = (x_l * self.att_l).sum(dim=-1) # if x_r is not None: # x_r = self.lin_r(x_r).view(-1, H, C) # alpha_r = (x_r * self.att_r).sum(dim=-1) assert x_l is not None assert alpha_l is not None # # for edge features: # e = self.lin_e(edge_attr).view(-1, H, C) # alpha_e = (e * self.att_e).sum(dim=-1) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) out = self.propagate(edge_index, x=x, cmd=cmd, cal_cmd=(cal_cmd, cal_cmd), alpha=(alpha_l, alpha_r), size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out