class posLinear(nn.Module): ''' Linear layer with positive weights weights can take any value, but will only impact the model they are greater than a minimum value ''' def __init__(self, in_features, out_features, minval=0.0, hard_constraint=False): super(posLinear, self).__init__() self.hard_constraint=hard_constraint self.weight = Parameter(torch.ones(in_features, out_features)) # raise to the power beta self.bias = Parameter(torch.zeros(out_features)) if self.hard_constraint: self.posConstraint = weightConstraint(minval) else: self.register_buffer("minval", torch.tensor(minval)) def forward(self, x): if self.hard_constraint: self.apply(self.posConstraint) x = [email protected] + self.bias else: x = x@(self.weight.max(self.minval)) + self.bias return x
class SpectralAttack(BaseAttack): """ Spectral attack for graph data. Parameters """ def __init__(self, model=None, nnodes=None, loss_type='CE', feature_shape=None, attack_structure=True, attack_features=False, regularization_weight=0.0, device='cpu'): super(SpectralAttack, self).__init__(model, nnodes, attack_structure, attack_features, device) assert attack_structure or attack_features, 'attack_feature or attack_structure cannot be both False' self.loss_type = loss_type self.modified_adj = None self.modified_features = None self.regularization_weight = regularization_weight if attack_features: assert True, 'Current Spectral Attack does not support attack feature' if attack_structure: assert nnodes is not None, 'Please give nnodes=' self.adj_changes = Parameter( torch.FloatTensor(int(nnodes * (nnodes - 1) / 2))) self.adj_changes.data.fill_(0) self.complementary = None def attack(self, ori_features, ori_adj, labels, idx_target, idx_train, idx_test, n_perturbations, att_lr, epochs=200, verbose=False, reduction='mean', **kwargs): """ Generate perturbations on the input graph """ victim_model = self.surrogate self.sparse_features = sp.issparse(ori_features) # ori_adj, ori_features, labels = utils.to_tensor(ori_adj, ori_features, labels, device=self.device) ori_adj_norm = utils.normalize_adj_tensor(ori_adj, device=self.device) ori_e, ori_v = torch.symeig(ori_adj_norm, eigenvectors=True) victim_model.eval() for t in tqdm(range(epochs), desc='Perturb Adj'): modified_adj = self.get_modified_adj(ori_adj) adj_norm = utils.normalize_adj_tensor(modified_adj, device=self.device) output = victim_model( ori_features, adj_norm) # forward of gcn need to normalize adj first self.loss = self._loss(output[idx_target], labels[idx_target]) # New: add regularization term for spectral distance eigen_mse = 0 eigen_norm = self.norm = torch.norm(ori_e) if self.regularization_weight != 0: e, v = torch.symeig(adj_norm, eigenvectors=True) # eigen_mse = F.mse_loss(ori_e, e, reduction=reduction) eigen_mse = torch.norm(ori_e, e) reg_loss = eigen_mse / eigen_norm * self.regularization_weight if verbose and t % 20 == 0: loss_target, acc_target = calc_acc(output, labels, idx_target) print('-- Epoch {}, '.format(t), 'class loss = {:.4f} | '.format(self.loss.item()), 'reg loss = {:.8f} | '.format(reg_loss), 'eigen_mse = {:.8f} | '.format(eigen_mse), 'eigen_norm = {:.4f} | '.format(eigen_norm), 'acc = {}'.format(acc_target)) self.loss += reg_loss adj_grad = torch.autograd.grad(self.loss, self.adj_changes)[0] if self.loss_type == 'CE': # lr = 200 / np.sqrt(t+1) lr = att_lr / np.sqrt(t + 1) self.adj_changes.data.add_(lr * adj_grad) if self.loss_type == 'CW': # lr = 0.1 / np.sqrt(t+1) lr = att_lr / np.sqrt(t + 1) self.adj_changes.data.add_(lr * adj_grad) self.projection(n_perturbations) self.random_sample(ori_adj, ori_features, labels, idx_target, n_perturbations) self.modified_adj = self.get_modified_adj(ori_adj).detach() self.check_adj_tensor(self.modified_adj) # for sanity check ori_adj_norm = utils.normalize_adj_tensor(ori_adj, device=self.device) ori_e, ori_v = torch.symeig(ori_adj_norm, eigenvectors=True) adj_norm = utils.normalize_adj_tensor(self.modified_adj, device=self.device) e, v = torch.symeig(adj_norm, eigenvectors=True) self.adj = ori_adj.detach() self.labels = labels.detach() self.ori_e = ori_e self.ori_v = ori_v self.e = e self.v = v def random_sample(self, ori_adj, ori_features, labels, idx_target, n_perturbations): K = 10 best_loss = -1000 victim_model = self.surrogate with torch.no_grad(): s = self.adj_changes.cpu().detach().numpy() for i in range(K): sampled = np.random.binomial(1, s) # randm = np.random.uniform(size=s.shape[0]) # sampled = np.where(s > randm, 1, 0) # if sampled.sum() > n_perturbations: # continue while sampled.sum() > n_perturbations: sampled = np.random.binomial(1, s) self.adj_changes.data.copy_(torch.tensor(sampled)) modified_adj = self.get_modified_adj(ori_adj) adj_norm = utils.normalize_adj_tensor(modified_adj, device=self.device) output = victim_model(ori_features, adj_norm) loss = self._loss(output[idx_target], labels[idx_target]) # loss = F.nll_loss(output[idx_target], labels[idx_target]) # print(loss) if best_loss < loss: best_loss = loss best_s = sampled self.adj_changes.data.copy_(torch.tensor(best_s)) def get_modified_adj(self, ori_adj): if self.complementary is None: self.complementary = (torch.ones_like(ori_adj) - torch.eye( self.nnodes).to(self.device) - ori_adj) - ori_adj m = torch.zeros((self.nnodes, self.nnodes)).to(self.device) tril_indices = torch.tril_indices(row=self.nnodes, col=self.nnodes, offset=-1) m[tril_indices[0], tril_indices[1]] = self.adj_changes m = m + m.t() modified_adj = self.complementary * m + ori_adj return modified_adj def projection(self, n_perturbations): if torch.clamp(self.adj_changes, 0, 1).sum() > n_perturbations: left = (self.adj_changes - 1).min() right = self.adj_changes.max() miu = self.bisection(left, right, n_perturbations, epsilon=1e-5) self.adj_changes.data.copy_( torch.clamp(self.adj_changes.data - miu, min=0, max=1)) else: self.adj_changes.data.copy_( torch.clamp(self.adj_changes.data, min=0, max=1)) def _loss(self, output, labels): if self.loss_type == "CE": loss = F.nll_loss(output, labels) if self.loss_type == "CW": onehot = utils.tensor2onehot(labels) best_second_class = (output - 1000 * onehot).argmax(1).detach() margin = output[np.arange(len(output)), labels] - \ output[np.arange(len(output)), best_second_class] k = 0 loss = -torch.clamp(margin, min=k).mean() # loss = torch.clamp(margin.sum()+50, min=k) return loss def bisection(self, a, b, n_perturbations, epsilon): def func(x): return torch.clamp(self.adj_changes - x, 0, 1).sum() - n_perturbations miu = a while ((b - a) >= epsilon): miu = (a + b) / 2 # Check if middle point is root if (func(miu) == 0.0): break # Decide the side to repeat the steps if (func(miu) * func(a) < 0): b = miu else: a = miu # print("The value of root is : ","%.4f" % miu) return miu
class PGDAttack(BaseAttack): """PGD attack for graph data. Parameters ---------- model : model to attack. Default `None`. nnodes : int number of nodes in the input graph loss_type: str attack loss type, chosen from ['CE', 'CW'] feature_shape : tuple shape of the input node features attack_structure : bool whether to attack graph structure attack_features : bool whether to attack node features device: str 'cpu' or 'cuda' Examples -------- >>> from deeprobust.graph.data import Dataset >>> from deeprobust.graph.defense import GCN >>> from deeprobust.graph.global_attack import PGDAttack >>> from deeprobust.graph.utils import preprocess >>> data = Dataset(root='/tmp/', name='cora') >>> adj, features, labels = data.adj, data.features, data.labels >>> adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False) # conver to tensor >>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test >>> # Setup Victim Model >>> victim_model = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1, nhid=16, dropout=0.5, weight_decay=5e-4, device='cpu').to('cpu') >>> victim_model.fit(features, adj, labels, idx_train) >>> # Setup Attack Model >>> model = PGDAttack(model=victim_model, nnodes=adj.shape[0], loss_type='CE', device='cpu').to('cpu') >>> model.attack(features, adj, labels, idx_train, n_perturbations=10) >>> modified_adj = model.modified_adj """ def __init__(self, model=None, nnodes=None, loss_type='CE', feature_shape=None, attack_structure=True, attack_features=False, device='cpu'): super(PGDAttack, self).__init__(model, nnodes, attack_structure, attack_features, device) assert attack_features or attack_structure, 'attack_features or attack_structure cannot be both False' self.loss_type = loss_type self.modified_adj = None self.modified_features = None if attack_structure: assert nnodes is not None, 'Please give nnodes=' self.adj_changes = Parameter(torch.FloatTensor(int(nnodes*(nnodes-1)/2))) self.adj_changes.data.fill_(0) if attack_features: assert True, 'Topology Attack does not support attack feature' self.complementary = None def attack(self, ori_features, ori_adj, labels, idx_train, n_perturbations, epochs=200, **kwargs): """Generate perturbations on the input graph. Parameters ---------- ori_features : Original (unperturbed) node feature matrix ori_adj : Original (unperturbed) adjacency matrix labels : node labels idx_train : node training indices n_perturbations : int Number of perturbations on the input graph. Perturbations could be edge removals/additions or feature removals/additions. epochs: number of training epochs """ victim_model = self.surrogate self.sparse_features = sp.issparse(ori_features) ori_adj, ori_features, labels = utils.to_tensor(ori_adj, ori_features, labels, device=self.device) victim_model.eval() for t in tqdm(range(epochs)): modified_adj = self.get_modified_adj(ori_adj) adj_norm = utils.normalize_adj_tensor(modified_adj) output = victim_model(ori_features, adj_norm) # loss = F.nll_loss(output[idx_train], labels[idx_train]) loss = self._loss(output[idx_train], labels[idx_train]) adj_grad = torch.autograd.grad(loss, self.adj_changes)[0] if self.loss_type == 'CE': lr = 200 / np.sqrt(t+1) self.adj_changes.data.add_(lr * adj_grad) if self.loss_type == 'CW': lr = 0.1 / np.sqrt(t+1) self.adj_changes.data.add_(lr * adj_grad) self.projection(n_perturbations) self.random_sample(ori_adj, ori_features, labels, idx_train, n_perturbations) self.modified_adj = self.get_modified_adj(ori_adj).detach() def random_sample(self, ori_adj, ori_features, labels, idx_train, n_perturbations): K = 20 best_loss = -1000 victim_model = self.surrogate with torch.no_grad(): s = self.adj_changes.cpu().detach().numpy() for i in range(K): sampled = np.random.binomial(1, s) print(sampled.sum()) if sampled.sum() > n_perturbations: continue self.adj_changes.data.copy_(torch.tensor(sampled)) modified_adj = self.get_modified_adj(ori_adj) adj_norm = utils.normalize_adj_tensor(modified_adj) output = victim_model(ori_features, adj_norm) loss = self._loss(output[idx_train], labels[idx_train]) # loss = F.nll_loss(output[idx_train], labels[idx_train]) print(loss) if best_loss < loss: best_loss = loss best_s = sampled self.adj_changes.data.copy_(torch.tensor(best_s)) def _loss(self, output, labels): if self.loss_type == "CE": loss = F.nll_loss(output, labels) if self.loss_type == "CW": onehot = utils.tensor2onehot(labels) best_second_class = (output - 1000*onehot).argmax(1) margin = output[np.arange(len(output)), labels] - \ output[np.arange(len(output)), best_second_class] k = 0 loss = -torch.clamp(margin, min=k).mean() # loss = torch.clamp(margin.sum()+50, min=k) return loss def projection(self, n_perturbations): # projected = torch.clamp(self.adj_changes, 0, 1) if torch.clamp(self.adj_changes, 0, 1).sum() > n_perturbations: left = (self.adj_changes - 1).min() right = self.adj_changes.max() miu = self.bisection(left, right, n_perturbations, epsilon=1e-5) self.adj_changes.data.copy_(torch.clamp(self.adj_changes.data - miu, min=0, max=1)) else: self.adj_changes.data.copy_(torch.clamp(self.adj_changes.data, min=0, max=1)) def get_modified_adj(self, ori_adj): if self.complementary is None: self.complementary = (torch.ones_like(ori_adj) - torch.eye(self.nnodes).to(self.device) - ori_adj) - ori_adj m = torch.zeros((self.nnodes, self.nnodes)).to(self.device) tril_indices = torch.tril_indices(row=self.nnodes-1, col=self.nnodes-1, offset=0) m[tril_indices[0], tril_indices[1]] = self.adj_changes # m += m.t() m = m + m.t() modified_adj = self.complementary * m + ori_adj return modified_adj def bisection(self, a, b, n_perturbations, epsilon): def func(x): return torch.clamp(self.adj_changes-x, 0, 1).sum() - n_perturbations miu = a while ((b-a) >= epsilon): miu = (a+b)/2 # Check if middle point is root if (func(miu) == 0.0): break # Decide the side to repeat the steps if (func(miu)*func(a) < 0): b = miu else: a = miu # print("The value of root is : ","%.4f" % miu) return miu
class PGDAttack(BaseAttack): """ Spectral attack for graph data """ def __init__(self, model=None, nnodes=None, loss_type='CE', feature_shape=None, attack_structure=True, attack_features=False, loss_weight=1.0, regularization_weight=0.0, device='cpu'): super(PGDAttack, self).__init__(model, nnodes, attack_structure, attack_features, device) assert attack_structure or attack_features, 'attack_feature or attack_structure cannot be both False' self.loss_type = loss_type self.modified_adj = None self.modified_features = None self.loss_weight = loss_weight self.regularization_weight = regularization_weight if attack_features: assert True, 'Current Spectral Attack does not support attack feature' if attack_structure: assert nnodes is not None, 'Please give nnodes=' self.adj_changes = Parameter( torch.FloatTensor(int(nnodes * (nnodes - 1) / 2))) torch.nn.init.uniform_(self.adj_changes, 0.0, 0.001) # self.adj_changes.data.fill_(0) self.complementary = None def set_model(self, model): self.surrogate = model def attack(self, ori_features, ori_adj, labels, idx_target, n_perturbations, att_lr, epochs=200, distance_type='l2', sample_type='sample', opt_type='max', verbose=True, **kwargs): """ Generate perturbations on the input graph """ victim_model = self.surrogate self.sparse_features = sp.issparse(ori_features) # ori_adj, ori_features, labels = utils.to_tensor(ori_adj, ori_features, labels, device=self.device) ori_adj_norm = utils.normalize_adj_tensor(ori_adj, device=self.device) ori_e, ori_v = torch.symeig(ori_adj_norm, eigenvectors=True) l, r, m = 0, 0, 0 victim_model.eval() # for t in tqdm(range(epochs), desc='Perturb Adj'): for t in tqdm(range(epochs)): modified_adj = self.get_modified_adj(ori_adj) adj_norm = utils.normalize_adj_tensor(modified_adj, device=self.device) output = victim_model( ori_features, adj_norm) # forward of gcn need to normalize adj first task_loss = self._loss(output[idx_target], labels[idx_target]) # spectral distance term for spectral distance eigen_mse = torch.tensor(0) eigen_self = torch.tensor(0) eigen_gf = torch.tensor(0) eigen_norm = self.norm = torch.norm(ori_e) if self.regularization_weight != 0: # add noise to make the graph asymmetric modified_adj_noise = modified_adj # modified_adj_noise = self.add_random_noise(modified_adj) adj_norm_noise = utils.normalize_adj_tensor(modified_adj_noise, device=self.device) e, v = torch.symeig(adj_norm_noise, eigenvectors=True) eigen_mse = torch.norm(ori_e - e) eigen_self = torch.norm(e) # low-rank loss in GF-attack idx = torch.argsort(e)[:128] mask = torch.zeros_like(e).bool() mask[idx] = True eigen_gf = torch.pow(torch.norm(e * mask, p=2), 2) * torch.pow( torch.norm(torch.matmul(v.detach() * mask, ori_features), p=2), 2) reg_loss = 0 if distance_type == 'l2': reg_loss = eigen_mse / eigen_norm elif distance_type == 'normDiv': reg_loss = eigen_self / eigen_norm elif distance_type == 'gf': reg_loss = eigen_gf else: exit(f'unknown distance metric: {distance_type}') if verbose and t % 20 == 0: loss_target, acc_target = calc_acc(output, labels, idx_target) print( '-- Epoch {}, '.format(t), 'ptb budget/true = {:.1f}/{:.1f}'.format( n_perturbations, torch.clamp(self.adj_changes, 0, 1).sum()), 'l/r/m = {:.4f}/{:.4f}/{:.4f}'.format(l, r, m), 'class loss = {:.4f} | '.format(task_loss.item()), 'reg loss = {:.4f} | '.format(reg_loss.item()), 'mse_norm = {:4f} | '.format(eigen_norm), 'eigen_mse = {:.4f} | '.format(eigen_mse), 'eigen_self = {:.4f} | '.format(eigen_self), 'acc/mis = {:.4f}/{:.4f}'.format(acc_target, 1 - acc_target)) self.loss = self.loss_weight * task_loss + self.regularization_weight * reg_loss adj_grad = torch.autograd.grad(self.loss, self.adj_changes)[0] if self.loss_type == 'CE': lr = att_lr / np.sqrt(t + 1) self.adj_changes.data.add_(lr * adj_grad) if self.loss_type == 'CW': lr = att_lr / np.sqrt(t + 1) self.adj_changes.data.add_(lr * adj_grad) # return self.adj_changes.cpu().detach().numpy() if verbose and t % 20 == 0: print('budget/true={:.1f}/{:.1f}'.format( n_perturbations, torch.clamp(self.adj_changes, 0, 1).sum())) if sample_type == 'sample': l, r, m = self.projection(n_perturbations) elif sample_type == 'greedy': self.greedy(n_perturbations) elif sample_type == 'greedy2': self.greedy2(n_perturbations) elif sample_type == 'greedy3': self.greedy3(n_perturbations) else: exit(f"unkown sample type {sample_type}") if verbose and t % 20 == 0: print('budget/true={:.1f}/{:.1f}'.format( n_perturbations, torch.clamp(self.adj_changes, 0, 1).sum())) if sample_type == 'sample': self.random_sample(ori_adj, ori_features, labels, idx_target, n_perturbations) elif sample_type == 'greedy': self.greedy(n_perturbations) elif sample_type == 'greedy2': self.greedy2(n_perturbations) elif sample_type == 'greedy3': self.greedy3(n_perturbations) else: exit(f"unkown sample type {sample_type}") print("final ptb budget/true= {:.1f}/{:.1f}".format( n_perturbations, self.adj_changes.sum())) self.modified_adj = self.get_modified_adj(ori_adj).detach() self.check_adj_tensor(self.modified_adj) # for sanity check ori_adj_norm = utils.normalize_adj_tensor(ori_adj, device=self.device) ori_e, ori_v = torch.symeig(ori_adj_norm, eigenvectors=True) adj_norm = utils.normalize_adj_tensor(self.modified_adj, device=self.device) e, v = torch.symeig(adj_norm, eigenvectors=True) self.adj = ori_adj.detach() self.labels = labels.detach() self.ori_e = ori_e self.ori_v = ori_v self.e = e self.v = v def greedy(self, n_perturbations): s = self.adj_changes.cpu().detach().numpy() # l = min(s) # r = max(s) # noise = np.random.normal((l+r)/2, 0.1*(r-l), s.shape) # s += noise s_vec = np.squeeze(np.reshape(s, (1, -1))) # max_index = (-np.absolute(s_vec)).argsort()[:n_perturbations] max_index = (-s_vec).argsort()[:n_perturbations] mask = np.zeros_like(s_vec) mask[max_index] = 1.0 best_s = np.reshape(mask, s.shape) self.adj_changes.data.copy_( torch.clamp(torch.tensor(best_s), min=0, max=1)) def greedy3(self, n_perturbations): s = self.adj_changes.cpu().detach().numpy() s_vec = np.squeeze(np.reshape(s, (1, -1))) # max_index = (-np.absolute(s_vec)).argsort()[:n_perturbations] max_index = (s_vec).argsort()[:n_perturbations] mask = np.zeros_like(s_vec) mask[max_index] = 1.0 best_s = np.reshape(mask, s.shape) self.adj_changes.data.copy_( torch.clamp(torch.tensor(best_s), min=0, max=1)) def greedy2(self, n_perturbations): s = self.adj_changes.cpu().detach().numpy() l = min(s) r = max(s) noise = np.random.normal((l + r) / 2, 0.4 * (r - l), s.shape) s += noise s_vec = np.squeeze(np.reshape(s, (1, -1))) max_index = (-np.absolute(s_vec)).argsort()[:n_perturbations] mask = np.zeros_like(s_vec) mask[max_index] = 1.0 best_s = np.reshape(mask, s.shape) self.adj_changes.data.copy_( torch.clamp(torch.tensor(best_s), min=0, max=1)) def random_sample(self, ori_adj, ori_features, labels, idx_target, n_perturbations): K = 10 best_loss = -1000 victim_model = self.surrogate with torch.no_grad(): s = self.adj_changes.cpu().detach().numpy() for i in range(K): sampled = np.random.binomial(1, s) # randm = np.random.uniform(size=s.shape[0]) # sampled = np.where(s > randm, 1, 0) # if sampled.sum() > n_perturbations: # continue while sampled.sum() > n_perturbations: sampled = np.random.binomial(1, s) # if sampled.sum() > n_perturbations: # indices = np.transpose(np.nonzero(sampled)) # candidate_idx = [m for m in range(indices.shape[0])] # chosen_idx = np.random.choice(candidate_idx, n_perturbations, replace=False) # chosen_indices = indices[chosen_idx, :] # sampled = np.zeros_like(sampled) # for idx in chosen_indices: # sampled[idx] = 1 self.adj_changes.data.copy_(torch.tensor(sampled)) modified_adj = self.get_modified_adj(ori_adj) adj_norm = utils.normalize_adj_tensor(modified_adj, device=self.device) output = victim_model(ori_features, adj_norm) loss = self._loss(output[idx_target], labels[idx_target]) # loss = F.nll_loss(output[idx_target], labels[idx_target]) # print(loss) if best_loss < loss: best_loss = loss best_s = sampled self.adj_changes.data.copy_(torch.tensor(best_s)) def get_modified_adj(self, ori_adj): if self.complementary is None: self.complementary = (torch.ones_like(ori_adj) - torch.eye( self.nnodes).to(self.device) - ori_adj) - ori_adj m = torch.zeros((self.nnodes, self.nnodes)).to(self.device) tril_indices = torch.tril_indices(row=self.nnodes, col=self.nnodes, offset=-1) m[tril_indices[0], tril_indices[1]] = self.adj_changes m = m + m.t() modified_adj = self.complementary * m + ori_adj return modified_adj def add_random_noise(self, ori_adj): noise = 1e-4 * torch.rand(self.nnodes, self.nnodes).to(self.device) return (noise + torch.transpose(noise, 0, 1)) / 2.0 + ori_adj def projection2(self, n_perturbations): s = self.adj_changes.cpu().detach().numpy() n = np.squeeze(np.reshape(s, (1, -1))).shape[0] self.adj_changes.data.copy_( torch.clamp(self.adj_changes.data, min=0, max=n_perturbations / n)) return 0, 0, 0 def projection(self, n_perturbations): l, r, m = 0, 0, 0 if torch.clamp(self.adj_changes, 0, 1).sum() > n_perturbations: left = (self.adj_changes).min() right = self.adj_changes.max() miu = self.bisection(left, right, n_perturbations, epsilon=1e-5) l = left.cpu().detach() r = right.cpu().detach() m = miu.cpu().detach() self.adj_changes.data.copy_( torch.clamp(self.adj_changes.data - miu, min=0, max=1)) else: self.adj_changes.data.copy_( torch.clamp(self.adj_changes.data, min=0, max=1)) return l, r, m def _loss(self, output, labels): if self.loss_type == "CE": loss = F.nll_loss(output, labels) if self.loss_type == "CW": onehot = utils.tensor2onehot(labels) best_second_class = (output - 1000 * onehot).argmax(1).detach() margin = output[np.arange(len(output)), labels] - \ output[np.arange(len(output)), best_second_class] k = 0 loss = -torch.clamp(margin, min=k).mean() # loss = torch.clamp(margin.sum()+50, min=k) return loss def bisection(self, a, b, n_perturbations, epsilon): def func(x): return torch.clamp(self.adj_changes - x, 0, 1).sum() - n_perturbations miu = a while ((b - a) >= epsilon): miu = (a + b) / 2 # Check if middle point is root if (func(miu) == 0.0): b = miu break # Decide the side to repeat the steps if (func(miu) * func(a) < 0): b = miu else: a = miu # print("The value of root is : ","%.4f" % miu) return miu
class PGDAttack(BaseAttack): def __init__(self, model=None, nnodes=None, loss_type='CE', feature_shape=None, attack_structure=True, attack_features=False, device='cpu'): super(PGDAttack, self).__init__(model, nnodes, attack_structure, attack_features, device) assert attack_features or attack_structure, 'attack_features or attack_structure cannot be both False' self.loss_type = loss_type self.modified_adj = None self.modified_features = None if attack_structure: assert nnodes is not None, 'Please give nnodes=' self.adj_changes = Parameter( torch.FloatTensor(int(nnodes * (nnodes - 1) / 2))) self.adj_changes.data.fill_(0) if attack_features: assert True, 'Topology Attack does not support attack feature' self.complementary = None def attack(self, ori_features, ori_adj, labels, idx_train, n_perturbations, epochs=200, **kwargs): victim_model = self.surrogate self.sparse_features = sp.issparse(ori_features) ori_adj, ori_features, labels = utils.to_tensor(ori_adj, ori_features, labels, device=self.device) victim_model.eval() for t in tqdm(range(epochs)): modified_adj = self.get_modified_adj(ori_adj) adj_norm = utils.normalize_adj_tensor(modified_adj) output = victim_model(ori_features, adj_norm) # loss = F.nll_loss(output[idx_train], labels[idx_train]) loss = self._loss(output[idx_train], labels[idx_train]) adj_grad = torch.autograd.grad(loss, self.adj_changes)[0] if self.loss_type == 'CE': lr = 200 / np.sqrt(t + 1) self.adj_changes.data.add_(lr * adj_grad) if self.loss_type == 'CW': lr = 0.1 / np.sqrt(t + 1) self.adj_changes.data.add_(lr * adj_grad) self.projection(n_perturbations) self.random_sample(ori_adj, ori_features, labels, idx_train, n_perturbations) self.modified_adj = self.get_modified_adj(ori_adj).detach() def random_sample(self, ori_adj, ori_features, labels, idx_train, n_perturbations): K = 20 best_loss = -1000 victim_model = self.surrogate with torch.no_grad(): s = self.adj_changes.cpu().detach().numpy() for i in range(K): sampled = np.random.binomial(1, s) print(sampled.sum()) if sampled.sum() > n_perturbations: continue self.adj_changes.data.copy_(torch.tensor(sampled)) modified_adj = self.get_modified_adj(ori_adj) adj_norm = utils.normalize_adj_tensor(modified_adj) output = victim_model(ori_features, adj_norm) loss = self._loss(output[idx_train], labels[idx_train]) # loss = F.nll_loss(output[idx_train], labels[idx_train]) print(loss) if best_loss < loss: best_loss = loss best_s = sampled self.adj_changes.data.copy_(torch.tensor(best_s)) def _loss(self, output, labels): if self.loss_type == "CE": loss = F.nll_loss(output, labels) if self.loss_type == "CW": onehot = utils.tensor2onehot(labels) best_second_class = (output - 1000 * onehot).argmax(1) margin = output[np.arange(len(output)), labels] - \ output[np.arange(len(output)), best_second_class] k = 0 loss = -torch.clamp(margin, min=k).mean() # loss = torch.clamp(margin.sum()+50, min=k) return loss def projection(self, n_perturbations): # projected = torch.clamp(self.adj_changes, 0, 1) if torch.clamp(self.adj_changes, 0, 1).sum() > n_perturbations: left = (self.adj_changes - 1).min() right = self.adj_changes.max() miu = self.bisection(left, right, n_perturbations, epsilon=1e-5) self.adj_changes.data.copy_( torch.clamp(self.adj_changes.data - miu, min=0, max=1)) else: self.adj_changes.data.copy_( torch.clamp(self.adj_changes.data, min=0, max=1)) def get_modified_adj(self, ori_adj): if self.complementary is None: self.complementary = (torch.ones_like(ori_adj) - torch.eye( self.nnodes).to(self.device) - ori_adj) - ori_adj m = torch.zeros((self.nnodes, self.nnodes)).to(self.device) tril_indices = torch.tril_indices(row=self.nnodes - 1, col=self.nnodes - 1, offset=0) m[tril_indices[0], tril_indices[1]] = self.adj_changes # m += m.t() m = m + m.t() modified_adj = self.complementary * m + ori_adj return modified_adj def bisection(self, a, b, n_perturbations, epsilon): def func(x): return torch.clamp(self.adj_changes - x, 0, 1).sum() - n_perturbations miu = a while ((b - a) >= epsilon): miu = (a + b) / 2 # Check if middle point is root if (func(miu) == 0.0): break # Decide the side to repeat the steps if (func(miu) * func(a) < 0): b = miu else: a = miu # print("The value of root is : ","%.4f" % miu) return miu