示例#1
0
    def forward(self, inputs):
        """
        Inputs:
            adj_norm        normalized adj matrix of the subgraph
            feat_in         2D matrix of input node features

        Outputs:
            adj_norm        same as input (to facilitate nn.Sequential)
            feat_out        2D matrix of output node features
        """
        # dropout-act-norm
        feat_in, adj, is_normed, dropedge = inputs
        if not is_normed and adj is not None:
            assert type(adj) == sp.csr_matrix
            adj = coo_scipy2torch(adj.tocoo()).to(feat_in.device)
            adj_norm = adj_norm_rw(adj, dropedge=dropedge)
        else:
            assert adj is None or type(adj) == torch.Tensor or type(
                adj) == tuple
            adj_norm = adj
        feat_in = self.f_dropout(feat_in)
        feat_self = feat_in
        feat_neigh = self._spmm(adj_norm, feat_in)
        feat_self_trans = self._f_norm(self.act(self.f_lin_self(feat_self)), 0)
        feat_neigh_trans = self._f_norm(self.act(self.f_lin_neigh(feat_neigh)),
                                        1)
        feat_out = feat_self_trans + feat_neigh_trans
        return feat_out, adj_norm, True, 0.
示例#2
0
 def forward(self, inputs):
     feat_in, adj, is_normed, dropedge = inputs
     feat_in = self.f_dropout(feat_in)
     if not is_normed and adj is not None:
         assert type(adj) == sp.csr_matrix
         adj_norm = adj_norm_sym(
             adj, dropedge=dropedge
         )  # self-edges are already added by C++ sampler
         adj_norm = coo_scipy2torch(adj_norm.tocoo()).to(feat_in.device)
     else:
         assert adj is None or type(adj) == torch.Tensor
         adj_norm = adj
     feat_aggr = torch.sparse.mm(adj_norm, feat_in)
     feat_trans = self.f_lin(feat_aggr)
     feat_out = self.f_norm(self.act(feat_trans))
     return feat_out, adj_norm, True, 0.
示例#3
0
 def _adj_drop(self, adj, is_normed, device, dropedge=0):
     """
     Will perform edge dropout only when is_normed == False
     """
     if type(adj) == sp.csr_matrix:
         assert not is_normed
         adj_norm = coo_scipy2torch(adj.tocoo()).to(device)
         # here we don't normalize adj (data = 1,1,1...). In DGL, it is sym normed
         if dropedge > 0:
             masked_indices = torch.floor(
                 torch.rand(int(adj_norm._values().size()[0] * dropedge)) *
                 adj_norm._values().size()[0]).long()
             adj_norm._values()[masked_indices] = 0
     else:
         assert type(adj) == torch.Tensor and is_normed
         adj_norm = adj
     return adj_norm
示例#4
0
 def forward(self, inputs):
     feat_in, adj, is_normed, dropedge = inputs
     assert not is_normed
     feat_in = self.f_dropout(feat_in)
     if type(adj) == sp.csr_matrix:
         adj = coo_scipy2torch(adj.tocoo()).to(feat_in.device)
         deg_orig = get_deg_torch_sparse(adj)
         masked_indices = torch.floor(
             torch.rand(int(adj._values().size()[0] * dropedge)) *
             adj._values().size()[0]).long()
         adj._values()[masked_indices] = 0
         deg_dropped = torch.clamp(get_deg_torch_sparse(adj), min=1)
         rescale = torch.repeat_interleave(deg_orig / deg_dropped,
                                           deg_orig.long())
         adj._values()[:] = adj._values() * rescale
     feat_aggr = torch.sparse.mm(adj, feat_in)
     feat_aggr += (1 + self.eps) * feat_in
     feat_out = self.mlp(feat_aggr)
     feat_out = self.f_norm(self.act(feat_out))
     return feat_out, adj, False, 0.
示例#5
0
 def _smooth_signals_subg(self,
                          adj,
                          signal,
                          target,
                          order: int,
                          pbar,
                          type_norm: str,
                          reduction_orders: str,
                          args: dict,
                          add_self_edge: bool = True,
                          is_normed: bool = False):
     if not is_normed:
         if type_norm == 'sym':
             adj_norm = adj_norm_sym(adj, add_self_edge=add_self_edge)
         elif type_norm == 'rw':  # NOTE: we haven't supported add_self_edge for rw norm yet.
             adj_norm = adj_norm_rw(adj)
         elif type_norm == 'ppr':
             assert order == 1
             adj_norm = adj_norm_sym(adj, add_self_edge=True)  # see APPNP
         else:
             raise NotImplementedError
         adj_norm = coo_scipy2torch(adj_norm.tocoo()).to(signal.device)
     else:
         adj_norm = adj
     if type_norm == 'ppr':
         signal_converged = self._ppr(adj_norm, signal, pbar, target.size,
                                      **args)[target]
         signal_orig = signal[target]
         if reduction_orders in ['cat', 'concat']:
             signal_out = torch.cat([signal_orig, signal_converged],
                                    dim=1).to(signal.device)
         elif reduction_orders == 'sum':
             signal_out = signal_orig + signal_converged
         elif reduction_orders == 'last':
             signal_out = signal_converged
     elif type_norm in ['sym', 'rw']:
         signal_order = signal
         if reduction_orders in ['cat', 'concat']:
             F = signal_order.shape[1]
             F_new = F + order * F
             signal_out = torch.zeros(target.size,
                                      F_new).to(signal_order.device)
             signal_out[:, :F] = signal_order[target]
             for _k in range(order):
                 signal_order = torch.sparse.mm(adj_norm, signal_order)
                 signal_out[:, (_k + 1) * F:(_k + 2) *
                            F] = signal_order[target]
                 if pbar is not None:
                     pbar.update(target.size)
         elif reduction_orders == 'sum':
             F_new = signal_order.shape[1]
             signal_out = signal_order[target].copy()
             for _k in range(order):
                 signal_order = torch.sparse.mm(adj_norm, signal_order)
                 signal_out += signal_order[target]
                 if pbar is not None:
                     pbar.update(target.size)
         elif reduction_orders == 'last':
             for _k in range(order):
                 signal_order = torch.sparse.mm(adj_norm, signal_order)
                 if pbar is not None:
                     pbar.update(target.size)
             signal_out = signal_order[target]
     return signal_out, adj_norm