Ejemplo n.º 1
0
 def detect(self):
     zeros = torch.zeros(self.attack.iteration - 1)
     result_dict = {'draw': zeros.clone(), 'win': zeros.clone(), 'lose': zeros.clone()}
     counter = 0
     attack_succ_num = 0.
     detect_succ_num = 0.
     for i, data in enumerate(self.dataset.loader['test']):
         if counter >= 200:
             break
         print('img idx: ', i)
         _input, _label = self.model.remove_misclassify(data)
         if len(_label) == 0:
             print('misclassification')
             print()
             print('------------------------------------')
             continue
         target = self.model.generate_target(_input)
         result, _, _, attack_succ, detect_succ = self.inference(_input, target)
         for _iter in range(len(result)):
             result_dict[result[_iter]][_iter] += 1
         counter += 1
         if attack_succ != self.attack.iteration:
             attack_succ_num += 1
         if detect_succ != self.attack.iteration:
             detect_succ_num += 1
         attack_succ_rate = attack_succ_num / counter
         detect_succ_rate = detect_succ_num / counter
         print('draw: ', to_list(result_dict['draw']))
         print('win : ', to_list(result_dict['win']))
         print('lose: ', to_list(result_dict['lose']))
         print()
         print('total: ', counter)
         print('attack succ rate: ', attack_succ_rate)
         print('detect succ rate: ', detect_succ_rate)
         print('------------------------------------')
Ejemplo n.º 2
0
    def inference(self, _input: torch.Tensor, target: torch.Tensor):
        # ------------------------------- Init --------------------------------- #
        torch.manual_seed(env['seed'])
        if 'start' in self.output:
            self.attack.output_info(_input=_input, noise=torch.zeros_like(_input), target=target,
                                    loss_fn=lambda _X: self.model.loss(_X, target))
        self.attack_grad_list: list[torch.Tensor] = []
        # ------------------------ Attacker Seq -------------------------------- #
        seq = self.get_seq(_input, target)  # Attacker cluster sequences (iter, query_num+1, C, H, W)
        seq_centers, seq_bias = self.get_center_bias(seq)  # Defender cluster center estimate
        # seq_centers: (iter, 1, C, H, W)   seq_bias: (iter)
        # seq_centers = seq[:, 0]  # debug
        if 'start' in self.output:
            mean_error = (seq_centers[:, 0] - seq[:, 0]).abs().flatten(start_dim=1).max(dim=1)[0]
            print('Mean Shift Distance: '.ljust(25) +
                  f'avg {mean_error.mean():<10.5f} min {mean_error.min():<10.5f} max {mean_error.max():<10.5f}')
            print('Bias Estimation: '.ljust(25) +
                  f'avg {seq_bias.mean():<10.5f} min {seq_bias.min():<10.5f} max {seq_bias.max():<10.5f}')
        # candidate_centers = self.get_candidate_centers(seq, seq_centers, seq_bias)  # abandoned
        # candidate_centers = seq_centers
        detect_result = self.get_detect_result(seq_centers, target=target)
        attack_result = self.model(seq[:, 0].squeeze()).argmax(dim=1)

        attack_succ = self.attack.iteration
        detect_succ = self.attack.iteration

        detect_true = True
        for i in range(self.attack.iteration - 1):
            if attack_result[i] == target and \
                    attack_result[min(i + 1, self.attack.iteration - 2)] == target and \
                    attack_succ == self.attack.iteration:
                attack_succ = i
            if detect_result[i] == detect_result[min(i + 1, self.attack.iteration - 2)] and detect_succ == self.attack.iteration and detect_true:
                if detect_result[i] == target:
                    detect_succ = i
                else:
                    detect_true = False
        if 'end' in self.output:
            # print('candidate centers: ', [len(i) for i in candidate_centers])
            print('Detect Iter: ', detect_succ)
            prints(to_list(detect_result), indent=12)
            print('Attack Iter: ', attack_succ)
            prints(to_list(attack_result), indent=12)
            print()
        result = ['draw'] * (self.attack.iteration - 1)
        if attack_succ < detect_succ:
            for i in range(attack_succ, self.attack.iteration - 1):
                result[i] = 'lose'
        elif attack_succ > detect_succ:
            for i in range(detect_succ, self.attack.iteration - 1):
                result[i] = 'win'
        elif attack_succ == detect_succ:
            pass
        else:
            raise ValueError()
        return result, detect_result, attack_result, attack_succ, detect_succ
Ejemplo n.º 3
0
 def output_info(self, _input: torch.Tensor, noise: torch.Tensor, target: torch.Tensor, **kwargs):
     super(PGD, self).output_info(_input, noise, **kwargs)
     # prints('Original class     : ', to_list(_label), indent=self.indent)
     # prints('Original confidence: ', to_list(_confidence), indent=self.indent)
     with torch.no_grad():
         _prob: torch.Tensor = self.model._model.softmax(
             self.model._model.classifier(_input + noise))
         _confidence = _prob.gather(dim=1, index=target.unsqueeze(1)).flatten()
     prints('Target   class     : ', to_list(target), indent=self.indent)
     prints('Target   confidence: ', to_list(_confidence), indent=self.indent)
Ejemplo n.º 4
0
 def output_info(self, _input: torch.Tensor, noise: torch.Tensor,
                 target: torch.Tensor, **kwargs):
     super().output_info(_input, noise, **kwargs)
     # prints('Original class     : ', to_list(_label), indent=self.indent)
     # prints('Original confidence: ', to_list(_confidence), indent=self.indent)
     with torch.no_grad():
         _confidence = self.model.get_target_prob(_input + noise, target)
     prints('Target   class     : ', to_list(target), indent=self.indent)
     prints('Target   confidence: ',
            to_list(_confidence),
            indent=self.indent)
Ejemplo n.º 5
0
    def check_neuron_jaccard(self, ratio=0.5) -> float:
        feats_list = []
        poison_feats_list = []
        with torch.no_grad():
            for data in self.dataset.loader['valid']:
                _input, _label = self.model.get_data(data)
                poison_input = self.add_mark(_input)

                _feats = self.model.get_final_fm(_input)
                poison_feats = self.model.get_final_fm(poison_input)
                feats_list.append(_feats)
                poison_feats_list.append(poison_feats)
        feats_list = torch.cat(feats_list).mean(dim=0)
        poison_feats_list = torch.cat(poison_feats_list).mean(dim=0)
        length = int(len(feats_list) * ratio)
        _idx = set(to_list(feats_list.argsort(descending=True))[:length])
        poison_idx = set(to_list(poison_feats_list.argsort(descending=True))[:length])
        jaccard_idx = len(_idx & poison_idx) / len(_idx | poison_idx)
        return jaccard_idx
Ejemplo n.º 6
0
 def benign_measure(self, validloader=None, batch_num=20):
     if validloader is None:
         validloader = self.model.dataset.loader['valid']
     measure_list = []
     for i, data in enumerate(validloader):
         _input, _label = self.model.get_data(data)
         if i >= batch_num:
             break
         measure = self.measure(_input, _label)
         measure_list.extend(to_list(measure))
     return measure_list
Ejemplo n.º 7
0
 def prune_step(self, mask: torch.Tensor, prune_num: int = 1):
     with torch.no_grad():
         feats_list = []
         for data in self.dataset.loader['valid']:
             _input, _label = self.model.get_data(data)
             _feats = self.model.get_final_fm(_input)
             feats_list.append(_feats)
         feats_list = torch.cat(feats_list).mean(dim=0)
         idx_rank = to_list(feats_list.argsort())
     counter = 0
     for idx in idx_rank:
         if mask[idx].norm(p=1) > 1e-6:
             mask[idx] = 0.0
             counter += 1
             print(f'    {output_iter(counter, prune_num)} Prune {idx:4d} / {len(idx_rank):4d}')
             if counter >= min(prune_num, len(idx_rank)):
                 break
Ejemplo n.º 8
0
 def attack(self, epoch: int, **kwargs):
     # model._validate()
     total = 0
     target_conf_list = []
     target_acc_list = []
     clean_acc_list = []
     pgd_norm_list = []
     pgd_alpha = 1.0 / 255
     pgd_eps = 8.0 / 255
     if self.dataset.name in ['cifar10', 'gtsrb', 'isic2018']:
         pgd_alpha = 1.0 / 255
         pgd_eps = 8.0 / 255
     if self.dataset.name in ['sample_imagenet', 'sample_vggface2']:
         pgd_alpha = 0.25 / 255
         pgd_eps = 2.0 / 255
     pgd_checker = PGD(pgd_alpha=pgd_alpha,
                       pgd_eps=pgd_eps,
                       iteration=8,
                       dataset=self.dataset,
                       model=self.model,
                       target_idx=self.target_idx,
                       stop_threshold=0.95)
     easy = 0
     difficult = 0
     normal = 0
     loader = self.dataset.get_dataloader(
         mode='valid', batch_size=self.dataset.test_batch_size)
     if 'curvature' in self.__dict__.keys():
         benign_curvature = self.curvature.benign_measure()
         tgt_curvature_list = []
         org_curvature_list = []
     if self.randomized_smooth:
         org_conf_list = []
         tgt_conf_list = []
     if 'magnet' in self.__dict__.keys():
         org_magnet_list = []
         tgt_magnet_list = []
     for data in loader:
         print(easy, normal, difficult)
         if normal >= 100:
             break
         self.model.load()
         _input, _label = self.model.remove_misclassify(data)
         if len(_label) == 0:
             continue
         target_label = self.model.generate_target(_input,
                                                   idx=self.target_idx)
         self.temp_input = _input
         self.temp_label = target_label
         _, _iter = pgd_checker.craft_example(_input)
         if _iter is None:
             difficult += 1
             continue
         if _iter < 4:
             easy += 1
             continue
         normal += 1
         target_conf, target_acc, clean_acc = self.validate_fn()
         noise = torch.zeros_like(_input)
         poison_input = self.craft_example(_input=_input,
                                           _label=target_label,
                                           epoch=epoch,
                                           noise=noise,
                                           **kwargs)
         pgd_norm = float(noise.norm(p=float('inf')))
         target_conf, target_acc, clean_acc = self.validate_fn()
         target_conf_list.append(target_conf)
         target_acc_list.append(target_acc)
         clean_acc_list.append(max(self.clean_acc - clean_acc, 0.0))
         pgd_norm_list.append(pgd_norm)
         print(
             f'[{total+1} / 100]\n'
             f'target confidence: {np.mean(target_conf_list)}({np.std(target_conf_list)})\n'
             f'target accuracy: {np.mean(target_acc_list)}({np.std(target_acc_list)})\n'
             f'clean accuracy Drop: {np.mean(clean_acc_list)}({np.std(clean_acc_list)})\n'
             f'PGD Norm: {np.mean(pgd_norm_list)}({np.std(pgd_norm_list)})\n\n\n'
         )
         org_conf = self.model.get_target_prob(_input=poison_input,
                                               target=_label)
         tgt_conf = self.model.get_target_prob(_input=poison_input,
                                               target=target_label)
         if 'curvature' in self.__dict__.keys():
             org_curvature_list.extend(
                 to_list(self.curvature.measure(poison_input,
                                                _label)))  # type: ignore
             tgt_curvature_list.extend(
                 to_list(self.curvature.measure(
                     poison_input, target_label)))  # type: ignore
             print('Curvature:')
             print(
                 f'    org_curvature: {ks_2samp(org_curvature_list, benign_curvature)}'
             )  # type: ignore
             print(
                 f'    tgt_curvature: {ks_2samp(tgt_curvature_list, benign_curvature)}'
             )  # type: ignore
             print()
         if self.randomized_smooth:
             org_new = self.model.get_target_prob(_input=poison_input,
                                                  target=_label,
                                                  randomized_smooth=True)
             tgt_new = self.model.get_target_prob(_input=poison_input,
                                                  target=target_label,
                                                  randomized_smooth=True)
             org_increase = (org_new - org_conf).clamp(min=0.0)
             tgt_decrease = (tgt_new - tgt_conf).clamp(min=0.0)
             org_conf_list.extend(to_list(org_increase))  # type: ignore
             tgt_conf_list.extend(to_list(tgt_decrease))  # type: ignore
             print('Randomized Smooth:')
             print(f'    org_confidence: {np.mean(org_conf_list)}'
                   )  # type: ignore
             print(f'    tgt_confidence: {np.mean(tgt_conf_list)}'
                   )  # type: ignore
             print()
         if 'magnet' in self.__dict__.keys():
             poison_input = self.magnet(poison_input)
             org_new = self.model.get_target_prob(_input=poison_input,
                                                  target=_label)
             tgt_new = self.model.get_target_prob(_input=poison_input,
                                                  target=target_label)
             org_increase = (org_new - org_conf).clamp(min=0.0)
             tgt_decrease = (tgt_conf - tgt_new).clamp(min=0.0)
             org_magnet_list.extend(to_list(org_increase))  # type: ignore
             tgt_magnet_list.extend(to_list(tgt_decrease))  # type: ignore
             print('MagNet:')
             print(f'    org_confidence: {np.mean(org_magnet_list)}'
                   )  # type: ignore
             print(f'    tgt_confidence: {np.mean(tgt_magnet_list)}'
                   )  # type: ignore
             print()
         total += 1