def forward(self, inputs): """ Inputs: adj_norm normalized adj matrix of the subgraph feat_in 2D matrix of input node features Outputs: adj_norm same as input (to facilitate nn.Sequential) feat_out 2D matrix of output node features """ # dropout-act-norm feat_in, adj, is_normed, dropedge = inputs if not is_normed and adj is not None: assert type(adj) == sp.csr_matrix adj = coo_scipy2torch(adj.tocoo()).to(feat_in.device) adj_norm = adj_norm_rw(adj, dropedge=dropedge) else: assert adj is None or type(adj) == torch.Tensor or type( adj) == tuple adj_norm = adj feat_in = self.f_dropout(feat_in) feat_self = feat_in feat_neigh = self._spmm(adj_norm, feat_in) feat_self_trans = self._f_norm(self.act(self.f_lin_self(feat_self)), 0) feat_neigh_trans = self._f_norm(self.act(self.f_lin_neigh(feat_neigh)), 1) feat_out = feat_self_trans + feat_neigh_trans return feat_out, adj_norm, True, 0.
def forward(self, inputs): feat_in, adj, is_normed, dropedge = inputs feat_in = self.f_dropout(feat_in) if not is_normed and adj is not None: assert type(adj) == sp.csr_matrix adj_norm = adj_norm_sym( adj, dropedge=dropedge ) # self-edges are already added by C++ sampler adj_norm = coo_scipy2torch(adj_norm.tocoo()).to(feat_in.device) else: assert adj is None or type(adj) == torch.Tensor adj_norm = adj feat_aggr = torch.sparse.mm(adj_norm, feat_in) feat_trans = self.f_lin(feat_aggr) feat_out = self.f_norm(self.act(feat_trans)) return feat_out, adj_norm, True, 0.
def _adj_drop(self, adj, is_normed, device, dropedge=0): """ Will perform edge dropout only when is_normed == False """ if type(adj) == sp.csr_matrix: assert not is_normed adj_norm = coo_scipy2torch(adj.tocoo()).to(device) # here we don't normalize adj (data = 1,1,1...). In DGL, it is sym normed if dropedge > 0: masked_indices = torch.floor( torch.rand(int(adj_norm._values().size()[0] * dropedge)) * adj_norm._values().size()[0]).long() adj_norm._values()[masked_indices] = 0 else: assert type(adj) == torch.Tensor and is_normed adj_norm = adj return adj_norm
def forward(self, inputs): feat_in, adj, is_normed, dropedge = inputs assert not is_normed feat_in = self.f_dropout(feat_in) if type(adj) == sp.csr_matrix: adj = coo_scipy2torch(adj.tocoo()).to(feat_in.device) deg_orig = get_deg_torch_sparse(adj) masked_indices = torch.floor( torch.rand(int(adj._values().size()[0] * dropedge)) * adj._values().size()[0]).long() adj._values()[masked_indices] = 0 deg_dropped = torch.clamp(get_deg_torch_sparse(adj), min=1) rescale = torch.repeat_interleave(deg_orig / deg_dropped, deg_orig.long()) adj._values()[:] = adj._values() * rescale feat_aggr = torch.sparse.mm(adj, feat_in) feat_aggr += (1 + self.eps) * feat_in feat_out = self.mlp(feat_aggr) feat_out = self.f_norm(self.act(feat_out)) return feat_out, adj, False, 0.
def _smooth_signals_subg(self, adj, signal, target, order: int, pbar, type_norm: str, reduction_orders: str, args: dict, add_self_edge: bool = True, is_normed: bool = False): if not is_normed: if type_norm == 'sym': adj_norm = adj_norm_sym(adj, add_self_edge=add_self_edge) elif type_norm == 'rw': # NOTE: we haven't supported add_self_edge for rw norm yet. adj_norm = adj_norm_rw(adj) elif type_norm == 'ppr': assert order == 1 adj_norm = adj_norm_sym(adj, add_self_edge=True) # see APPNP else: raise NotImplementedError adj_norm = coo_scipy2torch(adj_norm.tocoo()).to(signal.device) else: adj_norm = adj if type_norm == 'ppr': signal_converged = self._ppr(adj_norm, signal, pbar, target.size, **args)[target] signal_orig = signal[target] if reduction_orders in ['cat', 'concat']: signal_out = torch.cat([signal_orig, signal_converged], dim=1).to(signal.device) elif reduction_orders == 'sum': signal_out = signal_orig + signal_converged elif reduction_orders == 'last': signal_out = signal_converged elif type_norm in ['sym', 'rw']: signal_order = signal if reduction_orders in ['cat', 'concat']: F = signal_order.shape[1] F_new = F + order * F signal_out = torch.zeros(target.size, F_new).to(signal_order.device) signal_out[:, :F] = signal_order[target] for _k in range(order): signal_order = torch.sparse.mm(adj_norm, signal_order) signal_out[:, (_k + 1) * F:(_k + 2) * F] = signal_order[target] if pbar is not None: pbar.update(target.size) elif reduction_orders == 'sum': F_new = signal_order.shape[1] signal_out = signal_order[target].copy() for _k in range(order): signal_order = torch.sparse.mm(adj_norm, signal_order) signal_out += signal_order[target] if pbar is not None: pbar.update(target.size) elif reduction_orders == 'last': for _k in range(order): signal_order = torch.sparse.mm(adj_norm, signal_order) if pbar is not None: pbar.update(target.size) signal_out = signal_order[target] return signal_out, adj_norm