Exemplo n.º 1
0
    def forward(self, x, edge_index, edge_weight=None):
        """"""
        # edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        # print(x.size(), edge_index.size())
        row, col = edge_index
        batch, num_nodes, num_edges, K = x.size(0), x.size(1), row.size(
            0), self.weight.size(0)

        edge_weight = edge_weight.view(-1)
        assert edge_weight.size(0) == edge_index.size(1)

        ###degree matrix
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        # Compute normalized and rescaled Laplacian.
        deg = deg.pow(-0.5)
        deg[torch.isinf(deg)] = 0
        lap = -deg[row] * edge_weight * deg[col]
        ###Rescale the Laplacian eigenvalues in [-1, 1]
        ##rescale: 2L/lmax-I; lmax=1.0
        fill_value = -0.05  ##-0.5
        edge_index, lap = add_self_loops(edge_index, lap, fill_value,
                                         num_nodes)
        lap *= 2

        ########################################
        # Perform filter operation recurrently.
        Tx_0 = x
        out = torch.matmul(Tx_0, self.weight[0])
        if K > 1:
            Tx_1 = spmm(edge_index, lap, num_nodes,
                        x.permute(1, 2, 0).contiguous().view(
                            (num_nodes, -1))).view(
                                (num_nodes, -1, batch)).permute(
                                    2, 0,
                                    1)  # spmm(edge_index, lap, num_nodes, x)
            out = out + torch.matmul(Tx_1, self.weight[1])

        for k in range(2, K):
            Tx_2 = 2 * spmm(
                edge_index, lap, num_nodes,
                x.permute(1, 2, 0).contiguous().view((num_nodes, -1))).view(
                    (num_nodes, -1, batch)).permute(2, 0, 1) - Tx_0
            # 2 * spmm(edge_index, lap, num_nodes, Tx_1) - Tx_0
            out = out + torch.matmul(Tx_2, self.weight[k])
            Tx_0, Tx_1 = Tx_1, Tx_2

        if self.bias is not None:
            out = out + self.bias

        return out
Exemplo n.º 2
0
    def forward(self,  x, edge_index, edge_weight=None):
        batch, num_nodes = x.size(0), x.size(1)
        ##first adjust the adj matrix with diag elements
        edge_index, edge_weight = add_self_loops(edge_index, edge_weight, 1, num_nodes)
        row, col = edge_index
        
        edge_weight = edge_weight.view(-1)
        assert edge_weight.size(0) == edge_index.size(1)
        
        ###degree matrix
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        # Compute normalized and rescaled Laplacian.
        deg = deg.pow(-0.5)
        deg[torch.isinf(deg)] = 0
        lap = deg[row] * edge_weight * deg[col]
        ###Rescale the Laplacian eigenvalues in [-1, 1]
        #fill_value = 0.05  ##-0.5
        #edge_index, lap = add_self_loops(edge_index, lap, fill_value, num_nodes)

        x = torch.matmul(x, self.weight)
        out = spmm(edge_index, lap, num_nodes, x.permute(1, 2, 0).contiguous().view((num_nodes, -1))).view((num_nodes, -1, batch)).permute(2, 0,1)  # spmm(edge_index, lap, num_nodes, x)

        if self.bias is not None:
            out = out + self.bias

        return out
Exemplo n.º 3
0
    def forward(self, x, edge_index, edge_weight=None):
        # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)

        row, col = edge_index[0], edge_index[1]
        num_nodes, num_edges, K = x.size(0), row.size(0), self.weight.size(0)

        if edge_weight is None:
            edge_weight = torch.ones((num_edges,),
                                     dtype=x.dtype,
                                     device=edge_index.device)
        edge_weight = edge_weight.view(-1)

        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)

        # Compute normalized and rescaled graph Laplacian
        deg = deg.pow(-0.5)
        lap = -deg[row] * edge_weight * deg[col]

        # Perform filter operation recurrently
        Tx_0 = x
        out = torch.mm(Tx_0, self.weight[0])
        Tx_1 = spmm(edge_index, lap, num_nodes, x)

        if K > 1:
            out = out + torch.mm(Tx_1, self.weight[1])

        for k in range(K):
            if k >= 2:
                Tx_2 = 2 * spmm(edge_index, lap, num_nodes, Tx_1) - Tx_0
                out = out + torch.mm(Tx_2, self.weight[k])
                Tx_0, Tx_1 = Tx_1, Tx_2

        if self.bias is not None:
            out = out + self.bias

        return out