def simba_batch(self,
                 images_batch,
                 labels_batch,
                 max_iters,
                 freq_dims,
                 stride,
                 epsilon,
                 linf_bound=0.0,
                 order='rand',
                 targeted=False,
                 pixel_attack=False,
                 log_every=10):
     batch_size = images_batch.size(0)
     image_size = images_batch.size(2)
     assert self.image_size == image_size
     # sample a random ordering for coordinates independently per batch element
     if order == 'rand':
         indices = torch.randperm(3 * freq_dims * freq_dims)[:max_iters]
     elif order == 'diag':
         indices = utils.diagonal_order(image_size, 3)[:max_iters]
     elif order == 'strided':
         indices = utils.block_order(image_size,
                                     3,
                                     initial_size=freq_dims,
                                     stride=stride)[:max_iters]
     else:
         indices = utils.block_order(image_size, 3)[:max_iters]
     if order == 'rand':
         expand_dims = freq_dims
     else:
         expand_dims = image_size
     n_dims = 3 * expand_dims * expand_dims
     x = torch.zeros(batch_size, n_dims)
     # logging tensors
     probs = torch.zeros(batch_size, max_iters)
     succs = torch.zeros(batch_size, max_iters)
     queries = torch.zeros(batch_size, max_iters)
     l2_norms = torch.zeros(batch_size, max_iters)
     linf_norms = torch.zeros(batch_size, max_iters)
     prev_probs = self.get_probs(images_batch, labels_batch)
     preds = self.get_preds(images_batch)
     if pixel_attack:
         trans = lambda z: z
     else:
         trans = lambda z: utils.block_idct(
             z, block_size=image_size, linf_bound=linf_bound)
     remaining_indices = torch.arange(0, batch_size).long()
     for k in range(max_iters):
         dim = indices[k]
         expanded = (images_batch[remaining_indices] + trans(
             self.expand_vector(x[remaining_indices], expand_dims))).clamp(
                 0, 1)
         perturbation = trans(self.expand_vector(x, expand_dims))
         l2_norms[:, k] = perturbation.view(batch_size, -1).norm(2, 1)
         linf_norms[:, k] = perturbation.view(batch_size,
                                              -1).abs().max(1)[0]
         preds_next = self.get_preds(expanded)
         preds[remaining_indices] = preds_next
         if targeted:
             remaining = preds.ne(labels_batch)
         else:
             remaining = preds.eq(labels_batch)
         # check if all images are misclassified and stop early
         if remaining.sum() == 0:
             adv = (images_batch +
                    trans(self.expand_vector(x, expand_dims))).clamp(0, 1)
             probs_k = self.get_probs(adv, labels_batch)
             probs[:, k:] = probs_k.unsqueeze(1).repeat(1, max_iters - k)
             succs[:, k:] = torch.ones(batch_size, max_iters - k)
             queries[:, k:] = torch.zeros(batch_size, max_iters - k)
             break
         remaining_indices = torch.arange(0, batch_size)[remaining].long()
         if k > 0:
             succs[:, k - 1] = ~remaining
         diff = torch.zeros(remaining.sum(), n_dims)
         diff[:, dim] = epsilon
         left_vec = x[remaining_indices] - diff
         right_vec = x[remaining_indices] + diff
         # trying negative direction
         adv = (images_batch[remaining_indices] +
                trans(self.expand_vector(left_vec, expand_dims))).clamp(
                    0, 1)
         left_probs = self.get_probs(adv, labels_batch[remaining_indices])
         queries_k = torch.zeros(batch_size)
         # increase query count for all images
         queries_k[remaining_indices] += 1
         if targeted:
             improved = left_probs.gt(prev_probs[remaining_indices])
         else:
             improved = left_probs.lt(prev_probs[remaining_indices])
         # only increase query count further by 1 for images that did not improve in adversarial loss
         if improved.sum() < remaining_indices.size(0):
             queries_k[remaining_indices[~improved]] += 1
         # try positive directions
         adv = (images_batch[remaining_indices] +
                trans(self.expand_vector(right_vec, expand_dims))).clamp(
                    0, 1)
         right_probs = self.get_probs(adv, labels_batch[remaining_indices])
         if targeted:
             right_improved = right_probs.gt(
                 torch.max(prev_probs[remaining_indices], left_probs))
         else:
             right_improved = right_probs.lt(
                 torch.min(prev_probs[remaining_indices], left_probs))
         probs_k = prev_probs.clone()
         # update x depending on which direction improved
         if improved.sum() > 0:
             left_indices = remaining_indices[improved]
             left_mask_remaining = improved.unsqueeze(1).repeat(1, n_dims)
             x[left_indices] = left_vec[left_mask_remaining].view(
                 -1, n_dims)
             probs_k[left_indices] = left_probs[improved]
         if right_improved.sum() > 0:
             right_indices = remaining_indices[right_improved]
             right_mask_remaining = right_improved.unsqueeze(1).repeat(
                 1, n_dims)
             x[right_indices] = right_vec[right_mask_remaining].view(
                 -1, n_dims)
             probs_k[right_indices] = right_probs[right_improved]
         probs[:, k] = probs_k
         queries[:, k] = queries_k
         prev_probs = probs[:, k]
         if (k + 1) % log_every == 0 or k == max_iters - 1:
             print(
                 'Iteration %d: queries = %.4f, prob = %.4f, remaining = %.4f'
                 % (k + 1, queries.sum(1).mean(), probs[:, k].mean(),
                    remaining.float().mean()))
     expanded = (images_batch +
                 trans(self.expand_vector(x, expand_dims))).clamp(0, 1)
     preds = self.get_preds(expanded)
     if targeted:
         remaining = preds.ne(labels_batch)
     else:
         remaining = preds.eq(labels_batch)
     succs[:, max_iters - 1] = ~remaining
     return expanded, probs, succs, queries, l2_norms, linf_norms
Exemple #2
0
    z = torch.from_numpy(z).float()
    perturbation = (z @ Random_Matrix).view(1, 3, image_size, image_size)
    new_image = (x + perturbation).clamp(0, 1)
    return current_q, is_success, perturbation.view(1, -1).norm(2, 1).item()


if MODEL.startswith("inception"):
    image_size = 299
    testset = dset.ImageFolder(DATA_ROOT, utils.INCEPTION_TRANSFORM)
else:
    image_size = 224
    testset = dset.ImageFolder(DATA_ROOT, utils.IMAGENET_TRANSFORM)

Random_Matrix = np.zeros((LOW_DIM, 3 * image_size * image_size))
indices = utils.block_order(image_size,
                            3,
                            initial_size=FREQ_DIM,
                            stride=STRIDE)
for i in range(LOW_DIM):
    Random_Matrix[i][indices[i]] = 1
Random_Matrix = (torch.from_numpy(
    idct(
        idct(
            Random_Matrix.reshape(-1, 3, image_size, image_size),
            axis=3,
            norm="ortho",
        ),
        axis=2,
        norm="ortho",
    )).view(-1, 3, image_size, image_size).float()).view(LOW_DIM, -1)

# Attack the 1st image