def optimize(self, _input: torch.Tensor, noise: torch.Tensor = None, pgd_alpha: float = None, pgd_eps: float = None, iteration: int = None, loss_fn: Callable[[torch.Tensor], torch.Tensor] = None, output: Union[int, list[str]] = None, add_noise_fn=None, random_init: bool = False, **kwargs) -> tuple[torch.Tensor, int]: # ------------------------------ Parameter Initialization ---------------------------------- # pgd_alpha = pgd_alpha if pgd_alpha is not None else self.pgd_alpha pgd_eps = pgd_eps if pgd_eps is not None else self.pgd_eps iteration = iteration if iteration is not None else self.iteration loss_fn = loss_fn if loss_fn is not None else self.loss_fn add_noise_fn = add_noise_fn if add_noise_fn is not None else add_noise if random_init: noise = pgd_alpha * (torch.rand_like(_input) * 2 - 1) else: noise = noise if noise is not None else torch.zeros_like(_input[0] if self.universal else _input) output = self.get_output(output) # ----------------------------------------------------------------------------------------- # if 'start' in output: self.output_info(_input=_input, noise=noise, mode='start', loss_fn=loss_fn, **kwargs) if iteration == 0 or pgd_alpha == 0.0 or pgd_eps == 0.0: return _input, None X = add_noise_fn(_input=_input, noise=noise, batch=self.universal) # ----------------------------------------------------------------------------------------- # for _iter in range(iteration): if self.early_stop_check(X=X, loss_fn=loss_fn, **kwargs): if 'end' in output: self.output_info(_input=_input, noise=noise, mode='end', loss_fn=loss_fn, **kwargs) return X.detach(), _iter + 1 if self.grad_method == 'hess' and _iter % self.hess_p == 0: self.hess = self.calc_hess(loss_fn, X, sigma=self.sigma, hess_b=self.hess_b, hess_lambda=self.hess_lambda) self.hess /= self.hess.norm(p=2) grad = self.calc_grad(loss_fn, X) if self.grad_method != 'white' and 'middle' in output: real_grad = self.whitebox_grad(loss_fn, X) prints('cos<real, est> = ', cos_sim(grad.sign(), real_grad.sign()), indent=self.indent + 2) if self.universal: grad = grad.mean(dim=0) noise.data = (noise - pgd_alpha * torch.sign(grad)).data noise.data = self.projector(noise, pgd_eps, norm=self.norm).data X = add_noise_fn(_input=_input, noise=noise, batch=self.universal) if self.universal: noise.data = (X - _input).mode(dim=0)[0].data else: noise.data = (X - _input).data if 'middle' in output: self.output_info(_input=_input, noise=noise, mode='middle', _iter=_iter, iteration=iteration, loss_fn=loss_fn, **kwargs) if 'end' in output: self.output_info(_input=_input, noise=noise, mode='end', loss_fn=loss_fn, **kwargs) return X.detach(), None
def get_detect_result(self, seq_centers: torch.Tensor, target=None): pair_seq = -torch.ones( self.attack.iteration - 1, dtype=torch.long, device=env['device']) detect_prob = torch.ones( self.model.num_classes) / self.model.num_classes for i in range(len(seq_centers) - 1): X: torch.Tensor = seq_centers[i].clone() dist_list = torch.zeros(self.model.num_classes) for _class in range(self.model.num_classes): _label = _class * torch.ones( len(X), dtype=torch.long, device=X.device) X.requires_grad_() loss = self.model.loss(X, _label) grad = torch.autograd.grad(loss, X)[0] X.requires_grad = False grad /= grad.abs().max() if self.active: noise_grad = torch.zeros(X.numel(), device=X.device) offset = (_class + i) % self.model.num_classes for multiplier in range( len(noise_grad) // self.model.num_classes): noise_grad[multiplier * self.model.num_classes + offset] = 1 noise_grad = noise_grad.view(X.shape) # noise_grad /= noise_grad.abs().max() grad = self.active_percent * noise_grad + \ (1 - self.active_percent) * grad grad.sign_() vec = seq_centers[i + 1] - X dist = cos_sim(-grad, vec) dist_list[_class] = dist if 'middle' in self.output and _class == target: print('sim <vec, real>: ', cos_sim(vec, -grad)) print('sim <est, real>: ', cos_sim(self.attack_grad_list[i], grad)) print('sim <vec, est>: ', cos_sim(vec, -self.attack_grad_list[i])) # todo: Use atanh for normalization after pytorch 1.6 detect_prob = torch.nn.functional.softmax( torch.log((2 / (1 - dist_list)).sub(1))) # detect_prob.div_(detect_prob.norm(p=2)) pair_seq[i] = detect_prob.argmax().item() return pair_seq
def get_center_class_pairs(self, candidate_centers: torch.Tensor, seq_centers: torch.Tensor, seq: torch.Tensor): pair_seq = [] for i in range(len(candidate_centers) - 1): sub_pair_seq = [] for point in candidate_centers[i]: # if self.active: # vec = seq_centers[i+1]-point # _result = vec.view(-1) # for j in range(len(_result)): # if _result[j] < 0 and j > i: # sub_pair_seq.append((j-i) % self.num_classes) # print(vec.view(-1)[:self.num_classes]) x = point.clone() dist_list = torch.zeros(self.model.num_classes) # print('bound: ', estimate_error + shift_dist) for _class in range(self.model.num_classes): x.requires_grad_() loss = self.model.loss(x, _class) grad = torch.autograd.grad(loss, x)[0] x.requires_grad = False grad.sign_() if self.active: noise_grad = torch.zeros_like(grad).flatten() offset = (_class + i) % self.model.num_classes for multiplier in range( int(len(noise_grad) / self.model.num_classes)): noise_grad[multiplier * self.model.num_classes + offset] = 1 noise_grad = noise_grad.view(grad.shape) grad = self.active_percent * noise_grad + \ (1 - self.active_percent) * grad grad.sign_() vec = seq_centers[i + 1] - point dist = cos_sim(-grad, vec) dist_list[_class] = dist sub_pair_seq.append(dist_list.argmax().item()) pair_seq.append(sub_pair_seq) return pair_seq