Esempio n. 1
0
class Net(torch.nn.Module):
    def __init__(self, dataset, data, args, adj=()):
        super(Net, self).__init__()
        self.data = data
        self.conv1 = GCNConv(dataset.num_features,
                             16,
                             normalize=not args.use_gdc)
        self.conv2 = GCNConv(16,
                             dataset.num_classes,
                             normalize=not args.use_gdc)
        # print(adj)
        if data.edge_attr == None:
            data.edge_attr = torch.ones(data.edge_index[0].size(0))
        if len(adj) == 0:
            self.adj1 = SparseTensor(row=data.edge_index[0],
                                     col=data.edge_index[1],
                                     value=torch.clone(
                                         data.edge_attr)).to(device)
            # self.adj1 = (torch.clone(data.edge_index).to(device), torch.clone(data.edge_attr).to(device))
        else:
            self.adj1 = adj
        self.adj2 = self.adj1.clone()

    def forward(self):
        x = self.data.x
        # self.ei1, self.ew1 = self.adj1
        # self.ei2, self.ew2 = self.adj2
        x = F.relu(self.conv1(x, self.adj1))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, self.adj2)
        return F.log_softmax(x, dim=1)
Esempio n. 2
0
class Net(torch.nn.Module):
    def __init__(self, dataset, data, args, adj=()):
        super(Net, self).__init__()
        self.data = data
        if args.dataset == "F":
            hidden = 128
        else:
            hidden = 16
        self.conv1 = GCNConv(dataset.num_features, hidden,
                             normalize=not args.use_gdc)
        self.conv2 = GCNConv(hidden, dataset.num_classes,
                             normalize=not args.use_gdc)
        # print(adj)
        if data.edge_attr == None:
            data.edge_attr = torch.ones(data.edge_index[0].size(0))
        if len(adj) == 0:
            self.adj1 = SparseTensor(row=data.edge_index[0], col=data.edge_index[1], value=torch.clone(data.edge_attr)).to_torch_sparse_coo_tensor().to(device)
            self.adj1 = self.adj1 + torch.eye(self.adj1.shape[0]).to_sparse().to(device)
            # self.adj1 = (torch.clone(data.edge_index).to(device), torch.clone(data.edge_attr).to(device))
        else:
            self.adj1 = adj
        self.id = torch.eye(self.adj1.shape[0]).to_sparse().to(device)
        self.adj2 = self.adj1.clone()
        

    def forward(self):
        x = self.data.x
        # self.ei1, self.ew1 = self.adj1
        # self.ei2, self.ew2 = self.adj2
        x = F.relu(self.conv1(x, SparseTensor.from_torch_sparse_coo_tensor(self.adj1 -self.id)))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, SparseTensor.from_torch_sparse_coo_tensor(self.adj2 - self.id))
        return F.log_softmax(x, dim=1)
Esempio n. 3
0
def laplace(adj: SparseTensor, lap_type=None):
    M, N = adj.sizes()
    assert M == N
    row, col, val = adj.clone().coo()
    val = col.new_ones(col.shape, dtype=adj.dtype()) if val is None else val
    deg = adj.sum(0)

    loop_index = torch.arange(N, device=adj.device()).unsqueeze_(0)
    if lap_type in (None, "sym"):
        deg05 = deg.pow(-0.5)
        deg05[deg05 == float("inf")] = 0
        wgt = deg05[row] * val * deg05[col]
        wgt = torch.cat([-wgt.unsqueeze_(0), val.new_ones(1, N)], 1).squeeze_()

    elif lap_type == "rw":
        deg_inv = 1.0 / deg
        deg_inv[deg_inv == float("inf")] = 0
        wgt = deg_inv[row] * val

        wgt = torch.cat([-wgt.unsqueeze_(0), val.new_ones(1, N)], 1).squeeze_()

    elif lap_type == "comb":
        wgt = torch.cat([-val.unsqueeze_(0), deg.unsqueeze_(0)], 1).squeeze_()

    else:
        raise TypeError("Invalid laplace type: {}".format(lap_type))

    row = torch.cat([row.unsqueeze_(0), loop_index], 1).squeeze_()
    col = torch.cat([col.unsqueeze_(0), loop_index], 1).squeeze_()
    lap = SparseTensor(row=row, col=col, value=wgt, sparse_sizes=(M, N))
    return lap
Esempio n. 4
0
def normalize_laplace(L: SparseTensor, lam_max: float = 2.0):
    Ln = L.clone()
    row, col, val = Ln.coo()
    diag_mask = row == col
    val[...] = (2.0 * val) / lam_max
    val.masked_fill_(val == float("inf"), 0)
    val[diag_mask] -= 1
    return Ln