Ejemplo n.º 1
0
    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
                edge_attr: OptTensor = None, size: Size = None) -> Tensor:
        """"""

        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)

        # Node and edge feature dimensionalites need to match.
        if isinstance(edge_index, Tensor):
            if edge_attr is not None:
                assert x[0].size(-1) == edge_attr.size(-1)
        elif isinstance(edge_index, SparseTensor):
            edge_attr = edge_index.storage.value()
            if edge_attr is not None:
                assert x[0].size(-1) == edge_attr.size(-1)

        # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)

        if self.msg_norm is not None:
            out = self.msg_norm(x[0], out)

        x_r = x[1]
        if x_r is not None:
            out += x_r

        return self.mlp(out)
Ejemplo n.º 2
0
def add_self_loops(
        edge_index: Tensor,
        edge_attr: OptTensor = None,
        fill_value: Union[float, Tensor, str] = None,
        num_nodes: Optional[int] = None) -> Tuple[Tensor, OptTensor]:
    r"""Adds a self-loop :math:`(i,i) \in \mathcal{E}` to every node
    :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`.
    In case the graph is weighted or has multi-dimensional edge features
    (:obj:`edge_attr != None`), edge features of self-loops will be added
    according to :obj:`fill_value`.

    Args:
        edge_index (LongTensor): The edge indices.
        edge_attr (Tensor, optional): Edge weights or multi-dimensional edge
            features. (default: :obj:`None`)
        fill_value (float or Tensor or str, optional): The way to generate
            edge features of self-loops (in case :obj:`edge_attr != None`).
            If given as :obj:`float` or :class:`torch.Tensor`, edge features of
            self-loops will be directly given by :obj:`fill_value`.
            If given as :obj:`str`, edge features of self-loops are computed by
            aggregating all features of edges that point to the specific node,
            according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`,
            :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`1.`)
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
    N = maybe_num_nodes(edge_index, num_nodes)

    loop_index = torch.arange(0, N, dtype=torch.long, device=edge_index.device)
    loop_index = loop_index.unsqueeze(0).repeat(2, 1)

    if edge_attr is not None:
        if fill_value is None:
            loop_attr = edge_attr.new_full((N, ) + edge_attr.size()[1:], 1.)

        elif isinstance(fill_value, float) or isinstance(fill_value, int):
            loop_attr = edge_attr.new_full((N, ) + edge_attr.size()[1:],
                                           fill_value)
        elif isinstance(fill_value, Tensor):
            loop_attr = fill_value.to(edge_attr.device, edge_attr.dtype)
            if edge_attr.dim() != loop_attr.dim():
                loop_attr = loop_attr.unsqueeze(0)
            sizes = [N] + [1] * (loop_attr.dim() - 1)
            loop_attr = loop_attr.repeat(*sizes)

        elif isinstance(fill_value, str):
            loop_attr = scatter(edge_attr,
                                edge_index[1],
                                dim=0,
                                dim_size=N,
                                reduce=fill_value)
        else:
            raise AttributeError("No valid 'fill_value' provided")

        edge_attr = torch.cat([edge_attr, loop_attr], dim=0)

    edge_index = torch.cat([edge_index, loop_index], dim=1)
    return edge_index, edge_attr
Ejemplo n.º 3
0
    def message(self, x_j: Tensor, edge_attr: OptTensor):
        assert edge_attr is not None

        EPS = 1e-15
        F, M = self.rel_in_channels, self.out_channels
        (E, D), K = edge_attr.size(), self.kernel_size

        if not self.separate_gaussians:
            gaussian = -0.5 * (edge_attr.view(E, 1, D) -
                               self.mu.view(1, K, D)).pow(2)
            gaussian = gaussian / (EPS + self.sigma.view(1, K, D).pow(2))
            gaussian = torch.exp(gaussian.sum(dim=-1))  # [E, K]

            return (x_j.view(E, K, M) * gaussian.view(E, K, 1)).sum(dim=-2)

        else:
            gaussian = -0.5 * (edge_attr.view(E, 1, 1, 1, D) -
                               self.mu.view(1, F, M, K, D)).pow(2)
            gaussian = gaussian / (EPS + self.sigma.view(1, F, M, K, D).pow(2))
            gaussian = torch.exp(gaussian.sum(dim=-1))  # [E, F, M, K]

            gaussian = gaussian * self.g.view(1, F, M, K)
            gaussian = gaussian.sum(dim=-1)  # [E, F, M]

            return (x_j.view(E, F, 1) * gaussian).sum(dim=-2)  # [E, M]
Ejemplo n.º 4
0
    def forward(self,
                t,
                x: Union[Tensor, OptPairTensor],
                edge_index: Adj,
                edge_attr: OptTensor = None,
                size: Size = None) -> Tensor:
        """"""
        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)

        # Node and edge feature dimensionalites need to match.
        if isinstance(edge_index, Tensor):
            assert edge_attr is not None
            assert x[0].size(-1) == edge_attr.size(-1)
        elif isinstance(edge_index, SparseTensor):
            assert x[0].size(-1) == edge_index.size(-1)

        # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)

        x_r = x[1]
        if x_r is not None:
            out += (1 + self.eps) * x_r

        return self.nn(t, out)
    def message(self, x_j: Tensor, edge_attr: OptTensor,
                edge_attr_len: OptTensor, duplicates_idx: OptTensor) -> Tensor:
        #         print('in message')
        #         print('x_j.size(), edge_attr.size()', x_j.size(), edge_attr.size())
        just_made_dup = False

        if duplicates_idx is None:  # 1st layer, so make duplicates_idx
            just_made_dup = True
            edge_indice, cnt_ = torch.unique(
                edge_attr,
                return_counts=True,
            )
            duplicates_idx = [
                [idx] * (c - 1)
                for idx, c in zip(edge_indice[cnt_ != 1], cnt_[cnt_ != 1])
            ]
            duplicates_idx = [_ for li in duplicates_idx for _ in li]

        # duplicates_idx is not None
        x_j = torch.cat((x_j, x_j[duplicates_idx, ]),
                        dim=0)  # ((E+num_duplicates),H)
        #         print("duplicates_idx: ", duplicates_idx)
        #         print("x_j: ", x_j)

        if just_made_dup:
            for i, idx in enumerate(duplicates_idx):
                edge_attr[torch.arange(edge_attr.size(0))[edge_attr == idx]
                          [-1]] = x_j.size(0) - len(duplicates_idx) + i

#         print('x_j.size(), edge_attr.size()', x_j.size(), edge_attr.size())

        if self.edge_enc:
            #             print("Adding edge_enc ... ")
            edge_pos_emb = self.edge_enc(edge_attr_len)
            #             print(edge_pos_emb.size())
            x_j[edge_attr] += edge_pos_emb.squeeze()


#         print("x_j size", x_j.size())
        self.duplicates_idx = duplicates_idx

        return F.relu(x_j) + self.eps, duplicates_idx
Ejemplo n.º 6
0
    def message(self, x_i: Tensor, x_j: Tensor, edge_type: Tensor,
                edge_attr: OptTensor, index: Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> Tensor:

        if self.num_bases is not None:  # Basis-decomposition =================
            w = torch.matmul(self.att, self.basis.view(self.num_bases, -1))
            w = w.view(self.num_relations, self.in_channels,
                       self.heads * self.out_channels)
        if self.num_blocks is not None:  # Block-diagonal-decomposition =======
            if (x_i.dtype == torch.long and x_j.dtype == torch.long
                    and self.num_blocks is not None):
                raise ValueError('Block-diagonal decomposition not supported '
                                 'for non-continuous input features.')
            w = self.weight
            x_i = x_i.view(-1, 1, w.size(1), w.size(2))
            x_j = x_j.view(-1, 1, w.size(1), w.size(2))
            w = torch.index_select(w, 0, edge_type)
            outi = torch.einsum('abcd,acde->ace', x_i, w)
            outi = outi.contiguous().view(-1, self.heads * self.out_channels)
            outj = torch.einsum('abcd,acde->ace', x_j, w)
            outj = outj.contiguous().view(-1, self.heads * self.out_channels)
        else:  # No regularization/Basis-decomposition ========================
            if self.num_bases is None:
                w = self.weight
            w = torch.index_select(w, 0, edge_type)
            outi = torch.bmm(x_i.unsqueeze(1), w).squeeze(-2)
            outj = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2)

        qi = torch.matmul(outi, self.q)
        kj = torch.matmul(outj, self.k)

        alpha_edge, alpha = 0, torch.tensor([0])
        if edge_attr is not None:
            if edge_attr.dim() == 1:
                edge_attr = edge_attr.view(-1, 1)
            assert self.lin_edge is not None, (
                "Please set 'edge_dim = edge_attr.size(-1)' while calling the "
                "RGATConv layer")
            edge_attributes = self.lin_edge(edge_attr).view(
                -1, self.heads * self.out_channels)
            if edge_attributes.size(0) != edge_attr.size(0):
                edge_attributes = torch.index_select(edge_attributes, 0,
                                                     edge_type)
            alpha_edge = torch.matmul(edge_attributes, self.e)

        if self.attention_mode == "additive-self-attention":
            if edge_attr is not None:
                alpha = torch.add(qi, kj) + alpha_edge
            else:
                alpha = torch.add(qi, kj)
            alpha = F.leaky_relu(alpha, self.negative_slope)
        elif self.attention_mode == "multiplicative-self-attention":
            if edge_attr is not None:
                alpha = (qi * kj) * alpha_edge
            else:
                alpha = qi * kj

        if self.attention_mechanism == "within-relation":
            across_out = torch.zeros_like(alpha)
            for r in range(self.num_relations):
                mask = edge_type == r
                across_out[mask] = softmax(alpha[mask], index[mask])
            alpha = across_out
        elif self.attention_mechanism == "across-relation":
            alpha = softmax(alpha, index, ptr, size_i)

        self._alpha = alpha

        if self.mod == "additive":
            if self.attention_mode == "additive-self-attention":
                ones = torch.ones_like(alpha)
                h = (outj.view(-1, self.heads, self.out_channels) *
                     ones.view(-1, self.heads, 1))
                h = torch.mul(self.w, h)

                return (outj.view(-1, self.heads, self.out_channels) *
                        alpha.view(-1, self.heads, 1) + h)
            elif self.attention_mode == "multiplicative-self-attention":
                ones = torch.ones_like(alpha)
                h = (outj.view(-1, self.heads, 1, self.out_channels) *
                     ones.view(-1, self.heads, self.dim, 1))
                h = torch.mul(self.w, h)

                return (outj.view(-1, self.heads, 1, self.out_channels) *
                        alpha.view(-1, self.heads, self.dim, 1) + h)

        elif self.mod == "scaled":
            if self.attention_mode == "additive-self-attention":
                ones = alpha.new_ones(index.size())
                degree = scatter_add(ones, index,
                                     dim_size=size_i)[index].unsqueeze(-1)
                degree = torch.matmul(degree, self.l1) + self.b1
                degree = self.activation(degree)
                degree = torch.matmul(degree, self.l2) + self.b2

                return torch.mul(
                    outj.view(-1, self.heads, self.out_channels) *
                    alpha.view(-1, self.heads, 1),
                    degree.view(-1, 1, self.out_channels))
            elif self.attention_mode == "multiplicative-self-attention":
                ones = alpha.new_ones(index.size())
                degree = scatter_add(ones, index,
                                     dim_size=size_i)[index].unsqueeze(-1)
                degree = torch.matmul(degree, self.l1) + self.b1
                degree = self.activation(degree)
                degree = torch.matmul(degree, self.l2) + self.b2

                return torch.mul(
                    outj.view(-1, self.heads, 1, self.out_channels) *
                    alpha.view(-1, self.heads, self.dim, 1),
                    degree.view(-1, 1, 1, self.out_channels))

        elif self.mod == "f-additive":
            alpha = torch.where(alpha > 0, alpha + 1, alpha)

        elif self.mod == "f-scaled":
            ones = alpha.new_ones(index.size())
            degree = scatter_add(ones, index,
                                 dim_size=size_i)[index].unsqueeze(-1)
            alpha = alpha * degree

        elif self.training and self.dropout > 0:
            alpha = F.dropout(alpha, p=self.dropout, training=True)

        else:
            alpha = alpha  # original

        if self.attention_mode == "additive-self-attention":
            return alpha.view(-1, self.heads, 1) * outj.view(
                -1, self.heads, self.out_channels)
        else:
            return (alpha.view(-1, self.heads, self.dim, 1) *
                    outj.view(-1, self.heads, 1, self.out_channels))