def forward(self, x, adj, edge_type): """Forward-passing Parameters ---------- x : torch.Tensor The node features adj : torch_sparse.SparseMatrix Graph's adjacency matrix edge_type : torch.LongTensor The types of the edges Returns ------- torch.Tensor The new node features """ # adj normalization. Does not use edge classes! out = x @ self.root + self.bias for i in range(self.num_relations): mask = edge_type == i if mask.sum() > 0: tmp = masked_select_nnz(adj, mask, layout="coo") h = self.propagate(tmp, x=x, size=(x.size(-2), x.size(-2))) out = out + h @ self.weight[i] return out
def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_type: OptTensor = None) -> Tensor: """""" if isinstance(x, Tensor): x: PairTensor = (x, x) beta, gamma = self.film_skip(x[1]).split(self.out_channels, dim=-1) out = gamma * self.lin_skip(x[1]) + beta if self.act is not None: out = self.act(out) # propagate_type: (x: Tensor, beta: Tensor, gamma: Tensor) if self.num_relations <= 1: beta, gamma = self.films[0](x[1]).split(self.out_channels, dim=-1) out = out + self.propagate(edge_index, x=self.lins[0](x[0]), beta=beta, gamma=gamma, size=None) else: for i, (lin, film) in enumerate(zip(self.lins, self.films)): beta, gamma = film(x[1]).split(self.out_channels, dim=-1) if isinstance(edge_index, SparseTensor): edge_type = edge_index.storage.value() assert edge_type is not None mask = edge_type == i out = out + self.propagate( masked_select_nnz(edge_index, mask, layout='coo'), x=lin(x[0]), beta=beta, gamma=gamma, size=None) else: assert edge_type is not None mask = edge_type == i out = out + self.propagate(edge_index[:, mask], x=lin( x[0]), beta=beta, gamma=gamma, size=None) return out
def masked_edge_index(edge_index, edge_mask): if isinstance(edge_index, Tensor): return edge_index[:, edge_mask] else: return masked_select_nnz(edge_index, edge_mask, layout='coo')