Example #1
0
    def meta_grad(self):

        modified_adj, modified_nx = self.adj_tensor, self.x_tensor
        adj_tensor, x_tensor = self.adj_tensor, self.x_tensor
        persistent = self.structure_attack and self.feature_attack

        with tf.GradientTape(persistent=persistent) as tape:
            if self.structure_attack:
                modified_adj = self.get_perturbed_adj(adj_tensor, self.adj_changes)

            if self.feature_attack:
                modified_nx = self.get_perturbed_x(x_tensor, self.x_changes)

            adj_norm = gf.normalize_adj_tensor(modified_adj)
            output = self.forward(modified_nx, adj_norm)
            logit_labeled = tf.gather(output, self.train_nodes)
            logit_unlabeled = tf.gather(output, self.unlabeled_nodes)

            loss_labeled = self.loss_fn(self.labels_train, logit_labeled)
            loss_unlabeled = self.loss_fn(self.self_training_labels, logit_unlabeled)

            attack_loss = self.lambda_ * loss_labeled + (1 - self.lambda_) * loss_unlabeled

        adj_grad, x_grad = None, None

        if self.feature_attack:
            x_grad = tape.gradient(attack_loss, self.x_changes)

        if self.structure_attack:
            adj_grad = tape.gradient(attack_loss, self.adj_changes)

        return x_grad, adj_grad
    def compute_structure_gradients(self, adj, x, target_index, target_label):

        with tf.GradientTape() as tape:
            tape.watch(adj)
            adj_norm = gf.normalize_adj_tensor(adj)
            logit = self.surrogate([x, adj_norm])
            logit = tf.gather(logit, target_index)
            loss = self.loss_fn(target_label, logit)

        gradients = tape.gradient(loss, adj)
        return gradients
Example #3
0
    def inner_train(self, x, adj):

        self.initialize()
        adj_norm = gf.normalize_adj_tensor(adj)

        for it in range(self.epochs):
            weight_grads = self.train_step(x, adj_norm, self.train_nodes, self.labels_train)

            for v, g in zip(self.w_velocities, weight_grads):
                v.assign(self.momentum * v + g)

            for w, v in zip(self.weights, self.w_velocities):
                w.assign_sub(self.lr * v)
Example #4
0
    def compute_gradients(self, modified_adj, adj_changes, target_index,
                          target_label):

        with tf.GradientTape() as tape:
            tape.watch(adj_changes)
            adj = modified_adj + adj_changes
            adj_norm = gf.normalize_adj_tensor(adj)
            logit = self.surrogate([self.x_tensor, adj_norm])
            logit = tf.gather(logit, target_index)
            loss = self.loss_fn(target_label, logit, from_logits=True)

        gradients = tape.gradient(loss, adj_changes)
        return gradients
Example #5
0
    def compute_loss(self, victim_nodes):
        adj = self.get_perturbed_adj()
        adj_norm = gf.normalize_adj_tensor(adj)
        logit = self.surrogate(self.x_tensor, adj_norm)[victim_nodes]

        if self.CW_loss:
            logit = F.log_softmax(logit, dim=1)
            best_wrong_class = (logit - 1000 * self.label_matrix).argmax(1)
            indices_attack = torch.stack([self.range_idx, best_wrong_class])
            margin = logit[self.indices_real] - logit[indices_attack] + 0.2
            loss = -torch.clamp(margin, min=0.)
            return loss.mean()
        else:
            loss = self.loss_fn(logit, self.victim_labels)

            return loss
Example #6
0
    def compute_gradients(self, modified_adj, modified_nx, victim_nodes, victim_labels):
        # TODO persistent=False
        persistent = self.structure_attack and self.feature_attack
        with tf.GradientTape(persistent=persistent) as tape:
            adj_norm = gf.normalize_adj_tensor(modified_adj)
            logit = self.surrogate([modified_nx, adj_norm])
            logit = tf.gather(logit, victim_nodes)
            loss = self.loss_fn(victim_labels, logit, from_logits=True)

        adj_grad, x_grad = None, None
        if self.structure_attack:
            adj_grad = tape.gradient(loss, modified_adj)

        if self.feature_attack:
            x_grad = tape.gradient(loss, modified_nx)

        return adj_grad, x_grad
Example #7
0
    def compute_gradients(self, modified_adj, modified_nx, target_index,
                          target_label):

        with tf.GradientTape(persistent=True) as tape:
            adj_norm = gf.normalize_adj_tensor(modified_adj)
            logit = self.surrogate([modified_nx, adj_norm])
            logit = tf.gather(logit, target_index)
            loss = self.loss_fn(target_label, logit, from_logits=True)

        adj_grad, x_grad = None, None

        if self.structure_attack:
            adj_grad = tape.gradient(loss, modified_adj)

        if self.feature_attack:
            x_grad = tape.gradient(loss, modified_nx)

        return adj_grad, x_grad
Example #8
0
    def compute_loss(self, victim_nodes):
        adj = self.get_perturbed_adj()
        adj_norm = gf.normalize_adj_tensor(adj)
        logit = self.surrogate([self.x_tensor, adj_norm])
        logit = tf.gather(logit, victim_nodes)
        logit = softmax(logit)

        if self.CW_loss:
            best_wrong_class = tf.argmax(logit - self.label_matrix, axis=1,
                                         output_type=self.intx)
            indices_attack = tf.stack([self.range_idx, best_wrong_class], axis=1)
            margin = tf.gather_nd(logit, indices_attack) - tf.gather_nd(logit, self.indices_real) - 0.2
            loss = tf.minimum(margin, 0.)
            return tf.reduce_sum(loss)
        else:
            loss = self.loss_fn(self.victim_labels, logit)

            return tf.reduce_mean(loss)
    def process(self, surrogate, reset=True):
        if isinstance(surrogate, gg.gallery.nodeclas.Trainer):
            surrogate = surrogate.model

        adj, x = self.graph.adj_matrix, self.graph.node_attr
        self.nodes_set = set(range(self.num_nodes))
        self.features_set = np.arange(self.num_attrs)

        with tf.device(self.device):
            self.surrogate = surrogate
            self.loss_fn = SparseCategoricalCrossentropy(from_logits=True)
            self.x_tensor = gf.astensor(x)
            self.adj_tensor = gf.astensor(adj.A)
            self.adj_norm = gf.normalize_adj_tensor(self.adj_tensor)

        if reset:
            self.reset()
        return self
Example #10
0
    def meta_grad(self):
        self.initialize()

        modified_adj, modified_nx = self.adj_tensor, self.x_tensor
        adj_tensor, x_tensor = self.adj_tensor, self.x_tensor
        adj_grad_sum, x_grad_sum = self.adj_grad_sum, self.x_grad_sum
        optimizer = self.optimizer

        for it in tf.range(self.epochs):

            with tf.GradientTape(persistent=True) as tape:
                if self.structure_attack:
                    modified_adj = self.get_perturbed_adj(adj_tensor, self.adj_changes)

                if self.feature_attack:
                    modified_nx = self.get_perturbed_x(x_tensor, self.x_changes)

                adj_norm = gf.normalize_adj_tensor(modified_adj)
                output = self.forward(modified_nx, adj_norm)
                logit_labeled = tf.gather(output, self.train_nodes)
                logit_unlabeled = tf.gather(output, self.unlabeled_nodes)

                loss_labeled = self.loss_fn(self.labels_train, logit_labeled)
                loss_unlabeled = self.loss_fn(self.self_training_labels, logit_unlabeled)

                attack_loss = self.lambda_ * loss_labeled + (1 - self.lambda_) * loss_unlabeled

            adj_grad, x_grad = None, None

            gradients = tape.gradient(loss_labeled, self.weights)
            optimizer.apply_gradients(zip(gradients, self.weights))

            if self.structure_attack:
                adj_grad = tape.gradient(attack_loss, self.adj_changes)
                adj_grad_sum.assign_add(adj_grad)

            if self.feature_attack:
                x_grad = tape.gradient(attack_loss, self.x_changes)
                x_grad_sum.assign_add(x_grad)

            del tape

        return x_grad_sum, adj_grad_sum
Example #11
0
    def attack(self,
               num_budgets=0.05,
               structure_attack=True,
               feature_attack=False,
               ll_constraint=False,
               ll_cutoff=0.004,
               disable=False):

        super().attack(num_budgets, structure_attack, feature_attack)

        if ll_constraint:
            raise NotImplementedError(
                "`log_likelihood_constraint` has not been well tested."
                " Please set `ll_constraint=False` to achieve a better performance."
            )

        if feature_attack and not self.graph.is_binary():
            raise ValueError(
                "Attacks on the node features are currently only supported for binary attributes."
            )

        modified_adj, modified_nx = self.adj_tensor, self.x_tensor
        adj_tensor, x_tensor = self.adj_tensor, self.x_tensor
        adj_changes, x_changes = self.adj_changes, self.x_changes
        adj_flips, nattr_flips = self.adj_flips, self.nattr_flips

        self.inner_train(modified_nx, modified_adj)

        for it in tqdm(range(self.num_budgets),
                       desc='Peturbing Graph',
                       disable=disable):

            if structure_attack:
                modified_adj = self.get_perturbed_adj(adj_tensor, adj_changes)

            if feature_attack:
                modified_nx = self.get_perturbed_x(x_tensor, x_changes)

            adj_norm = gf.normalize_adj_tensor(modified_adj)

            self.inner_train(modified_nx, adj_norm)

            x_grad, adj_grad = self.meta_grad(modified_nx, adj_norm)

            x_meta_score = torch.tensor(0.0)
            adj_meta_score = torch.tensor(0.0)

            if structure_attack:
                adj_meta_score = self.structure_score(modified_adj, adj_grad,
                                                      ll_constraint, ll_cutoff)
            if feature_attack:
                x_meta_score = self.feature_score(modified_nx, x_grad)

            if adj_meta_score.max() >= x_meta_score.max():
                adj_meta_argmax = torch.argmax(adj_meta_score)
                row, col = unravel_index(adj_meta_argmax, self.num_nodes)
                self.adj_changes.data[row][
                    col] += -2 * modified_adj[row][col] + 1
                self.adj_changes.data[col][
                    row] += -2 * modified_adj[row][col] + 1
                adj_flips.append((row, col))

            else:
                x_meta_argmax = torch.argmax(x_meta_score)
                row, col = unravel_index(x_meta_argmax, self.num_attrs)
                self.x_changes.data[row][col] += -2 * modified_nx[row][col] + 1
                nattr_flips.append((row, col))