Ejemplo n.º 1
0
    def forward(self, x, adj, edge_type):
        """Forward-passing

        Parameters
        ----------
        x : torch.Tensor
            The node features
        adj : torch_sparse.SparseMatrix
            Graph's adjacency matrix
        edge_type : torch.LongTensor
            The types of the edges

        Returns
        -------
        torch.Tensor
            The new node features
        """
        # adj normalization. Does not use edge classes!
        out = x @ self.root + self.bias
        for i in range(self.num_relations):
            mask = edge_type == i
            if mask.sum() > 0:
                tmp = masked_select_nnz(adj, mask, layout="coo")
                h = self.propagate(tmp, x=x, size=(x.size(-2), x.size(-2)))
                out = out + h @ self.weight[i]
        return out
Ejemplo n.º 2
0
    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
                edge_type: OptTensor = None) -> Tensor:
        """"""
        if isinstance(x, Tensor):
            x: PairTensor = (x, x)

        beta, gamma = self.film_skip(x[1]).split(self.out_channels, dim=-1)
        out = gamma * self.lin_skip(x[1]) + beta
        if self.act is not None:
            out = self.act(out)

        # propagate_type: (x: Tensor, beta: Tensor, gamma: Tensor)
        if self.num_relations <= 1:
            beta, gamma = self.films[0](x[1]).split(self.out_channels, dim=-1)
            out = out + self.propagate(edge_index, x=self.lins[0](x[0]),
                                       beta=beta, gamma=gamma, size=None)
        else:
            for i, (lin, film) in enumerate(zip(self.lins, self.films)):
                beta, gamma = film(x[1]).split(self.out_channels, dim=-1)
                if isinstance(edge_index, SparseTensor):
                    edge_type = edge_index.storage.value()
                    assert edge_type is not None
                    mask = edge_type == i
                    out = out + self.propagate(
                        masked_select_nnz(edge_index, mask, layout='coo'),
                        x=lin(x[0]), beta=beta, gamma=gamma, size=None)
                else:
                    assert edge_type is not None
                    mask = edge_type == i
                    out = out + self.propagate(edge_index[:, mask], x=lin(
                        x[0]), beta=beta, gamma=gamma, size=None)

        return out
Ejemplo n.º 3
0
def masked_edge_index(edge_index, edge_mask):
    if isinstance(edge_index, Tensor):
        return edge_index[:, edge_mask]
    else:
        return masked_select_nnz(edge_index, edge_mask, layout='coo')