def detect(self): zeros = torch.zeros(self.attack.iteration - 1) result_dict = {'draw': zeros.clone(), 'win': zeros.clone(), 'lose': zeros.clone()} counter = 0 attack_succ_num = 0. detect_succ_num = 0. for i, data in enumerate(self.dataset.loader['test']): if counter >= 200: break print('img idx: ', i) _input, _label = self.model.remove_misclassify(data) if len(_label) == 0: print('misclassification') print() print('------------------------------------') continue target = self.model.generate_target(_input) result, _, _, attack_succ, detect_succ = self.inference(_input, target) for _iter in range(len(result)): result_dict[result[_iter]][_iter] += 1 counter += 1 if attack_succ != self.attack.iteration: attack_succ_num += 1 if detect_succ != self.attack.iteration: detect_succ_num += 1 attack_succ_rate = attack_succ_num / counter detect_succ_rate = detect_succ_num / counter print('draw: ', to_list(result_dict['draw'])) print('win : ', to_list(result_dict['win'])) print('lose: ', to_list(result_dict['lose'])) print() print('total: ', counter) print('attack succ rate: ', attack_succ_rate) print('detect succ rate: ', detect_succ_rate) print('------------------------------------')
def inference(self, _input: torch.Tensor, target: torch.Tensor): # ------------------------------- Init --------------------------------- # torch.manual_seed(env['seed']) if 'start' in self.output: self.attack.output_info(_input=_input, noise=torch.zeros_like(_input), target=target, loss_fn=lambda _X: self.model.loss(_X, target)) self.attack_grad_list: list[torch.Tensor] = [] # ------------------------ Attacker Seq -------------------------------- # seq = self.get_seq(_input, target) # Attacker cluster sequences (iter, query_num+1, C, H, W) seq_centers, seq_bias = self.get_center_bias(seq) # Defender cluster center estimate # seq_centers: (iter, 1, C, H, W) seq_bias: (iter) # seq_centers = seq[:, 0] # debug if 'start' in self.output: mean_error = (seq_centers[:, 0] - seq[:, 0]).abs().flatten(start_dim=1).max(dim=1)[0] print('Mean Shift Distance: '.ljust(25) + f'avg {mean_error.mean():<10.5f} min {mean_error.min():<10.5f} max {mean_error.max():<10.5f}') print('Bias Estimation: '.ljust(25) + f'avg {seq_bias.mean():<10.5f} min {seq_bias.min():<10.5f} max {seq_bias.max():<10.5f}') # candidate_centers = self.get_candidate_centers(seq, seq_centers, seq_bias) # abandoned # candidate_centers = seq_centers detect_result = self.get_detect_result(seq_centers, target=target) attack_result = self.model(seq[:, 0].squeeze()).argmax(dim=1) attack_succ = self.attack.iteration detect_succ = self.attack.iteration detect_true = True for i in range(self.attack.iteration - 1): if attack_result[i] == target and \ attack_result[min(i + 1, self.attack.iteration - 2)] == target and \ attack_succ == self.attack.iteration: attack_succ = i if detect_result[i] == detect_result[min(i + 1, self.attack.iteration - 2)] and detect_succ == self.attack.iteration and detect_true: if detect_result[i] == target: detect_succ = i else: detect_true = False if 'end' in self.output: # print('candidate centers: ', [len(i) for i in candidate_centers]) print('Detect Iter: ', detect_succ) prints(to_list(detect_result), indent=12) print('Attack Iter: ', attack_succ) prints(to_list(attack_result), indent=12) print() result = ['draw'] * (self.attack.iteration - 1) if attack_succ < detect_succ: for i in range(attack_succ, self.attack.iteration - 1): result[i] = 'lose' elif attack_succ > detect_succ: for i in range(detect_succ, self.attack.iteration - 1): result[i] = 'win' elif attack_succ == detect_succ: pass else: raise ValueError() return result, detect_result, attack_result, attack_succ, detect_succ
def output_info(self, _input: torch.Tensor, noise: torch.Tensor, target: torch.Tensor, **kwargs): super(PGD, self).output_info(_input, noise, **kwargs) # prints('Original class : ', to_list(_label), indent=self.indent) # prints('Original confidence: ', to_list(_confidence), indent=self.indent) with torch.no_grad(): _prob: torch.Tensor = self.model._model.softmax( self.model._model.classifier(_input + noise)) _confidence = _prob.gather(dim=1, index=target.unsqueeze(1)).flatten() prints('Target class : ', to_list(target), indent=self.indent) prints('Target confidence: ', to_list(_confidence), indent=self.indent)
def output_info(self, _input: torch.Tensor, noise: torch.Tensor, target: torch.Tensor, **kwargs): super().output_info(_input, noise, **kwargs) # prints('Original class : ', to_list(_label), indent=self.indent) # prints('Original confidence: ', to_list(_confidence), indent=self.indent) with torch.no_grad(): _confidence = self.model.get_target_prob(_input + noise, target) prints('Target class : ', to_list(target), indent=self.indent) prints('Target confidence: ', to_list(_confidence), indent=self.indent)
def check_neuron_jaccard(self, ratio=0.5) -> float: feats_list = [] poison_feats_list = [] with torch.no_grad(): for data in self.dataset.loader['valid']: _input, _label = self.model.get_data(data) poison_input = self.add_mark(_input) _feats = self.model.get_final_fm(_input) poison_feats = self.model.get_final_fm(poison_input) feats_list.append(_feats) poison_feats_list.append(poison_feats) feats_list = torch.cat(feats_list).mean(dim=0) poison_feats_list = torch.cat(poison_feats_list).mean(dim=0) length = int(len(feats_list) * ratio) _idx = set(to_list(feats_list.argsort(descending=True))[:length]) poison_idx = set(to_list(poison_feats_list.argsort(descending=True))[:length]) jaccard_idx = len(_idx & poison_idx) / len(_idx | poison_idx) return jaccard_idx
def benign_measure(self, validloader=None, batch_num=20): if validloader is None: validloader = self.model.dataset.loader['valid'] measure_list = [] for i, data in enumerate(validloader): _input, _label = self.model.get_data(data) if i >= batch_num: break measure = self.measure(_input, _label) measure_list.extend(to_list(measure)) return measure_list
def prune_step(self, mask: torch.Tensor, prune_num: int = 1): with torch.no_grad(): feats_list = [] for data in self.dataset.loader['valid']: _input, _label = self.model.get_data(data) _feats = self.model.get_final_fm(_input) feats_list.append(_feats) feats_list = torch.cat(feats_list).mean(dim=0) idx_rank = to_list(feats_list.argsort()) counter = 0 for idx in idx_rank: if mask[idx].norm(p=1) > 1e-6: mask[idx] = 0.0 counter += 1 print(f' {output_iter(counter, prune_num)} Prune {idx:4d} / {len(idx_rank):4d}') if counter >= min(prune_num, len(idx_rank)): break
def attack(self, epoch: int, **kwargs): # model._validate() total = 0 target_conf_list = [] target_acc_list = [] clean_acc_list = [] pgd_norm_list = [] pgd_alpha = 1.0 / 255 pgd_eps = 8.0 / 255 if self.dataset.name in ['cifar10', 'gtsrb', 'isic2018']: pgd_alpha = 1.0 / 255 pgd_eps = 8.0 / 255 if self.dataset.name in ['sample_imagenet', 'sample_vggface2']: pgd_alpha = 0.25 / 255 pgd_eps = 2.0 / 255 pgd_checker = PGD(pgd_alpha=pgd_alpha, pgd_eps=pgd_eps, iteration=8, dataset=self.dataset, model=self.model, target_idx=self.target_idx, stop_threshold=0.95) easy = 0 difficult = 0 normal = 0 loader = self.dataset.get_dataloader( mode='valid', batch_size=self.dataset.test_batch_size) if 'curvature' in self.__dict__.keys(): benign_curvature = self.curvature.benign_measure() tgt_curvature_list = [] org_curvature_list = [] if self.randomized_smooth: org_conf_list = [] tgt_conf_list = [] if 'magnet' in self.__dict__.keys(): org_magnet_list = [] tgt_magnet_list = [] for data in loader: print(easy, normal, difficult) if normal >= 100: break self.model.load() _input, _label = self.model.remove_misclassify(data) if len(_label) == 0: continue target_label = self.model.generate_target(_input, idx=self.target_idx) self.temp_input = _input self.temp_label = target_label _, _iter = pgd_checker.craft_example(_input) if _iter is None: difficult += 1 continue if _iter < 4: easy += 1 continue normal += 1 target_conf, target_acc, clean_acc = self.validate_fn() noise = torch.zeros_like(_input) poison_input = self.craft_example(_input=_input, _label=target_label, epoch=epoch, noise=noise, **kwargs) pgd_norm = float(noise.norm(p=float('inf'))) target_conf, target_acc, clean_acc = self.validate_fn() target_conf_list.append(target_conf) target_acc_list.append(target_acc) clean_acc_list.append(max(self.clean_acc - clean_acc, 0.0)) pgd_norm_list.append(pgd_norm) print( f'[{total+1} / 100]\n' f'target confidence: {np.mean(target_conf_list)}({np.std(target_conf_list)})\n' f'target accuracy: {np.mean(target_acc_list)}({np.std(target_acc_list)})\n' f'clean accuracy Drop: {np.mean(clean_acc_list)}({np.std(clean_acc_list)})\n' f'PGD Norm: {np.mean(pgd_norm_list)}({np.std(pgd_norm_list)})\n\n\n' ) org_conf = self.model.get_target_prob(_input=poison_input, target=_label) tgt_conf = self.model.get_target_prob(_input=poison_input, target=target_label) if 'curvature' in self.__dict__.keys(): org_curvature_list.extend( to_list(self.curvature.measure(poison_input, _label))) # type: ignore tgt_curvature_list.extend( to_list(self.curvature.measure( poison_input, target_label))) # type: ignore print('Curvature:') print( f' org_curvature: {ks_2samp(org_curvature_list, benign_curvature)}' ) # type: ignore print( f' tgt_curvature: {ks_2samp(tgt_curvature_list, benign_curvature)}' ) # type: ignore print() if self.randomized_smooth: org_new = self.model.get_target_prob(_input=poison_input, target=_label, randomized_smooth=True) tgt_new = self.model.get_target_prob(_input=poison_input, target=target_label, randomized_smooth=True) org_increase = (org_new - org_conf).clamp(min=0.0) tgt_decrease = (tgt_new - tgt_conf).clamp(min=0.0) org_conf_list.extend(to_list(org_increase)) # type: ignore tgt_conf_list.extend(to_list(tgt_decrease)) # type: ignore print('Randomized Smooth:') print(f' org_confidence: {np.mean(org_conf_list)}' ) # type: ignore print(f' tgt_confidence: {np.mean(tgt_conf_list)}' ) # type: ignore print() if 'magnet' in self.__dict__.keys(): poison_input = self.magnet(poison_input) org_new = self.model.get_target_prob(_input=poison_input, target=_label) tgt_new = self.model.get_target_prob(_input=poison_input, target=target_label) org_increase = (org_new - org_conf).clamp(min=0.0) tgt_decrease = (tgt_conf - tgt_new).clamp(min=0.0) org_magnet_list.extend(to_list(org_increase)) # type: ignore tgt_magnet_list.extend(to_list(tgt_decrease)) # type: ignore print('MagNet:') print(f' org_confidence: {np.mean(org_magnet_list)}' ) # type: ignore print(f' tgt_confidence: {np.mean(tgt_magnet_list)}' ) # type: ignore print() total += 1