def forward(self, x: Tensor, edge_index: Adj) -> Tensor: """""" edge_weight: OptTensor = None if isinstance(edge_index, Tensor): num_nodes = x.size(self.node_dim) if self.add_self_loops: edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) row, col = edge_index[0], edge_index[1] deg_inv = 1. / degree(col, num_nodes=num_nodes).clamp_(1.) edge_weight = deg_inv[col] edge_weight[row == col] += self.diag_lambda * deg_inv elif isinstance(edge_index, SparseTensor): if self.add_self_loops: edge_index = set_diag(edge_index) col, row, _ = edge_index.coo() # Transposed. deg_inv = 1. / sparsesum(edge_index, dim=1).clamp_(1.) edge_weight = deg_inv[col] edge_weight[row == col] += self.diag_lambda * deg_inv edge_index = edge_index.set_value(edge_weight, layout='coo') # propagate_type: (x: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) out = self.lin_out(out) + self.lin_root(x) return out
def forward(self, x: Union[OptTensor, PairOptTensor], pos: Union[Tensor, PairTensor], edge_index: Adj) -> Tensor: """""" if not isinstance(x, tuple): x: PairOptTensor = (x, None) if isinstance(pos, Tensor): pos: PairTensor = (pos, pos) if self.add_self_loops: if isinstance(edge_index, Tensor): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=min( pos[0].size(0), pos[1].size(0))) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: PairOptTensor, pos: PairTensor) out = self.propagate(edge_index, x=x, pos=pos, size=None) if self.global_nn is not None: out = self.global_nn(out) return out
def forward(self, x, edge_index, size=None): x_l = x_r = x alpha_l = (x_l * self.att_l).sum(dim=-1) alpha_r = (x_r * self.att_r).sum(dim=-1) # print(alpha_l.shape) # print(alpha_r.shape) if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) _, col = edge_index[0], edge_index[1] self.degree = degree(col) theta = self.get_relu_coefs(x, edge_index) self.theta = theta.view(-1, self.channels, 2 * self.k) out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size) return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj): H, C = self.heads, self.out_channels x = self.lin(x).view(-1, H, C) alpha = (x * self.att).sum(dim=-1) if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x.size(0) edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) out = self.propagate(edge_index, x=x, alpha=alpha) if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias return out
def forward( self, x: Union[Tensor, PairTensor], pos: Union[Tensor, PairTensor], edge_index: Adj, ) -> Tensor: """""" if isinstance(x, Tensor): alpha = (self.lin_src(x), self.lin_dst(x)) x: PairTensor = (self.lin(x), x) else: alpha = (self.lin_src(x[0]), self.lin_dst(x[1])) x = (self.lin(x[0]), x[1]) if isinstance(pos, Tensor): pos: PairTensor = (pos, pos) if self.add_self_loops: if isinstance(edge_index, Tensor): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=min( pos[0].size(0), pos[1].size(0))) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: PairTensor, pos: PairTensor, alpha: PairTensor) out = self.propagate(edge_index, x=x, pos=pos, alpha=alpha, size=None) return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, return_attention_weights=None): r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_r = self.lin_l(x).view(-1, H, C) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = self.lin_l(x_l).view(-1, H, C) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) assert x_l is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) out = self.propagate(edge_index, x=(x_l, x_r), size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, return_attention_weights=None): H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None alpha_l: OptTensor = None alpha_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_r = x.view(-1, H, C) alpha_l = (x_l*self.att_l).sum(dim=-1) alpha_r = (x_r*self.att_r).sum(dim=-1) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_l.view(-1, H, C) alpha_l = (x_l*self.att_l).sum(dim=-1) if x_r is not None: x_r = x_r.view(-1, H, C) alpha_r = (x_r*self.att_r).sum(dim=-1) assert x_l is not None assert alpha_l is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Tensor, edge_index: Adj) -> Tensor: """""" if self.add_self_loops: if isinstance(edge_index, Tensor): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(self.node_dim)) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) x_norm = F.normalize(x, p=2., dim=-1) # propagate_type: (x: Tensor, x_norm: Tensor) return self.propagate(edge_index, x=x, x_norm=x_norm, size=None)
def forward(self, x: Union[Tensor, PairTensor], pos: Union[Tensor, PairTensor], normal: Union[Tensor, PairTensor], edge_index: Adj, faces: Optional[Tensor] = None, size: Optional[int] = None) -> Tensor: """""" # propagate_type: (x: PairTensor, edge_attr: Tensor) # noqa if not isinstance(x, tuple): x: PairTensor = (x, x) if self.add_self_loops: if isinstance(edge_index, Tensor): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=pos[1].size(0)) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) if self.use_edge_features: # get edge features if self.use_geodesic_distance: edge_attr = edge_pair_features(edge_index, pos, normal, faces=faces, normalize_gd=self.normalize_gd) else: edge_attr = edge_pair_features(edge_index, pos, normal) if self.edge_bn is not None: edge_attr = self.edge_bn(edge_attr) else: edge_attr = None # do message-passing out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None) if self.weighted_average: out = out[:, :-1] / (out[:, -1].view(-1, 1) + 1e-8) if self.global_nn is not None: #out = self.global_nn(torch.cat([x[0], out], dim=1)) out = self.global_nn(out) return out
def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: if self.add_self_loops: if isinstance(edge_index, Tensor): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(self.node_dim)) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) x = self.lin(x) # propagate_type: (x: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) if self.bias is not None: out += self.bias return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, return_attention_weights=False): # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ assert return_attention_weights == False, "Returning attention weights not supported. " H, C = self.numheads, self.size_per_head assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_r = self.lin(x).view(x.size(0), H, C) if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) num_nodes = size[1] if size is not None else num_nodes num_nodes = x_r.size(0) if x_r is not None else num_nodes edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) out = self.propagate(edge_index, x=(x_l, x_r), size=size) out = out.view(-1, self.heads * self.size_per_head) if self.bias is not None: out += self.bias return out
def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj) -> Tensor: """""" if isinstance(x, Tensor): x: PairTensor = (x, x) if self.add_self_loops: if isinstance(edge_index, Tensor): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=x[1].size(0)) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: PairTensor) out = self.propagate(edge_index, x=x, size=None) if self.bias is not None: out += self.bias return out
def forward(self, x: Union[OptTensor, PairOptTensor], pos: Union[Tensor, PairTensor], normal: Union[Tensor, PairTensor], edge_index: Adj) -> Tensor: # yapf: disable """""" if not isinstance(x, tuple): x: PairOptTensor = (x, None) if isinstance(pos, Tensor): pos: PairTensor = (pos, pos) if isinstance(normal, Tensor): normal: PairTensor = (normal, normal) if self.add_self_loops: if isinstance(edge_index, Tensor): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=pos[1].size(0)) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: PairOptTensor, pos: PairTensor, normal: PairTensor) # noqa out = self.propagate(edge_index, x=x, pos=pos, normal=normal, size=None) s = out.size() if self.normalize: out = out / (out.sum(dim=-1).view(-1, s[1], 1) + self.eps) if self.global_nn is not None: # flatten histogram and perform transform out = self.global_nn(out.view(-1, s[1] * s[2])) return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, return_attention_weights=None): # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels # We first transform the input node features. If a tuple is passed, we # transform source and target node features via separate weights: if isinstance(x, Tensor): assert x.dim() == 2, "Static graphs not supported in 'GATConv'" x_src = x_dst = self.lin_src(x).view(-1, H, C) else: # Tuple of source and target node features: x_src, x_dst = x assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" x_src = self.lin_src(x_src).view(-1, H, C) if x_dst is not None: x_dst = self.lin_dst(x_dst).view(-1, H, C) x = (x_src, x_dst) # Next, we compute node-level attention coefficients, both for source # and target nodes (if present): alpha_src = (x_src * self.att_src).sum(dim=-1) alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) alpha = (alpha_src, alpha_dst) if self.add_self_loops: if isinstance(edge_index, Tensor): # We only want to add self-loops for nodes that appear both as # source and target nodes: num_nodes = x_src.size(0) if x_dst is not None: num_nodes = min(num_nodes, x_dst.size(0)) num_nodes = min(size) if size is not None else num_nodes edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) out = self.propagate(edge_index, x=x, alpha=alpha, size=size) alpha = self._alpha assert alpha is not None self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward( self, x, edge_index, edge_weights: Optional[Tensor] = None, size=None, return_attention_weights=None, ): """Adapted from the PyTorch Geometric `GATConv` `forward` method. See PyTorch Geometric for documentation. """ H, C = self.heads, self.out_channels x_l = None x_r = None alpha_l = None alpha_r = None if isinstance(x, Tensor): assert x.dim() == 2, "Static graphs not supported in `GATConv`." x_l = x_r = self.lin_l(x).view(-1, H, C) alpha_l = (x_l * self.att_l).sum(dim=-1) alpha_r = (x_r * self.att_r).sum(dim=-1) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2, "Static graphs not supported in `GATConv`." x_l = self.lin_l(x_l).view(-1, H, C) alpha_l = (x_l * self.att_l).sum(dim=-1) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) alpha_r = (x_r * self.att_r).sum(dim=-1) assert x_l is not None assert alpha_l is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, edge_weights = remove_self_loops( edge_index, edge_weights) edge_index, edge_weights = add_self_loops(edge_index, edge_weights, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) self.edge_weights = edge_weights out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout="coo") else: return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None, return_attention_weights=None): # type: (Union[Tensor, OptPairTensor], Tensor, OptTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, OptTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], Tensor, OptTensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, OptTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ # NOTE: attention weights will be returned whenever # `return_attention_weights` is set to a value, regardless of its # actual value (might be `True` or `False`). This is a current somewhat # hacky workaround to allow for TorchScript support via the # `torch.jit._overload` decorator, as we can only change the output # arguments conditioned on type (`None` or `bool`), not based on its # actual value. H, C = self.heads, self.out_channels # We first transform the input node features. If a tuple is passed, we # transform source and target node features via separate weights: if isinstance(x, Tensor): assert x.dim() == 2, "Static graphs not supported in 'GATConv'" x_src = x_dst = self.lin_src(x).view(-1, H, C) else: # Tuple of source and target node features: x_src, x_dst = x assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" x_src = self.lin_src(x_src).view(-1, H, C) if x_dst is not None: x_dst = self.lin_dst(x_dst).view(-1, H, C) x = (x_src, x_dst) # Next, we compute node-level attention coefficients, both for source # and target nodes (if present): alpha_src = (x_src * self.att_src).sum(dim=-1) alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) alpha = (alpha_src, alpha_dst) if self.add_self_loops: if isinstance(edge_index, Tensor): # We only want to add self-loops for nodes that appear both as # source and target nodes: num_nodes = x_src.size(0) if x_dst is not None: num_nodes = min(num_nodes, x_dst.size(0)) num_nodes = min(size) if size is not None else num_nodes edge_index, edge_attr = remove_self_loops( edge_index, edge_attr) edge_index, edge_attr = add_self_loops( edge_index, edge_attr, fill_value=self.fill_value, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): if self.edge_dim is None: edge_index = set_diag(edge_index) else: raise NotImplementedError( "The usage of 'edge_attr' and 'add_self_loops' " "simultaneously is currently not yet supported for " "'edge_index' in a 'SparseTensor' form") # propagate_type: (x: OptPairTensor, alpha: OptPairTensor, edge_attr: OptTensor) # noqa out = self.propagate(edge_index, x=x, alpha=alpha, edge_attr=edge_attr, size=size) alpha = self._alpha assert alpha is not None self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, data, size=None, return_attention_weights=None): """ Args: return_attention_weights (bool, optional): If set to 'True', will additionally return the tuple (edge_index, attention_weights), holding the computed attention weights for each edge """ x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr batch = data.batch H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None alpha_l: OptTensor = None # attention coefficients alpha_r: OptTensor = None # node linear transform if isinstance(x, Tensor): # assert x.dim == 2, "Static graphs not supported in 'EdgeGATConv'." x_l = x_r = self.lin_l(x).view(-1, H, C) alpha_l = alpha_r = (x_l * self.att_l).sum(dim=-1) else: x_l, x_r = x[0], x[1] # assert x[0].dim == 2, "Static graphs not supported in 'EdgeGATConv'." x_l = self.lin_l(x_l).view(-1, H, C) alpha_l = (x_l * self.att_l).sum(dim=-1) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) alpha_r = (x_r * self.att_r).sum(dim=-1) assert x_l is not None assert alpha_l is not None # edge propagate edge_attr = self.lin_edge1(edge_attr).view(-1, H, C) edge_attr = self.edge_propagate(edge_index, edge_attr, x_l, size=size) # add self loops if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) num_nodes = size[1] if size is not None else num_nodes num_nodes = x_r.size(0) if x_r is not None else num_nodes edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # node propagate typeL (x: OptPairTensor, alpha: OptPairTensor) out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size) alpha = self._alpha self._alpha = None # mutil-head attention aggregation if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, edge_attr, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_attr, edge_index.set_value(alpha, layout='coo') else: return out, edge_attr
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, return_attention_weights=None, layer=None): # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None alpha_l: OptTensor = None alpha_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_r = self.lin_l(x).view(-1, H, C) alpha_l = (x_l * self.att_l).sum(dim=-1) alpha_r = (x_r * self.att_r).sum(dim=-1) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = self.lin_l(x_l).view(-1, H, C) alpha_l = (x_l * self.att_l).sum(dim=-1) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) alpha_r = (x_r * self.att_r).sum(dim=-1) assert x_l is not None assert alpha_l is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) if layer != None and (layer > 1): # propagate_type: (x: Tensor, edge_weight: OptTensor) L = math.floor(self.out_channels / layer) parts = [] for i in range(layer - 1): parts.append(L) parts.append(x_l.shape[-1] - L * (layer - 1)) parts = tuple(parts) xl = x_l.split(parts, dim=-1) xr = x_r.split(parts, dim=-1) out = self.propagate(edge_index, x=(xl[0], xr[0]), alpha=(alpha_l, alpha_r), size=size) for i in range(1, layer): out1 = self.propagate(edge_index, x=(xl[i], xr[i]), alpha=(alpha_l, alpha_r), size=size) out = torch.cat((out, out1), dim=-1) else: out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, return_attention_weights=None, edge_weight=None): # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None alpha_l: OptTensor = None alpha_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_r = self.lin_l(x).view(-1, H, C) alpha_l = alpha_r = (x_l * self.att_l).sum(dim=-1) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = self.lin_l(x_l).view(-1, H, C) alpha_l = (x_l * self.att_l).sum(dim=-1) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) alpha_r = (x_r * self.att_r).sum(dim=-1) assert x_l is not None assert alpha_l is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) num_nodes = size[1] if size is not None else num_nodes num_nodes = x_r.size(0) if x_r is not None else num_nodes edge_index, edge_weight = remove_self_loops( edge_index, edge_attr=edge_weight) edge_index, edge_weight = add_self_loops( edge_index, edge_weight=edge_weight, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) # This is the modified part: if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=x.dtype, device=edge_index.device) out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size, edge_weight=edge_weight) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None, return_attention_weights: bool = None): # type: (Union[Tensor, PairTensor], Tensor, OptTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, PairTensor], Tensor, OptTensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2 x_l = self.lin_l(x).view(-1, H, C) if self.share_weights: x_r = x_l else: x_r = self.lin_r(x).view(-1, H, C) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2 x_l = self.lin_l(x_l).view(-1, H, C) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) assert x_l is not None assert x_r is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, edge_attr = remove_self_loops( edge_index, edge_attr) edge_index, edge_attr = add_self_loops( edge_index, edge_attr, fill_value=self.fill_value, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): if self.edge_dim is None: edge_index = set_diag(edge_index) else: raise NotImplementedError( "The usage of 'edge_attr' and 'add_self_loops' " "simultaneously is currently not yet supported for " "'edge_index' in a 'SparseTensor' form") # propagate_type: (x: PairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=(x_l, x_r), edge_attr=edge_attr, size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, return_attention_weights=None): # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None alpha_l: OptTensor = None alpha_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = x_r = self.lin_l(x).view(-1, H, C) alpha_l = (x_l * self.att_l).sum( dim=-1 ) # (num_nodes, heads, out_channels) (1, heads, out_channels) alpha_r = (x_r * self.att_r).sum(dim=-1) if not self.gi_by_gcn: alpha_gi = (x_r * self.global_importance_scorer).sum(dim=-1) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' x_l = self.lin_l(x_l).view(-1, H, C) alpha_l = (x_l * self.att_l).sum(dim=-1) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) alpha_r = (x_r * self.att_r).sum(dim=-1) # if not self.gi_by_gcn: # alpha_gi = (x_r * self.global_importance_scorer).sum(dim=-1) assert x_l is not None assert alpha_l is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) if self.gi_by_gcn: alpha_gi = self.global_importance_scorer( x_r, edge_index) # num_nodes, num_heads # A_ij + GI_j 因此GI 应当像alpha_j 一样扩展, 我在思考 alpha_i 和alpha_j 的得到方式是不是可以通过这个思路 # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size) # base class 的propagate 调用了message() alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out