コード例 #1
0
ファイル: autoattack.py プロジェクト: iamgroot42/auto-attack
class AutoAttack():
    def __init__(self, model, norm='Linf', eps=.3, seed=None, verbose=True,
                 attacks_to_run=['apgd-ce', 'apgd-dlr', 'fab', 'square'],
                 plus=False, is_tf_model=False, device='cuda', log_path=None):
        self.model = model
        self.norm = norm
        assert norm in ['Linf', 'L2']
        self.epsilon = eps
        self.seed = seed
        self.verbose = verbose
        if plus:
            attacks_to_run.extend(['apgd-t', 'fab-t'])
        self.attacks_to_run = attacks_to_run
        self.plus = plus
        self.is_tf_model = is_tf_model
        self.device = device
        self.logger = utils.Logger(log_path)
        
        if not self.is_tf_model:
            from autopgd_pt import APGDAttack
            self.apgd = APGDAttack(self.model, n_restarts=5, n_iter=100, verbose=False,
                eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device)
            
            from fab_pt import FABAttack
            self.fab = FABAttack(self.model, n_restarts=5, n_iter=100, eps=self.epsilon, seed=self.seed,
                norm=self.norm, verbose=False, device=self.device)
        
            from square import SquareAttack
            self.square = SquareAttack(self.model, p_init=.8, n_queries=5000, eps=self.epsilon, norm=self.norm,
                early_stop=True, n_restarts=1, seed=self.seed, verbose=False, device=self.device)
                
            from autopgd_pt import APGDAttack_targeted
            self.apgd_targeted = APGDAttack_targeted(self.model, n_restarts=1, n_iter=100, verbose=False,
                eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device)
    
        else:
            from autopgd_tf import APGDAttack
            self.apgd = APGDAttack(self.model, n_restarts=5, n_iter=100, verbose=False,
                eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device)
            
            from fab_tf import FABAttack
            self.fab = FABAttack(self.model, n_restarts=5, n_iter=100, eps=self.epsilon, seed=self.seed,
                norm=self.norm, verbose=False, device=self.device)
        
            from square import SquareAttack
            self.square = SquareAttack(self.model.predict, p_init=.8, n_queries=5000, eps=self.epsilon, norm=self.norm,
                early_stop=True, n_restarts=1, seed=self.seed, verbose=False, device=self.device)
                
            from autopgd_tf import APGDAttack_targeted
            self.apgd_targeted = APGDAttack_targeted(self.model, n_restarts=1, n_iter=100, verbose=False,
                eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device)
    
    def get_logits(self, x):
        if not self.is_tf_model:
            return self.model(x)
        else:
            return self.model.predict(x)
    
    def get_seed(self):
        return time.time() if self.seed is None else self.seed
    
    def run_standard_evaluation(self, x_orig, y_orig, bs=250):
        # update attacks list if plus activated or after initialization
        if self.plus:
            if not 'apgd-t' in self.attacks_to_run:
                self.attacks_to_run.extend(['apgd-t'])
            if not 'fab-t' in self.attacks_to_run:
                self.attacks_to_run.extend(['fab-t'])
        
        with torch.no_grad():
            # calculate accuracy
            n_batches = int(np.ceil(x_orig.shape[0] / bs))
            robust_flags = torch.zeros(x_orig.shape[0], dtype=torch.bool, device=x_orig.device)
            for batch_idx in range(n_batches):
                start_idx = batch_idx * bs
                end_idx = min( (batch_idx + 1) * bs, x_orig.shape[0])

                x = x_orig[start_idx:end_idx, :].clone().to(self.device)
                y = y_orig[start_idx:end_idx].clone().to(self.device)
                output = self.get_logits(x)
                correct_batch = y.eq(output.max(dim=1)[1])
                robust_flags[start_idx:end_idx] = correct_batch.detach().to(robust_flags.device)

            robust_accuracy = torch.sum(robust_flags).item() / x_orig.shape[0]
                
            if self.verbose:
                self.logger.log('initial accuracy: {:.2%}'.format(robust_accuracy))
                    
            x_adv = x_orig.clone().detach()
            startt = time.time()
            for attack in self.attacks_to_run:
                # item() is super important as pytorch int division uses floor rounding
                num_robust = torch.sum(robust_flags).item()

                if num_robust == 0:
                    break

                n_batches = int(np.ceil(num_robust / bs))

                robust_lin_idcs = torch.nonzero(robust_flags, as_tuple=False)
                if num_robust > 1:
                    robust_lin_idcs.squeeze_()
                
                for batch_idx in range(n_batches):
                    start_idx = batch_idx * bs
                    end_idx = min((batch_idx + 1) * bs, num_robust)

                    batch_datapoint_idcs = robust_lin_idcs[start_idx:end_idx]
                    if len(batch_datapoint_idcs.shape) > 1:
                        batch_datapoint_idcs.squeeze_(-1)
                    x = x_orig[batch_datapoint_idcs, :].clone().to(self.device)
                    y = y_orig[batch_datapoint_idcs].clone().to(self.device)

                    # make sure that x is a 4d tensor even if there is only a single datapoint left
                    if len(x.shape) == 3:
                        x.unsqueeze_(dim=0)
                    
                    # run attack
                    if attack == 'apgd-ce':
                        # apgd on cross-entropy loss
                        self.apgd.loss = 'ce'
                        self.apgd.seed = self.get_seed()
                        _, adv_curr = self.apgd.perturb(x, y, cheap=True)
                    
                    elif attack == 'apgd-dlr':
                        # apgd on dlr loss
                        self.apgd.loss = 'dlr'
                        self.apgd.seed = self.get_seed()
                        _, adv_curr = self.apgd.perturb(x, y, cheap=True)
                    
                    elif attack == 'fab':
                        # fab
                        self.fab.targeted = False
                        self.fab.seed = self.get_seed()
                        adv_curr = self.fab.perturb(x, y)
                    
                    elif attack == 'square':
                        # square
                        self.square.seed = self.get_seed()
                        _, adv_curr = self.square.perturb(x, y)
                    
                    elif attack == 'apgd-t':
                        # targeted apgd
                        self.apgd_targeted.seed = self.get_seed()
                        _, adv_curr = self.apgd_targeted.perturb(x, y, cheap=True)
                    
                    elif attack == 'fab-t':
                        # fab targeted
                        self.fab.targeted = True
                        self.fab.n_restarts = 1
                        self.fab.seed = self.get_seed()
                        adv_curr = self.fab.perturb(x, y)
                    
                    else:
                        raise ValueError('Attack not supported')
                
                    output = self.get_logits(adv_curr)
                    false_batch = ~y.eq(output.max(dim=1)[1]).to(robust_flags.device)
                    non_robust_lin_idcs = batch_datapoint_idcs[false_batch]
                    robust_flags[non_robust_lin_idcs] = False

                    x_adv[non_robust_lin_idcs] = adv_curr[false_batch].detach().to(x_adv.device)
                
                    if self.verbose:
                        num_non_robust_batch = torch.sum(false_batch)    
                        self.logger.log('{} - {}/{} - {} out of {} successfully perturbed'.format(
                            attack, batch_idx + 1, n_batches, num_non_robust_batch, x.shape[0]))
                
                robust_accuracy = torch.sum(robust_flags).item() / x_orig.shape[0]
                if self.verbose:
                    print('robust accuracy after {}: {:.2%} (total time {:.1f} s)'.format(
                        attack.upper(), robust_accuracy, time.time() - startt))
                    
            # final check
            if self.verbose:
                if self.norm == 'Linf':
                    res = (x_adv - x_orig).abs().view(x_orig.shape[0], -1).max(1)[0]
                elif self.norm == 'L2':
                    res = ((x_adv - x_orig) ** 2).view(x_orig.shape[0], -1).sum(-1).sqrt()
                self.logger.log('max {} perturbation: {:.5f}, nan in tensor: {}, max: {:.5f}, min: {:.5f}'.format(
                    self.norm, res.max(), (x_adv != x_adv).sum(), x_adv.max(), x_adv.min()))
                self.logger.log('robust accuracy: {:.2%}'.format(robust_accuracy))
        
        return x_adv
        
    def clean_accuracy(self, x_orig, y_orig, bs=250):
        n_batches = x_orig.shape[0] // bs
        acc = 0.
        for counter in range(n_batches):
            x = x_orig[counter * bs:min((counter + 1) * bs, x_orig.shape[0])].clone().to(self.device)
            y = y_orig[counter * bs:min((counter + 1) * bs, x_orig.shape[0])].clone().to(self.device)
            output = self.get_logits(x)
            acc += (output.max(1)[1] == y).float().sum()
            
        if self.verbose:
            print('clean accuracy: {:.2%}'.format(acc / x_orig.shape[0]))
        
        return acc.item() / x_orig.shape[0]
        
    def run_standard_evaluation_individual(self, x_orig, y_orig, bs=250):
        # update attacks list if plus activated after initialization
        if self.plus:
            if not 'apgd-t' in self.attacks_to_run:
                self.attacks_to_run.extend(['apgd-t'])
            if not 'fab-t' in self.attacks_to_run:
                self.attacks_to_run.extend(['fab-t'])
        
        l_attacks = self.attacks_to_run
        adv = {}
        self.plus = False
        verbose_indiv = self.verbose
        self.verbose = False
        
        for c in l_attacks:
            startt = time.time()
            self.attacks_to_run = [c]
            adv[c] = self.run_standard_evaluation(x_orig, y_orig, bs=bs)
            if verbose_indiv:    
                acc_indiv  = self.clean_accuracy(adv[c], y_orig, bs=bs)
                space = '\t \t' if c == 'fab' else '\t'
                self.logger.log('robust accuracy by {} {} {:.2%} \t (time attack: {:.1f} s)'.format(
                    c.upper(), space, acc_indiv,  time.time() - startt))
        
        return adv
        
    def cheap(self):
        self.apgd.n_restarts = 1
        self.fab.n_restarts = 1
        self.apgd_targeted.n_restarts = 1
        self.square.n_queries = 1000
        self.plus = False
コード例 #2
0
ファイル: attack.py プロジェクト: val-iisc/FLSS
class Attack():
    def __init__(self,
                 model,
                 eot_iter,
                 norm='Linf',
                 eps=.3,
                 restarts=5,
                 seed=None,
                 verbose=True,
                 attacks_to_run=['apgd-ce', 'apgd-dlr', 'fab', 'square', 'MM'],
                 plus=False,
                 is_tf_model=False,
                 device='cuda'):  #
        self.model = model
        self.norm = norm
        self.eot_iter = eot_iter
        assert norm in ['Linf', 'L2']
        self.epsilon = eps
        self.restarts = restarts
        self.seed = seed
        self.verbose = verbose
        if plus:
            attacks_to_run.extend(['apgd-t', 'fab-t'])
        self.attacks_to_run = attacks_to_run
        self.plus = plus
        self.is_tf_model = is_tf_model
        self.device = device

        from autopgd_pt import APGDAttack
        self.apgd = APGDAttack(self.model,
                               n_restarts=self.restarts,
                               n_iter=100,
                               verbose=False,
                               eps=self.epsilon,
                               norm=self.norm,
                               eot_iter=1,
                               rho=.75,
                               seed=self.seed,
                               device=self.device)

        from fab_pt import FABAttack
        self.fab = FABAttack(self.model,
                             n_restarts=self.restarts,
                             n_iter=100,
                             eps=self.epsilon,
                             seed=self.seed,
                             norm=self.norm,
                             verbose=False,
                             device=self.device)

        from square import SquareAttack
        self.square = SquareAttack(self.model,
                                   p_init=.8,
                                   n_queries=5000,
                                   eps=self.epsilon,
                                   norm=self.norm,
                                   n_restarts=1,
                                   seed=self.seed,
                                   verbose=False,
                                   device=self.device,
                                   resc_schedule=False)

        from autopgd_pt import APGDAttack_targeted
        self.apgd_targeted = APGDAttack_targeted(self.model,
                                                 n_restarts=1,
                                                 n_iter=100,
                                                 verbose=False,
                                                 eps=self.epsilon,
                                                 norm=self.norm,
                                                 eot_iter=1,
                                                 rho=.75,
                                                 seed=self.seed,
                                                 device=self.device)

    def get_logits(self, x):
        return self.model(x)

    def _pgd_whitebox(self,
                      X,
                      y,
                      epsilon=0.03137254,
                      num_steps=100,
                      step_size=0.007843137,
                      eot_iter=1):
        out = self.model(X)
        err = (out.data.max(1)[1] != y.data).float().sum()
        X_pgd = Variable(X.data, requires_grad=True)

        random_noise = torch.FloatTensor(*X_pgd.shape).uniform_(
            -epsilon, epsilon).cuda()
        X_pgd = Variable(X_pgd.data + random_noise, requires_grad=True)
        for _ in range(num_steps):
            summer_grad = torch.zeros(X_pgd.shape).cuda()
            for j in range(eot_iter):
                opt = optim.SGD([X_pgd], lr=1e-3)
                opt.zero_grad()
                with torch.enable_grad():
                    loss = nn.CrossEntropyLoss()(self.model(X_pgd), y)
                loss.backward()
                summer_grad = summer_grad + X_pgd.grad.data
            eta = step_size * ((summer_grad).sign())
            X_pgd = Variable(X_pgd.data + eta, requires_grad=True)
            eta = torch.clamp(X_pgd.data - X.data, -epsilon, epsilon)
            X_pgd = Variable(X.data + eta, requires_grad=True)
            X_pgd = Variable(torch.clamp(X_pgd, 0, 1.0), requires_grad=True)
        return X_pgd

    def get_seed(self):
        return time.time() if self.seed is None else self.seed

    def run_standard_evaluation(self,
                                x_orig,
                                y_orig,
                                lister_attack,
                                threshold,
                                bs=1000,
                                type_of_thresholder="original",
                                RR=10):

        with torch.no_grad():
            # calculate accuracy
            n_batches = int(np.ceil(x_orig.shape[0] / bs))
            robust_flags = torch.zeros(x_orig.shape[0],
                                       dtype=torch.bool,
                                       device=x_orig.device)
            #print("Robust_flag",len(robust_flags))
            for batch_idx in range(n_batches):
                start_idx = batch_idx * bs
                end_idx = min((batch_idx + 1) * bs, x_orig.shape[0])

                x = x_orig[start_idx:end_idx, :].clone().to(self.device)
                y = y_orig[start_idx:end_idx].clone().to(self.device)
                correct_batch = y.eq(y)
                robust_flags[start_idx:end_idx] = correct_batch.detach().to(
                    robust_flags.device)
            #print("Robust_flag",len(robust_flags))
            x_adv = x_orig.clone().detach()
            startt = time.time()
            attack = lister_attack[0]

            index_array_accepted_false_0 = []
            index_array_accepted_true_0 = []
            index_array_accepted_false = []
            index_array_accepted_true = []
            index_array_rejected = []

            num_robust = torch.sum(robust_flags).item()
            n_batches = int(np.ceil(num_robust / bs))
            robust_lin_idcs = torch.nonzero(robust_flags, as_tuple=False)
            if num_robust > 1:
                robust_lin_idcs.squeeze_()
            noise_mat = torch.Tensor(np.load('./mat.npy')).cuda()
            #print("Number of batches are:",n_batches)
            for batch_idx in range(n_batches):
                start_idx = batch_idx * bs
                end_idx = min((batch_idx + 1) * bs, num_robust)

                batch_datapoint_idcs = robust_lin_idcs[start_idx:end_idx]
                if len(batch_datapoint_idcs.shape) > 1:
                    batch_datapoint_idcs.squeeze_(-1)
                x = x_orig[batch_datapoint_idcs, :].clone().to(self.device)
                y = y_orig[batch_datapoint_idcs].clone().to(self.device)

                # make sure that x is a 4d tensor even if there is only a single datapoint left
                if len(x.shape) == 3:
                    x.unsqueeze_(dim=0)

                # run attack
                if attack == 'apgd_ce':
                    # apgd on cross-entropy loss
                    print("running apgd_ce")
                    self.apgd.loss = 'ce'
                    self.apgd.seed = self.get_seed()
                    _, adv_curr = self.apgd.perturb(x, y, cheap=True)

                elif attack == 'apgd_dlr':
                    print("running apgd_dlr")
                    # apgd on dlr loss
                    self.apgd.loss = 'dlr'
                    self.apgd.seed = self.get_seed()
                    _, adv_curr = self.apgd.perturb(x, y, cheap=True)

                elif attack == 'fab':
                    print("running fab")
                    # fab
                    self.fab.targeted = False
                    self.fab.seed = self.get_seed()
                    adv_curr = self.fab.perturb(x, y)

                elif attack == 'square':
                    # square
                    print("running square")
                    self.square.seed = self.get_seed()
                    adv_curr = self.square.perturb(x, y)

                elif attack == 'clean':
                    # square
                    print("running clean")
                    adv_curr = x

                elif attack == 'Gama_pgd':
                    print("running Gama_pgd")
                    with torch.enable_grad():
                        adv_curr = GAMA_PGD(self.model,
                                            x,
                                            y,
                                            eps=self.epsilon,
                                            eps_iter=2 * self.epsilon,
                                            bounds=np.array([[0, 1], [0, 1],
                                                             [0, 1]]),
                                            steps=100,
                                            w_reg=50,
                                            lin=25,
                                            SCHED=[60, 85],
                                            drop=10)
                    adv_curr = Variable(adv_curr).cuda()

                elif attack == 'Gama_MT':
                    #Max margin attack
                    print("running GAMA_MT")
                    with torch.enable_grad():
                        out_clean = self.model(x)
                        #print(out_clean)
                        topk = torch.topk(out_clean, 5 + 1)[1]
                        #print(topk)
                        adv_curr = GAMA_MT(self.model,
                                           x,
                                           y,
                                           eps=self.epsilon,
                                           eps_iter=2 * self.epsilon,
                                           bounds=np.array([[0, 1], [0, 1],
                                                            [0, 1]]),
                                           steps=100,
                                           w_reg=50,
                                           lin=25,
                                           SCHED=[60, 85],
                                           drop=10,
                                           rr=RR + 1,
                                           new_targ=topk[range(len(y)),
                                                         RR + 1])

                elif attack == 'MM':
                    #Max margin attack
                    print("running MM")
                    with torch.enable_grad():
                        adv_curr = MT(self.model,
                                      x,
                                      y,
                                      eps=self.epsilon,
                                      eps_iter=2 * self.epsilon,
                                      bounds=np.array([[0, 1], [0, 1], [0,
                                                                        1]]),
                                      steps=100,
                                      w_reg=0,
                                      lin=0,
                                      SCHED=[50, 75],
                                      drop=10,
                                      multi_tar=10)

                elif attack == 'MT':
                    print("running MT")
                    # apgd on dlr loss
                    #self.MT.loss = 'MT'
                    #self.apgd.seed = self.get_seed()
                    with torch.enable_grad():
                        adv_curr = MT(self.model,
                                      x,
                                      y,
                                      eps=self.epsilon,
                                      eps_iter=2 * self.epsilon,
                                      bounds=np.array([[0, 1], [0, 1], [0,
                                                                        1]]),
                                      steps=100,
                                      w_reg=0,
                                      lin=0,
                                      SCHED=[50, 75],
                                      drop=10,
                                      multi_tar=RR + 1)

                elif attack == 'pgd':
                    #pgd
                    print("running pgd")
                    adv_curr = self._pgd_whitebox(x,
                                                  y,
                                                  eot_iter=self.eot_iter,
                                                  epsilon=self.epsilon,
                                                  step_size=self.epsilon / 4)
                else:
                    raise ValueError('Attack not supported')

                maxa = np.zeros(len(adv_curr))
                output = np.zeros(len(adv_curr))
                lst_ind = np.zeros([len(adv_curr), 100])

                for i in range(100):
                    noise = noise_mat[i]
                    answer = self.model(adv_curr, noi=noise, noi_sample=0)

                    ans = answer.max(dim=1)[1]
                    for j in range(len(ans)):
                        lst_ind[j][ans[j]] += 1
                        if lst_ind[j][ans[j]] > maxa[j]:
                            maxa[j] = lst_ind[j][ans[j]]
                            output[j] = ans[j]
                y_label = y.detach().cpu().long().numpy()

                for i in range(len(ans)):
                    # Finding the 0% correct and incorrect samples
                    if output[i] == y_label[i]:
                        index_array_accepted_true_0.append(
                            (batch_idx * bs + i))
                    else:
                        index_array_accepted_false_0.append(
                            (batch_idx * bs + i))
                    #Running the original rejection
                    if type_of_thresholder == "original":
                        if maxa[i] <= int(threshold):
                            index_array_rejected.append((batch_idx * bs + i))
                        else:
                            if output[i] == y_label[i]:
                                index_array_accepted_true.append(
                                    (batch_idx * bs + i))
                            else:
                                index_array_accepted_false.append(
                                    (batch_idx * bs + i))

                false_batch = ~y.eq(torch.Tensor(output).cuda()).to(
                    robust_flags.device)
                non_robust_lin_idcs = batch_datapoint_idcs[false_batch]
                robust_flags[non_robust_lin_idcs] = False

                x_adv[non_robust_lin_idcs] = adv_curr[false_batch].detach().to(
                    x_adv.device)

                if self.verbose:
                    num_non_robust_batch = torch.sum(false_batch)
                    print('{} - {}/{} - {} out of {} successfully perturbed'.
                          format(attack, batch_idx + 1, n_batches,
                                 num_non_robust_batch, x.shape[0]))
        return np.array(index_array_accepted_true_0), np.array(
            index_array_accepted_false_0
        ), np.array(index_array_rejected), np.array(
            index_array_accepted_true), np.array(index_array_accepted_false)

    def cheap(self):
        self.apgd.n_restarts = 1
        self.fab.n_restarts = 1
        self.apgd_targeted.n_restarts = 1
        self.square.n_queries = 1000
        self.square.resc_schedule = True
        self.plus = False