def meta_grad(self): modified_adj, modified_nx = self.adj_tensor, self.x_tensor adj_tensor, x_tensor = self.adj_tensor, self.x_tensor persistent = self.structure_attack and self.feature_attack with tf.GradientTape(persistent=persistent) as tape: if self.structure_attack: modified_adj = self.get_perturbed_adj(adj_tensor, self.adj_changes) if self.feature_attack: modified_nx = self.get_perturbed_x(x_tensor, self.x_changes) adj_norm = gf.normalize_adj_tensor(modified_adj) output = self.forward(modified_nx, adj_norm) logit_labeled = tf.gather(output, self.train_nodes) logit_unlabeled = tf.gather(output, self.unlabeled_nodes) loss_labeled = self.loss_fn(self.labels_train, logit_labeled) loss_unlabeled = self.loss_fn(self.self_training_labels, logit_unlabeled) attack_loss = self.lambda_ * loss_labeled + (1 - self.lambda_) * loss_unlabeled adj_grad, x_grad = None, None if self.feature_attack: x_grad = tape.gradient(attack_loss, self.x_changes) if self.structure_attack: adj_grad = tape.gradient(attack_loss, self.adj_changes) return x_grad, adj_grad
def compute_structure_gradients(self, adj, x, target_index, target_label): with tf.GradientTape() as tape: tape.watch(adj) adj_norm = gf.normalize_adj_tensor(adj) logit = self.surrogate([x, adj_norm]) logit = tf.gather(logit, target_index) loss = self.loss_fn(target_label, logit) gradients = tape.gradient(loss, adj) return gradients
def inner_train(self, x, adj): self.initialize() adj_norm = gf.normalize_adj_tensor(adj) for it in range(self.epochs): weight_grads = self.train_step(x, adj_norm, self.train_nodes, self.labels_train) for v, g in zip(self.w_velocities, weight_grads): v.assign(self.momentum * v + g) for w, v in zip(self.weights, self.w_velocities): w.assign_sub(self.lr * v)
def compute_gradients(self, modified_adj, adj_changes, target_index, target_label): with tf.GradientTape() as tape: tape.watch(adj_changes) adj = modified_adj + adj_changes adj_norm = gf.normalize_adj_tensor(adj) logit = self.surrogate([self.x_tensor, adj_norm]) logit = tf.gather(logit, target_index) loss = self.loss_fn(target_label, logit, from_logits=True) gradients = tape.gradient(loss, adj_changes) return gradients
def compute_loss(self, victim_nodes): adj = self.get_perturbed_adj() adj_norm = gf.normalize_adj_tensor(adj) logit = self.surrogate(self.x_tensor, adj_norm)[victim_nodes] if self.CW_loss: logit = F.log_softmax(logit, dim=1) best_wrong_class = (logit - 1000 * self.label_matrix).argmax(1) indices_attack = torch.stack([self.range_idx, best_wrong_class]) margin = logit[self.indices_real] - logit[indices_attack] + 0.2 loss = -torch.clamp(margin, min=0.) return loss.mean() else: loss = self.loss_fn(logit, self.victim_labels) return loss
def compute_gradients(self, modified_adj, modified_nx, victim_nodes, victim_labels): # TODO persistent=False persistent = self.structure_attack and self.feature_attack with tf.GradientTape(persistent=persistent) as tape: adj_norm = gf.normalize_adj_tensor(modified_adj) logit = self.surrogate([modified_nx, adj_norm]) logit = tf.gather(logit, victim_nodes) loss = self.loss_fn(victim_labels, logit, from_logits=True) adj_grad, x_grad = None, None if self.structure_attack: adj_grad = tape.gradient(loss, modified_adj) if self.feature_attack: x_grad = tape.gradient(loss, modified_nx) return adj_grad, x_grad
def compute_gradients(self, modified_adj, modified_nx, target_index, target_label): with tf.GradientTape(persistent=True) as tape: adj_norm = gf.normalize_adj_tensor(modified_adj) logit = self.surrogate([modified_nx, adj_norm]) logit = tf.gather(logit, target_index) loss = self.loss_fn(target_label, logit, from_logits=True) adj_grad, x_grad = None, None if self.structure_attack: adj_grad = tape.gradient(loss, modified_adj) if self.feature_attack: x_grad = tape.gradient(loss, modified_nx) return adj_grad, x_grad
def compute_loss(self, victim_nodes): adj = self.get_perturbed_adj() adj_norm = gf.normalize_adj_tensor(adj) logit = self.surrogate([self.x_tensor, adj_norm]) logit = tf.gather(logit, victim_nodes) logit = softmax(logit) if self.CW_loss: best_wrong_class = tf.argmax(logit - self.label_matrix, axis=1, output_type=self.intx) indices_attack = tf.stack([self.range_idx, best_wrong_class], axis=1) margin = tf.gather_nd(logit, indices_attack) - tf.gather_nd(logit, self.indices_real) - 0.2 loss = tf.minimum(margin, 0.) return tf.reduce_sum(loss) else: loss = self.loss_fn(self.victim_labels, logit) return tf.reduce_mean(loss)
def process(self, surrogate, reset=True): if isinstance(surrogate, gg.gallery.nodeclas.Trainer): surrogate = surrogate.model adj, x = self.graph.adj_matrix, self.graph.node_attr self.nodes_set = set(range(self.num_nodes)) self.features_set = np.arange(self.num_attrs) with tf.device(self.device): self.surrogate = surrogate self.loss_fn = SparseCategoricalCrossentropy(from_logits=True) self.x_tensor = gf.astensor(x) self.adj_tensor = gf.astensor(adj.A) self.adj_norm = gf.normalize_adj_tensor(self.adj_tensor) if reset: self.reset() return self
def meta_grad(self): self.initialize() modified_adj, modified_nx = self.adj_tensor, self.x_tensor adj_tensor, x_tensor = self.adj_tensor, self.x_tensor adj_grad_sum, x_grad_sum = self.adj_grad_sum, self.x_grad_sum optimizer = self.optimizer for it in tf.range(self.epochs): with tf.GradientTape(persistent=True) as tape: if self.structure_attack: modified_adj = self.get_perturbed_adj(adj_tensor, self.adj_changes) if self.feature_attack: modified_nx = self.get_perturbed_x(x_tensor, self.x_changes) adj_norm = gf.normalize_adj_tensor(modified_adj) output = self.forward(modified_nx, adj_norm) logit_labeled = tf.gather(output, self.train_nodes) logit_unlabeled = tf.gather(output, self.unlabeled_nodes) loss_labeled = self.loss_fn(self.labels_train, logit_labeled) loss_unlabeled = self.loss_fn(self.self_training_labels, logit_unlabeled) attack_loss = self.lambda_ * loss_labeled + (1 - self.lambda_) * loss_unlabeled adj_grad, x_grad = None, None gradients = tape.gradient(loss_labeled, self.weights) optimizer.apply_gradients(zip(gradients, self.weights)) if self.structure_attack: adj_grad = tape.gradient(attack_loss, self.adj_changes) adj_grad_sum.assign_add(adj_grad) if self.feature_attack: x_grad = tape.gradient(attack_loss, self.x_changes) x_grad_sum.assign_add(x_grad) del tape return x_grad_sum, adj_grad_sum
def attack(self, num_budgets=0.05, structure_attack=True, feature_attack=False, ll_constraint=False, ll_cutoff=0.004, disable=False): super().attack(num_budgets, structure_attack, feature_attack) if ll_constraint: raise NotImplementedError( "`log_likelihood_constraint` has not been well tested." " Please set `ll_constraint=False` to achieve a better performance." ) if feature_attack and not self.graph.is_binary(): raise ValueError( "Attacks on the node features are currently only supported for binary attributes." ) modified_adj, modified_nx = self.adj_tensor, self.x_tensor adj_tensor, x_tensor = self.adj_tensor, self.x_tensor adj_changes, x_changes = self.adj_changes, self.x_changes adj_flips, nattr_flips = self.adj_flips, self.nattr_flips self.inner_train(modified_nx, modified_adj) for it in tqdm(range(self.num_budgets), desc='Peturbing Graph', disable=disable): if structure_attack: modified_adj = self.get_perturbed_adj(adj_tensor, adj_changes) if feature_attack: modified_nx = self.get_perturbed_x(x_tensor, x_changes) adj_norm = gf.normalize_adj_tensor(modified_adj) self.inner_train(modified_nx, adj_norm) x_grad, adj_grad = self.meta_grad(modified_nx, adj_norm) x_meta_score = torch.tensor(0.0) adj_meta_score = torch.tensor(0.0) if structure_attack: adj_meta_score = self.structure_score(modified_adj, adj_grad, ll_constraint, ll_cutoff) if feature_attack: x_meta_score = self.feature_score(modified_nx, x_grad) if adj_meta_score.max() >= x_meta_score.max(): adj_meta_argmax = torch.argmax(adj_meta_score) row, col = unravel_index(adj_meta_argmax, self.num_nodes) self.adj_changes.data[row][ col] += -2 * modified_adj[row][col] + 1 self.adj_changes.data[col][ row] += -2 * modified_adj[row][col] + 1 adj_flips.append((row, col)) else: x_meta_argmax = torch.argmax(x_meta_score) row, col = unravel_index(x_meta_argmax, self.num_attrs) self.x_changes.data[row][col] += -2 * modified_nx[row][col] + 1 nattr_flips.append((row, col))