def forward(self, x, edge_index):
     # type: (Tensor, Tensor) -> Tensor
     edge_index, _ = remove_self_loops(edge_index)
     edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
     row, col = edge_index[0], edge_index[1]
     out = scatter_mean(x[col], row, dim=0)  # do not set dim_size, out.size() = row.max() + 1
     x = x[0:out.size(0)]
     x = torch.cat([x, out], dim=1)
     out = torch.matmul(x, self.weight)
     out = out + self.bias
     out = F.normalize(out, p=2.0, dim=-1)
     return out
    def forward(self, x, edge_index):
        # type: (Tensor, Tensor) -> Tensor
        """"""
        edge_index, _ = remove_self_loops(edge_index)
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        x = x.unsqueeze(-1) if x.dim() == 1 else x
        row, col = edge_index[0], edge_index[1]

        x = torch.matmul(x, self.weight)
        out = scatter_mean(x[col], row, dim=0, dim_size=x.size(0))

        if self.bias is not None:
            out = out + self.bias

        if self.normalize:
            out = F.normalize(out, p=2.0, dim=-1)

        return out
示例#3
0
    def forward(self, x, edge_index, edge_weight=None):
        # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)

        row, col = edge_index[0], edge_index[1]
        num_nodes, num_edges, K = x.size(0), row.size(0), self.weight.size(0)

        if edge_weight is None:
            edge_weight = torch.ones((num_edges,),
                                     dtype=x.dtype,
                                     device=edge_index.device)
        edge_weight = edge_weight.view(-1)

        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)

        # Compute normalized and rescaled graph Laplacian
        deg = deg.pow(-0.5)
        lap = -deg[row] * edge_weight * deg[col]

        # Perform filter operation recurrently
        Tx_0 = x
        out = torch.mm(Tx_0, self.weight[0])
        Tx_1 = spmm(edge_index, lap, num_nodes, x)

        if K > 1:
            out = out + torch.mm(Tx_1, self.weight[1])

        for k in range(K):
            if k >= 2:
                Tx_2 = 2 * spmm(edge_index, lap, num_nodes, Tx_1) - Tx_0
                out = out + torch.mm(Tx_2, self.weight[k])
                Tx_0, Tx_1 = Tx_1, Tx_2

        if self.bias is not None:
            out = out + self.bias

        return out
示例#4
0
 def forward(self, x, edge_index):
     # type: (Tensor, Tensor) -> Tensor
     edge_index, _ = remove_self_loops(edge_index)
     edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
     x = torch.mm(x, self.weight).view(-1, self.heads, self.out_channels)
     return self.propagate(edge_index, x=x, num_nodes=x.size(0))
示例#5
0
	def remove_self_loops(self):
		# algorithms tend to work better if there is no self-loop in the given graph, so we call this method at first.
		utils.remove_self_loops(self.graph)