Пример #1
0
    def optimize(self, _input: torch.Tensor, noise: torch.Tensor = None,
                 pgd_alpha: float = None, pgd_eps: float = None,
                 iteration: int = None, loss_fn: Callable[[torch.Tensor], torch.Tensor] = None,
                 output: Union[int, list[str]] = None, add_noise_fn=None,
                 random_init: bool = False, **kwargs) -> tuple[torch.Tensor, int]:
        # ------------------------------ Parameter Initialization ---------------------------------- #

        pgd_alpha = pgd_alpha if pgd_alpha is not None else self.pgd_alpha
        pgd_eps = pgd_eps if pgd_eps is not None else self.pgd_eps
        iteration = iteration if iteration is not None else self.iteration
        loss_fn = loss_fn if loss_fn is not None else self.loss_fn
        add_noise_fn = add_noise_fn if add_noise_fn is not None else add_noise
        if random_init:
            noise = pgd_alpha * (torch.rand_like(_input) * 2 - 1)
        else:
            noise = noise if noise is not None else torch.zeros_like(_input[0] if self.universal else _input)
        output = self.get_output(output)

        # ----------------------------------------------------------------------------------------- #

        if 'start' in output:
            self.output_info(_input=_input, noise=noise, mode='start', loss_fn=loss_fn, **kwargs)
        if iteration == 0 or pgd_alpha == 0.0 or pgd_eps == 0.0:
            return _input, None

        X = add_noise_fn(_input=_input, noise=noise, batch=self.universal)
        # ----------------------------------------------------------------------------------------- #

        for _iter in range(iteration):
            if self.early_stop_check(X=X, loss_fn=loss_fn, **kwargs):
                if 'end' in output:
                    self.output_info(_input=_input, noise=noise, mode='end', loss_fn=loss_fn, **kwargs)
                return X.detach(), _iter + 1
            if self.grad_method == 'hess' and _iter % self.hess_p == 0:
                self.hess = self.calc_hess(loss_fn, X, sigma=self.sigma,
                                           hess_b=self.hess_b, hess_lambda=self.hess_lambda)
                self.hess /= self.hess.norm(p=2)
            grad = self.calc_grad(loss_fn, X)
            if self.grad_method != 'white' and 'middle' in output:
                real_grad = self.whitebox_grad(loss_fn, X)
                prints('cos<real, est> = ', cos_sim(grad.sign(), real_grad.sign()),
                       indent=self.indent + 2)
            if self.universal:
                grad = grad.mean(dim=0)
            noise.data = (noise - pgd_alpha * torch.sign(grad)).data
            noise.data = self.projector(noise, pgd_eps, norm=self.norm).data
            X = add_noise_fn(_input=_input, noise=noise, batch=self.universal)
            if self.universal:
                noise.data = (X - _input).mode(dim=0)[0].data
            else:
                noise.data = (X - _input).data

            if 'middle' in output:
                self.output_info(_input=_input, noise=noise, mode='middle',
                                 _iter=_iter, iteration=iteration, loss_fn=loss_fn, **kwargs)
        if 'end' in output:
            self.output_info(_input=_input, noise=noise, mode='end', loss_fn=loss_fn, **kwargs)
        return X.detach(), None
Пример #2
0
    def get_detect_result(self, seq_centers: torch.Tensor, target=None):
        pair_seq = -torch.ones(
            self.attack.iteration - 1, dtype=torch.long, device=env['device'])
        detect_prob = torch.ones(
            self.model.num_classes) / self.model.num_classes
        for i in range(len(seq_centers) - 1):
            X: torch.Tensor = seq_centers[i].clone()
            dist_list = torch.zeros(self.model.num_classes)

            for _class in range(self.model.num_classes):
                _label = _class * torch.ones(
                    len(X), dtype=torch.long, device=X.device)
                X.requires_grad_()
                loss = self.model.loss(X, _label)
                grad = torch.autograd.grad(loss, X)[0]
                X.requires_grad = False
                grad /= grad.abs().max()
                if self.active:
                    noise_grad = torch.zeros(X.numel(), device=X.device)
                    offset = (_class + i) % self.model.num_classes

                    for multiplier in range(
                            len(noise_grad) // self.model.num_classes):
                        noise_grad[multiplier * self.model.num_classes +
                                   offset] = 1
                    noise_grad = noise_grad.view(X.shape)
                    # noise_grad /= noise_grad.abs().max()
                    grad = self.active_percent * noise_grad + \
                        (1 - self.active_percent) * grad
                grad.sign_()
                vec = seq_centers[i + 1] - X
                dist = cos_sim(-grad, vec)
                dist_list[_class] = dist
                if 'middle' in self.output and _class == target:
                    print('sim <vec, real>: ', cos_sim(vec, -grad))
                    print('sim <est, real>: ',
                          cos_sim(self.attack_grad_list[i], grad))
                    print('sim <vec, est>: ',
                          cos_sim(vec, -self.attack_grad_list[i]))
            # todo: Use atanh for normalization after pytorch 1.6
            detect_prob = torch.nn.functional.softmax(
                torch.log((2 / (1 - dist_list)).sub(1)))
            # detect_prob.div_(detect_prob.norm(p=2))
            pair_seq[i] = detect_prob.argmax().item()
        return pair_seq
Пример #3
0
 def get_center_class_pairs(self, candidate_centers: torch.Tensor,
                            seq_centers: torch.Tensor, seq: torch.Tensor):
     pair_seq = []
     for i in range(len(candidate_centers) - 1):
         sub_pair_seq = []
         for point in candidate_centers[i]:
             # if self.active:
             #     vec = seq_centers[i+1]-point
             #     _result = vec.view(-1)
             #     for j in range(len(_result)):
             #         if _result[j] < 0 and j > i:
             #             sub_pair_seq.append((j-i) % self.num_classes)
             # print(vec.view(-1)[:self.num_classes])
             x = point.clone()
             dist_list = torch.zeros(self.model.num_classes)
             # print('bound: ', estimate_error + shift_dist)
             for _class in range(self.model.num_classes):
                 x.requires_grad_()
                 loss = self.model.loss(x, _class)
                 grad = torch.autograd.grad(loss, x)[0]
                 x.requires_grad = False
                 grad.sign_()
                 if self.active:
                     noise_grad = torch.zeros_like(grad).flatten()
                     offset = (_class + i) % self.model.num_classes
                     for multiplier in range(
                             int(len(noise_grad) / self.model.num_classes)):
                         noise_grad[multiplier * self.model.num_classes +
                                    offset] = 1
                     noise_grad = noise_grad.view(grad.shape)
                     grad = self.active_percent * noise_grad + \
                         (1 - self.active_percent) * grad
                     grad.sign_()
                 vec = seq_centers[i + 1] - point
                 dist = cos_sim(-grad, vec)
                 dist_list[_class] = dist
             sub_pair_seq.append(dist_list.argmax().item())
         pair_seq.append(sub_pair_seq)
     return pair_seq