Exemplo n.º 1
0
    def _fast_iterative_shrinkage_thresholding(self, x, yy_k, xx_k):

        zt = self.global_step / (self.global_step + 3)

        upper = clamp(yy_k - self.beta, max=self.clip_max)
        lower = clamp(yy_k + self.beta, min=self.clip_min)

        diff = yy_k - x
        cond1 = (diff > self.beta).float()
        cond2 = (torch.abs(diff) <= self.beta).float()
        cond3 = (diff < -self.beta).float()

        xx_k_p_1 = (cond1 * upper) + (cond2 * x) + (cond3 * lower)
        yy_k.data = xx_k_p_1 + (zt * (xx_k_p_1 - xx_k))
        return yy_k, xx_k_p_1
Exemplo n.º 2
0
    def perturb(self, x, y=None):
        x, y = self._verify_and_process_inputs(x, y)
        xadv = x
        batch_size = x.shape[0]
        dim_x = int(np.prod(x.shape[1:]))
        max_iters = int(dim_x * self.gamma / 2)
        search_space = x.new_ones(batch_size, dim_x).int()
        curr_step = 0
        yadv = self._get_predicted_label(xadv)

        # Algorithm 1
        while ((y != yadv).any() and curr_step < max_iters):

            grads_target, grads_other = self._compute_forward_derivative(
                xadv, y)

            # Algorithm 3
            p1, p2, valid = self._saliency_map(search_space, grads_target,
                                               grads_other, y)

            cond = (y != yadv) & valid

            self._update_search_space(search_space, p1, p2, cond)

            xadv = self._modify_xadv(xadv, batch_size, cond, p1, p2)
            yadv = self._get_predicted_label(xadv)

            curr_step += 1

        xadv = clamp(xadv, min=self.clip_min, max=self.clip_max)
        return xadv
Exemplo n.º 3
0
 def _modify_xadv(self, xadv, batch_size, cond, p1, p2):
     ori_shape = xadv.shape
     xadv = xadv.view(batch_size, -1)
     for idx in range(batch_size):
         if cond[idx] != 0:
             xadv[idx, p1[idx]] += self.theta
             xadv[idx, p2[idx]] += self.theta
     xadv = clamp(xadv, min=self.clip_min, max=self.clip_max)
     xadv = xadv.view(ori_shape)
     return xadv
Exemplo n.º 4
0
    def _loss_fn(self, output, y_onehot, l1dist, l2distsq, const, opt=False):

        real = (y_onehot * output).sum(dim=1)
        other = ((1.0 - y_onehot) * output -
                 (y_onehot * TARGET_MULT)).max(1)[0]

        if self.targeted:
            loss_logits = clamp(other - real + self.confidence, min=0.)
        else:
            loss_logits = clamp(real - other + self.confidence, min=0.)
        loss_logits = torch.sum(const * loss_logits)

        loss_l2 = l2distsq.sum()

        if opt:
            loss = loss_logits + loss_l2
        else:
            loss_l1 = self.beta * l1dist.sum()
            loss = loss_logits + loss_l2 + loss_l1
        return loss
Exemplo n.º 5
0
    def _loss_fn(self, output, y_onehot, l2distsq, const):
        # TODO: move this out of the class and make this the default loss_fn
        #   after having targeted tests implemented


        # real = (y_onehot * logits).sum(dim=1)
        real = (y_onehot * output).sum(dim=1)

        # TODO: make loss modular, write a loss class
        other = ((1.0 - y_onehot) * output - (y_onehot * TARGET_MULT)
                 ).max(1)[0]
        # - (y_onehot * TARGET_MULT) is for the true label not to be selected

        if self.targeted:
            loss1 = clamp(other - real + self.confidence, min=0.)
        else:
            loss1 = clamp(real - other + self.confidence, min=0.)
        loss2 = (l2distsq).sum()
        loss1 = torch.sum(const * loss1)
        loss = loss1 + loss2
        return loss
Exemplo n.º 6
0
def rand_init_delta(delta, x, ord, eps, clip_min, clip_max):
    # TODO: Currently only considered one way of "uniform" sampling
    # for Linf, there are 3 ways:
    #   1) true uniform sampling by first calculate the rectangle then sample
    #   2) uniform in eps box then truncate using data domain (implemented)
    #   3) uniform sample in data domain then truncate with eps box
    # for L2, true uniform sampling is hard, since it requires uniform sampling
    #   inside a intersection of cube and ball, so there are 2 ways:
    #   1) uniform sample in the data domain, then truncate using the L2 ball
    #       (implemented)
    #   2) uniform sample in the L2 ball, then truncate using the data domain
    # for L1: uniform l1 ball init, then truncate using the data domain

    if isinstance(eps, torch.Tensor):
        assert len(eps) == len(delta)

    if ord == np.inf:
        delta.data.uniform_(-1, 1)
        delta.data = batch_multiply(eps, delta.data)
    elif ord == 2:
        delta.data.uniform_(clip_min, clip_max)
        delta.data = delta.data - x
        delta.data = clamp_by_pnorm(delta.data, ord, eps)
    elif ord == 1:
        ini = laplace.Laplace(0, 1)
        delta.data = ini.sample(delta.data.shape)
        delta.data = normalize_by_pnorm(delta.data, p=1)
        ray = uniform.Uniform(0, eps).sample()
        delta.data *= ray
        delta.data = clamp(x.data + delta.data, clip_min, clip_max) - x.data
    else:
        error = "Only ord = inf, ord = 1 and ord = 2 have been implemented"
        raise NotImplementedError(error)

    delta.data = clamp(x + delta.data, min=clip_min, max=clip_max) - x
    return delta.data
Exemplo n.º 7
0
    def _loss_fn_spatial(self, grid, x, y, const, grid_ori):
        imgs = x.clone()
        grid = torch.from_numpy(grid.reshape(grid_ori.shape)).float().to(
            x.device).requires_grad_()
        delta = grid_ori - grid

        adv_img = F.grid_sample(imgs, grid)
        output = self.predict(adv_img)
        real = (y * output).sum(dim=1)
        other = ((1.0 - y) * output - (y * TARGET_MULT)).max(1)[0]
        if self.targeted:
            loss1 = clamp(other - real + self.confidence, min=0.)
        else:
            loss1 = clamp(real - other + self.confidence, min=0.)
        loss2 = self.initial_const * (torch.sqrt(
            (((delta[:, :, 1:] - delta[:, :, :-1] + 1e-10)**2)).view(
                delta.shape[0], -1).sum(1)) + torch.sqrt(
                    ((delta[:, 1:, :] - delta[:, :-1, :] + 1e-10)**2).view(
                        delta.shape[0], -1).sum(1)))
        loss = torch.sum(loss1) + torch.sum(loss2)
        loss.backward()
        grad_ret = grid.grad.data.cpu().numpy().flatten().astype(float)
        grid.grad.data.zero_()
        return loss.data.cpu().numpy().astype(float), grad_ret
Exemplo n.º 8
0
    def _rescale_x_score(self, predict, x, y, ori, best_dist):
        x = torch.stack(x)
        x = self._revert_rescale(x)

        batch_logits = predict(x)
        scores = nn.Softmax(dim=1)(batch_logits)[:, y]

        if not self.comply_with_foolbox:
            x = clamp(x, self.clip_min, self.clip_max)
            batch_logits = predict(x)

        _, bests = torch.max(batch_logits, dim=1)
        best_img = None
        for ii in range(len(bests)):
            curr_dist = torch.sum((x[ii] - ori)**2)
            if (is_successful(int(bests[ii]), y, self.targeted)
                    and curr_dist < best_dist):
                best_img = x[ii]
                best_dist = curr_dist
        scores = nn.Softmax(dim=1)(batch_logits)[:, y]
        return scores, best_img, best_dist
    def perturb(self, source, guide, delta=None):
        """
        Given source, returns their adversarial counterparts
        with representations close to that of the guide.

        :param source: input tensor which we want to perturb.
        :param guide: targeted input.
        :param delta: tensor contains the random initialization.
        :return: tensor containing perturbed inputs.
        """
        # Initialization
        if delta is None:
            delta = torch.zeros_like(source)
            if self.rand_init:
                delta = delta.uniform_(-self.eps, self.eps)
        else:
            delta = delta.detach()

        delta.requires_grad_()

        source = replicate_input(source)
        guide = replicate_input(guide)
        guide_ftr = self.predict(guide).detach()

        xadv = perturb_iterative(source,
                                 guide_ftr,
                                 self.predict,
                                 self.nb_iter,
                                 eps_iter=self.eps_iter,
                                 loss_fn=self.loss_fn,
                                 minimize=True,
                                 ord=np.inf,
                                 eps=self.eps,
                                 clip_min=self.clip_min,
                                 clip_max=self.clip_max,
                                 delta_init=delta)

        xadv = clamp(xadv, self.clip_min, self.clip_max)

        return xadv.data
    def perturb(self, x, y=None):
        """
        Given examples (x, y), returns their adversarial counterparts with
        an attack length of eps.

        :param x: input tensor.
        :param y: label tensor.
                  - if None and self.targeted=False, compute y as predicted
                    labels.
                  - if self.targeted=True, then y must be the targeted labels.
        :return: tensor containing perturbed inputs.
        """
        x, y = self._verify_and_process_inputs(x, y)

        delta = torch.zeros_like(x)
        delta = nn.Parameter(delta)
        if self.rand_init:
            rand_init_delta(delta, x, self.ord, self.eps, self.clip_min,
                            self.clip_max)
            delta.data = clamp(
                x + delta.data, min=self.clip_min, max=self.clip_max) - x

        rval = perturb_iterative(
            x,
            y,
            self.predict,
            nb_iter=self.nb_iter,
            eps=self.eps,
            eps_iter=self.eps_iter,
            loss_fn=self.loss_fn,
            minimize=self.targeted,
            ord=self.ord,
            clip_min=self.clip_min,
            clip_max=self.clip_max,
            delta_init=delta,
            l1_sparsity=self.l1_sparsity,
        )

        return rval.data
Exemplo n.º 11
0
 def _get_arctanh_x(self, x):
     result = clamp((x - self.clip_min) / (self.clip_max - self.clip_min),
                    min=0., max=1.) * 2 - 1
     return torch_arctanh(result * ONE_MINUS_EPS)
    def perturb(self, x, y=None):
        """
        Given examples (x, y), returns their adversarial counterparts with
        an attack length of eps.

        :param x: input tensor.
        :param y: label tensor.
                  - if None and self.targeted=False, compute y as predicted
                    labels.
                  - if self.targeted=True, then y must be the targeted labels.
        :return: tensor containing perturbed inputs.
        """
        x, y = self._verify_and_process_inputs(x, y)

        delta = torch.zeros_like(x)
        g = torch.zeros_like(x)

        delta = nn.Parameter(delta)

        for i in range(self.nb_iter):

            if delta.grad is not None:
                delta.grad.detach_()
                delta.grad.zero_()

            imgadv = x + delta
            outputs = self.predict(imgadv)

            if isinstance(outputs, tuple):
                logits = outputs[-2]
            else:
                logits = outputs
            loss = self.loss_fn(logits, y)
            # loss = self.loss_fn(outputs, y)
            if self.targeted:
                loss = -loss
            loss.backward()

            g = self.decay_factor * g + normalize_by_pnorm(delta.grad.data,
                                                           p=1)
            # according to the paper it should be .sum(), but in their
            #   implementations (both cleverhans and the link from the paper)
            #   it is .mean(), but actually it shouldn't matter
            if self.ord == np.inf:
                delta.data += self.eps_iter * torch.sign(g)
                delta.data = clamp(delta.data, min=-self.eps, max=self.eps)
                delta.data = clamp(
                    x + delta.data, min=self.clip_min, max=self.clip_max) - x
            elif self.ord == 2:
                delta.data += self.eps_iter * normalize_by_pnorm(g, p=2)
                delta.data *= clamp(
                    (self.eps * normalize_by_pnorm(delta.data, p=2) /
                     delta.data),
                    max=1.)
                delta.data = clamp(
                    x + delta.data, min=self.clip_min, max=self.clip_max) - x
            else:
                error = "Only ord = inf and ord = 2 have been implemented"
                raise NotImplementedError(error)

        rval = x + delta.data
        return rval
def perturb_iterative(xvar,
                      yvar,
                      predict,
                      nb_iter,
                      eps,
                      eps_iter,
                      loss_fn,
                      delta_init=None,
                      minimize=False,
                      ord=np.inf,
                      clip_min=0.0,
                      clip_max=1.0,
                      l1_sparsity=None):
    """
    Iteratively maximize the loss over the input. It is a shared method for
    iterative attacks including IterativeGradientSign, LinfPGD, etc.

    :param xvar: input data.
    :param yvar: input labels.
    :param predict: forward pass function.
    :param nb_iter: number of iterations.
    :param eps: maximum distortion.
    :param eps_iter: attack step size.
    :param loss_fn: loss function.
    :param delta_init: (optional) tensor contains the random initialization.
    :param minimize: (optional bool) whether to minimize or maximize the loss.
    :param ord: (optional) the order of maximum distortion (inf or 2).
    :param clip_min: mininum value per input dimension.
    :param clip_max: maximum value per input dimension.
    :param l1_sparsity: sparsity value for L1 projection.
                  - if None, then perform regular L1 projection.
                  - if float value, then perform sparse L1 descent from
                    Algorithm 1 in https://arxiv.org/pdf/1904.13000v1.pdf
    :return: tensor containing the perturbed input.
    """
    if delta_init is not None:
        delta = delta_init
    else:
        delta = torch.zeros_like(xvar)

    delta.requires_grad_()
    for ii in range(nb_iter):
        outputs = predict(xvar + delta)

        if isinstance(outputs, tuple):
            logits = outputs[-2]
        else:
            logits = outputs
        loss = loss_fn(logits, yvar)
        # loss = loss_fn(outputs, yvar)
        if minimize:
            loss = -loss

        loss.backward()
        if ord == np.inf:
            grad_sign = delta.grad.data.sign()
            delta.data = delta.data + batch_multiply(eps_iter, grad_sign)
            delta.data = batch_clamp(eps, delta.data)
            delta.data = clamp(xvar.data + delta.data, clip_min,
                               clip_max) - xvar.data

        elif ord == 2:
            grad = delta.grad.data
            grad = normalize_by_pnorm(grad)
            delta.data = delta.data + batch_multiply(eps_iter, grad)
            delta.data = clamp(xvar.data + delta.data, clip_min,
                               clip_max) - xvar.data
            if eps is not None:
                delta.data = clamp_by_pnorm(delta.data, ord, eps)

        elif ord == 1:
            grad = delta.grad.data
            abs_grad = torch.abs(grad)

            batch_size = grad.size(0)
            view = abs_grad.view(batch_size, -1)
            view_size = view.size(1)
            if l1_sparsity is None:
                vals, idx = view.topk(1)
            else:
                vals, idx = view.topk(
                    int(np.round((1 - l1_sparsity) * view_size)))

            out = torch.zeros_like(view).scatter_(1, idx, vals)
            out = out.view_as(grad)
            grad = grad.sign() * (out > 0).float()
            grad = normalize_by_pnorm(grad, p=1)
            delta.data = delta.data + batch_multiply(eps_iter, grad)

            delta.data = batch_l1_proj(delta.data.cpu(), eps)
            if xvar.is_cuda:
                delta.data = delta.data.cuda()
            delta.data = clamp(xvar.data + delta.data, clip_min,
                               clip_max) - xvar.data
        else:
            error = "Only ord = inf, ord = 1 and ord = 2 have been implemented"
            raise NotImplementedError(error)
        delta.grad.data.zero_()

    x_adv = clamp(xvar + delta, clip_min, clip_max)
    return x_adv