def perturb_iterative(xvar,
                      yvar,
                      predict,
                      nb_iter,
                      eps,
                      eps_iter,
                      loss_fn,
                      delta_init=None,
                      minimize=False,
                      ord=np.inf,
                      clip_min=0.0,
                      clip_max=1.0):
    """
    Iteratively maximize the loss over the input. It is a shared method for
    iterative attacks including IterativeGradientSign, LinfPGD, etc.

    :param xvar: input data.
    :param yvar: input labels.
    :param predict: forward pass function.
    :param nb_iter: number of iterations.
    :param eps: maximum distortion.
    :param eps_iter: attack step size.
    :param loss_fn: loss function.
    :param delta_init: (optional) tensor contains the random initialization.
    :param minimize: (optional bool) whether to minimize or maximize the loss.
    :param ord: (optional) the order of maximum distortion (inf or 2).
    :param clip_min: mininum value per input dimension.
    :param clip_max: maximum value per input dimension.

    :return: tensor containing the perturbed input.
    """
    if delta_init is not None:
        delta = delta_init
    else:
        delta = torch.zeros_like(xvar)

    delta.requires_grad_()
    for ii in range(nb_iter):
        outputs = predict(xvar + delta)
        loss = loss_fn(outputs, yvar)
        if minimize:
            loss = -loss

        loss.backward()
        if ord == np.inf:
            grad_sign = delta.grad.data.sign()
            delta.data = delta.data + batch_multiply(eps_iter, grad_sign)
            delta.data = batch_clamp(eps, delta.data)
            delta.data = clamp(xvar.data + delta.data, clip_min,
                               clip_max) - xvar.data

        elif ord == 2:
            grad = delta.grad.data
            grad = normalize_by_pnorm(grad)
            delta.data = delta.data + batch_multiply(eps_iter, grad)
            delta.data = clamp(xvar.data + delta.data, clip_min,
                               clip_max) - xvar.data
            if eps is not None:
                delta.data = clamp_by_pnorm(delta.data, ord, eps)
        else:
            error = "Only ord = inf and ord = 2 have been implemented"
            raise NotImplementedError(error)

        delta.grad.data.zero_()

    x_adv = clamp(xvar + delta, clip_min, clip_max)
    return x_adv
Ejemplo n.º 2
0
def perturb_iterative(xvar, yvar, predict, nb_iter, eps, eps_iter, loss_fn,
                      delta_init=None, minimize=False, ord=np.inf,
                      clip_min=0.0, clip_max=1.0,
                      l1_sparsity=None, accumulate_param_grad_prob=0., return_itermediate=False):
    """
    Iteratively maximize the loss over the input. It is a shared method for
    iterative attacks including IterativeGradientSign, LinfPGD, etc.

    :param xvar: input data.
    :param yvar: input labels.
    :param predict: forward pass function.
    :param nb_iter: number of iterations.
    :param eps: maximum distortion.
    :param eps_iter: attack step size.
    :param loss_fn: loss function.
    :param delta_init: (optional) tensor contains the random initialization.
    :param minimize: (optional bool) whether to minimize or maximize the loss.
    :param ord: (optional) the order of maximum distortion (inf or 2).
    :param clip_min: mininum value per input dimension.
    :param clip_max: maximum value per input dimension.
    :param l1_sparsity: sparsity value for L1 projection.
                  - if None, then perform regular L1 projection.
                  - if float value, then perform sparse L1 descent from
                    Algorithm 1 in https://arxiv.org/pdf/1904.13000v1.pdf
    :return: tensor containing the perturbed input.
    """
    if delta_init is not None:
        delta = delta_init
    else:
        delta = torch.zeros_like(xvar)
    deltas = []
    delta.requires_grad_()
    for ii in range(nb_iter):
        outputs = predict(xvar + delta)
        loss = loss_fn(outputs, yvar)
        if minimize:
            loss = -loss
        store_grad = bool(np.random.binomial(1, accumulate_param_grad_prob))
        if store_grad:
            loss.backward()
        else:            
            delta.grad = torch.autograd.grad(loss, [delta])[0]
        if ord == np.inf:
            grad_sign = delta.grad.data.sign()
            delta.data = delta.data + batch_multiply(eps_iter, grad_sign)
            delta.data = batch_clamp(eps, delta.data)
            delta.data = clamp(xvar.data + delta.data, clip_min, clip_max
                               ) - xvar.data

        elif ord == 2:
            grad = delta.grad.data
            grad = normalize_by_pnorm(grad)
            delta.data = delta.data + batch_multiply(eps_iter, grad)
            delta.data = clamp(xvar.data + delta.data, clip_min, clip_max
                               ) - xvar.data
            if eps is not None:
                delta.data = clamp_by_pnorm(delta.data, ord, eps)

        elif ord == 1:
            grad = delta.grad.data
            abs_grad = torch.abs(grad)

            batch_size = grad.size(0)
            view = abs_grad.view(batch_size, -1)
            view_size = view.size(1)
            if l1_sparsity is None:
                vals, idx = view.topk(1)
            else:
                vals, idx = view.topk(
                    int(np.round((1 - l1_sparsity) * view_size)))

            out = torch.zeros_like(view).scatter_(1, idx, vals)
            out = out.view_as(grad)
            grad = grad.sign() * (out > 0).float()
            grad = normalize_by_pnorm(grad, p=1)
            delta.data = delta.data + batch_multiply(eps_iter, grad)

            delta.data = batch_l1_proj(delta.data.cpu(), eps)
            if xvar.is_cuda:
                delta.data = delta.data.cuda()
            delta.data = clamp(xvar.data + delta.data, clip_min, clip_max
                               ) - xvar.data
        else:
            error = "Only ord = inf, ord = 1 and ord = 2 have been implemented"
            raise NotImplementedError(error)
        delta.grad.data.zero_()
        deltas.append(delta.detach())
    if return_itermediate:
        x_adv = torch.stack([clamp(xvar + d, clip_min, clip_max) for d in deltas], dim=1)
    else:
        x_adv = clamp(xvar + delta, clip_min, clip_max)
    return x_adv
Ejemplo n.º 3
0
    def perturb_iterative_fool_many(xvar,
                                    embvar,
                                    indlistvar,
                                    yvar,
                                    predict,
                                    nb_iter,
                                    eps,
                                    epscand,
                                    eps_iter,
                                    loss_fn,
                                    rayon,
                                    delta_init=None,
                                    minimize=False,
                                    ord=np.inf,
                                    clip_min=0.0,
                                    clip_max=1.0,
                                    l1_sparsity=None):
        """
      Iteratively maximize the loss over the input. It is a shared method for
      iterative attacks including IterativeGradientSign, LinfPGD, etc.
      :param xvar: input data.
      :param yvar: input labels.
      :param predict: forward pass function.
      :param nb_iter: number of iterations.
      :param eps: maximum distortion.
      :param eps_iter: attack step size.
      :param loss_fn: loss function.
      :param delta_init: (optional) tensor contains the random initialization.
      :param minimize: (optional bool) whether to minimize or maximize the loss.
      :param ord: (optional) the order of maximum distortion (inf or 2).
      :param clip_min: mininum value per input dimension.
      :param clip_max: maximum value per input dimension.
      :param l1_sparsity: sparsity value for L1 projection.
                    - if None, then perform regular L1 projection.
                    - if float value, then perform sparse L1 descent from
                      Algorithm 1 in https://arxiv.org/pdf/1904.13000v1.pdf
      :return: tensor containing the perturbed input.
      """

        #will contain all words encountered during PGD
        nb = len(indlistvar)
        tablist = []
        for t in range(nb):
            tablist += [[]]
        fool = False

        #contain each loss on embed and each difference of loss on word nearest neighboor
        loss_memory = np.zeros((nb_iter, ))
        word_balance_memory = np.zeros((nb_iter, ))

        candid = [torch.empty(0)] * nb
        convers = [[]] * nb
        for u in range(nb):
            #prepare all potential candidates, once and for all
            candidates = torch.empty([0, 768]).to(device)
            conversion = []
            emb_matrix = model.roberta.embeddings.word_embeddings.weight
            normed_emb_matrix = F.normalize(emb_matrix, p=2, dim=1)
            normed_emb_word = F.normalize(embvar[0][indlistvar[u]], p=2, dim=0)
            cosine_similarity = torch.matmul(
                normed_emb_word, torch.transpose(normed_emb_matrix, 0, 1))
            for t in range(
                    len(cosine_similarity)):  #evitez de faire DEUX boucles .
                if cosine_similarity[t] > epscand:
                    if levenshtein(
                            tokenizer.decode(
                                torch.tensor([xvar[0][indlistvar[u]]])),
                            tokenizer.decode(torch.tensor([t]))) != 1:
                        candidates = torch.cat(
                            (candidates, normed_emb_matrix[t].unsqueeze(0)), 0)
                        conversion += [t]
            candid[u] = candidates
            convers[u] = conversion
            print("nb of candidates :")
            print(len(conversion))

        #U, S, V = torch.svd(model.roberta.embeddings.word_embeddings.weight)

        if delta_init is not None:
            delta = delta_init
        else:
            delta = torch.zeros_like(embvar)

        #PGD
        delta.requires_grad_()
        ii = 0
        while ii < nb_iter and not (fool):
            outputs = predict(xvar, embvar + delta)
            loss = loss_fn(outputs, yvar)
            if minimize:
                loss = -loss

            loss.backward()
            if ord == np.inf:
                grad_sign = delta.grad.data.sign()
                grad_sign = tozerolist(grad_sign, indlistvar)
                delta.data = delta.data + batch_multiply(eps_iter, grad_sign)
                delta.data = batch_clamp(eps, delta.data)
                delta.data = clamp(
                    embvar.data + delta.data,
                    clip_min,
                    clip_max  #à retirer?
                ) - embvar.data
                with torch.no_grad():
                    delta.data = tozero(delta.data, indlistvar)
                    if (ii % 300) == 0:
                        adverslist = []
                        for t in range(nb):
                            advers, nb_vois = neighboors_np_dens_cand(
                                (embvar + delta)[0][indlistvar[t]], rayon,
                                candid[t])
                            advers = int(advers[0])
                            advers = torch.tensor(convers[t][advers])
                            if len(tablist[t]) == 0:
                                tablist[t] += [
                                    (tokenizer.decode(advers.unsqueeze(0)), ii,
                                     nb_vois)
                                ]
                            elif not (first(
                                    tablist[t][-1]) == tokenizer.decode(
                                        advers.unsqueeze(0))):
                                tablist[t] += [
                                    (tokenizer.decode(advers.unsqueeze(0)), ii,
                                     nb_vois)
                                ]
                            adverslist += [advers]
                        word_balance_memory[ii] = float(
                            model(replacelist(xvar, indlistvar, adverslist),
                                  labels=1 - yvar)[0]) - float(
                                      model(replacelist(
                                          xvar, indlistvar, adverslist),
                                            labels=yvar)[0])
                        if word_balance_memory[ii] < 0:
                            fool = True

            elif ord == 0:
                grad = delta.grad.data
                grad = tozero(grad, indlistvar)
                grad = torch.matmul(
                    torch.cat((torch.matmul(grad, v)[:, :, :50],
                               torch.zeros([768 - 50]).to(device)), 2), v.t())
                delta.data = delta.data + batch_multiply(eps_iter, grad)
                delta.data[0] = my_proj_all(embvar.data[0] + delta.data[0],
                                            embvar[0], indlistvar,
                                            eps) - embvar.data[0]
                delta.data = clamp(embvar.data + delta.data, clip_min,
                                   clip_max) - embvar.data  #à virer je pense
                with torch.no_grad():
                    delta.data = tozero(delta.data, indlistvar)
                    if (ii % 300) == 0:
                        adverslist = []
                        for t in range(nb):
                            advers, nb_vois = neighboors_np_dens_cand(
                                (embvar + delta)[0][indlistvar[t]], rayon,
                                candid[t])
                            advers = int(advers[0])
                            advers = torch.tensor(convers[t][advers])
                            if len(tablist[t]) == 0:
                                tablist[t] += [
                                    (tokenizer.decode(advers.unsqueeze(0)), ii,
                                     nb_vois)
                                ]
                            elif not (first(
                                    tablist[t][-1]) == tokenizer.decode(
                                        advers.unsqueeze(0))):
                                tablist[t] += [
                                    (tokenizer.decode(advers.unsqueeze(0)), ii,
                                     nb_vois)
                                ]
                            adverslist += [advers]
                        word_balance_memory[ii] = float(
                            model(replacelist(xvar, indlistvar, adverslist),
                                  labels=1 - yvar)[0]) - float(
                                      model(replacelist(
                                          xvar, indlistvar, adverslist),
                                            labels=yvar)[0])
                        if word_balance_memory[ii] < 0:
                            fool = True

            elif ord == 2:
                grad = delta.grad.data
                grad = tozero(grad, indlistvar)
                grad = normalize_by_pnorm(grad)
                delta.data = delta.data + batch_multiply(eps_iter, grad)
                delta.data = clamp(embvar.data + delta.data, clip_min,
                                   clip_max) - embvar.data
                if eps is not None:
                    delta.data = clamp_by_pnorm(delta.data, ord, eps)
                with torch.no_grad():
                    delta.data = tozero(delta.data, indlistvar)
                    if (ii % 300) == 0:
                        adverslist = []
                        for t in range(nb):
                            advers, nb_vois = neighboors_np_dens_cand(
                                (embvar + delta)[0][indlistvar[t]], rayon,
                                candid[t])
                            advers = int(advers[0])
                            advers = torch.tensor(convers[t][advers])
                            if len(tablist[t]) == 0:
                                tablist[t] += [
                                    (tokenizer.decode(advers.unsqueeze(0)), ii,
                                     nb_vois)
                                ]
                            elif not (first(
                                    tablist[t][-1]) == tokenizer.decode(
                                        advers.unsqueeze(0))):
                                tablist[t] += [
                                    (tokenizer.decode(advers.unsqueeze(0)), ii,
                                     nb_vois)
                                ]
                            adverslist += [advers]
                        word_balance_memory[ii] = float(
                            model(replacelist(xvar, indlistvar, adverslist),
                                  labels=1 - yvar)[0]) - float(
                                      model(replacelist(
                                          xvar, indlistvar, adverslist),
                                            labels=yvar)[0])
                        if word_balance_memory[ii] < 0:
                            fool = True

            elif ord == 1:
                grad = delta.grad.data
                grad_sign = tozero(grad_sign, indvar)
                abs_grad = torch.abs(grad)

                batch_size = grad.size(0)
                view = abs_grad.view(batch_size, -1)
                view_size = view.size(1)
                if l1_sparsity is None:
                    vals, idx = view.topk(1)
                else:
                    vals, idx = view.topk(
                        int(np.round((1 - l1_sparsity) * view_size)))

                out = torch.zeros_like(view).scatter_(1, idx, vals)
                out = out.view_as(grad)
                grad = grad.sign() * (out > 0).float()
                grad = normalize_by_pnorm(grad, p=1)
                delta.data = delta.data + batch_multiply(eps_iter, grad)

                delta.data = batch_l1_proj(delta.data.cpu(), eps)
                if embvar.is_cuda:
                    delta.data = delta.data.cuda()
                delta.data = clamp(embvar.data + delta.data, clip_min,
                                   clip_max) - embvar.data
            else:
                error = "Only ord = inf, ord = 1 and ord = 2 have been implemented"
                raise NotImplementedError(error)
            delta.grad.data.zero_()
            with torch.no_grad():
                loss_memory[ii] = loss

            ii += 1

        #plt.plot(loss_memory)
        #plt.title("evolution of embed loss")
        #plt.show()
        #plt.plot(word_balance_memory)
        #plt.title("evolution of word loss difference")
        #plt.show()
        emb_adv = clamp(embvar + delta, clip_min, clip_max)
        return emb_adv, word_balance_memory, loss_memory, tablist, fool
def perturb_russian_roulette(xvar,
                             yvar,
                             predict,
                             stop_prob,
                             eps,
                             eps_iter,
                             loss_fn,
                             delta_init=None,
                             minimize=False,
                             ord=np.inf,
                             clip_min=0.0,
                             clip_max=1.0,
                             l1_sparsity=None):
    """
    Iteratively maximize the loss over the input. It is a shared method for
    iterative attacks including IterativeGradientSign, LinfPGD, etc.

    :param xvar: input data.
    :param yvar: input labels.
    :param predict: forward pass function.
    :param stop_prob: probability of stopping at each iteration.
    :param eps: maximum distortion.
    :param eps_iter: attack step size.
    :param loss_fn: loss function.
    :param delta_init: (optional) tensor contains the random initialization.
    :param minimize: (optional bool) whether to minimize or maximize the loss.
    :param ord: (optional) the order of maximum distortion (inf or 2).
    :param clip_min: mininum value per input dimension.
    :param clip_max: maximum value per input dimension.
    :param l1_sparsity: sparsity value for L1 projection.
                  - if None, then perform regular L1 projection.
                  - if float value, then perform sparse L1 descent from
                    Algorithm 1 in https://arxiv.org/pdf/1904.13000v1.pdf
    :return: tensor containing the perturbed input.
    """
    if delta_init is not None:
        delta = delta_init
    else:
        delta = torch.zeros_like(xvar)

    delta.requires_grad_()
    continue_prob = 1.
    delta_rr = delta.clone().detach()
    coin = bernoulli(1 - stop_prob)
    while coin.rvs():
        outputs = predict(xvar + delta)
        loss = loss_fn(outputs, yvar)
        if minimize:
            loss = -loss

        loss.backward()

        delta_prev = delta.clone().detach()
        if ord == np.inf:
            grad_sign = delta.grad.data.sign()
            delta.data = delta.data + batch_multiply(eps_iter, grad_sign)
            delta.data = batch_clamp(eps, delta.data)
            delta.data = clamp(xvar.data + delta.data, clip_min,
                               clip_max) - xvar.data

        elif ord == 2:
            grad = delta.grad.data
            grad = normalize_by_pnorm(grad)
            delta.data = delta.data + batch_multiply(eps_iter, grad)
            delta.data = clamp(xvar.data + delta.data, clip_min,
                               clip_max) - xvar.data
            if eps is not None:
                delta.data = clamp_by_pnorm(delta.data, ord, eps)

        elif ord == 1:
            grad = delta.grad.data
            abs_grad = torch.abs(grad)

            batch_size = grad.size(0)
            view = abs_grad.view(batch_size, -1)
            view_size = view.size(1)
            if l1_sparsity is None:
                vals, idx = view.topk(1)
            else:
                vals, idx = view.topk(
                    int(np.round((1 - l1_sparsity) * view_size)))

            out = torch.zeros_like(view).scatter_(1, idx, vals)
            out = out.view_as(grad)
            grad = grad.sign() * (out > 0).float()
            grad = normalize_by_pnorm(grad, p=1)
            delta.data = delta.data + batch_multiply(eps_iter, grad)

            delta.data = batch_l1_proj(delta.data.cpu(), eps)
            if xvar.is_cuda:
                delta.data = delta.data.cuda()
            delta.data = clamp(xvar.data + delta.data, clip_min,
                               clip_max) - xvar.data
        else:
            error = "Only ord = inf, ord = 1 and ord = 2 have been implemented"
            raise NotImplementedError(error)
        delta.grad.data.zero_()

        continue_prob *= (1 - stop_prob)
        delta_rr += continue_prob * (delta - delta_prev)

    x_adv = clamp(xvar + delta_rr, clip_min, clip_max)
    return x_adv
Ejemplo n.º 5
0
def masked_perturb_iterative(xvar, yvar, predict, nb_iter, eps, eps_iter, loss_fn,
                      delta_init=None, minimize=False, ord=np.inf,
                      clip_min=0.0, clip_max=1.0,
                      l1_sparsity=None, mask_steps=100, device="cuda:0"):
    """
    Iteratively maximize the loss over the input. It is a shared method for
    iterative attacks including IterativeGradientSign, LinfPGD, etc.
    :param xvar: input data.
    :param yvar: input labels.
    :param predict: forward pass function.
    :param nb_iter: number of iterations.
    :param eps: maximum distortion.
    :param eps_iter: attack step size.
    :param loss_fn: loss function.
    :param delta_init: (optional) tensor contains the random initialization.
    :param minimize: (optional bool) whether to minimize or maximize the loss.
    :param ord: (optional) the order of maximum distortion (inf or 2).
    :param clip_min: mininum value per input dimension.
    :param clip_max: maximum value per input dimension.
    :param l1_sparsity: sparsity value for L1 projection.
                  - if None, then perform regular L1 projection.
                  - if float value, then perform sparse L1 descent from
                    Algorithm 1 in https://arxiv.org/pdf/1904.13000v1.pdf
    :param mask_steps: number of times a mask should be drawn and a delta computed.
    :return: tensor containing the perturbed input.
    """
    if delta_init is not None:
        delta = delta_init
    else:
        delta = torch.zeros_like(xvar)

    delta.requires_grad_()
    for ii in tqdm(range(nb_iter)):
        new_delta = 0 # added
        for jj in range(mask_steps): # added

            outputs = predict(xvar + delta)

            # MASKED part
            mask = torch.Tensor(np.random.randint(0,2,size=outputs.shape[1])) # added
            mask = torch.stack([mask for i in range(outputs.shape[0])])

            # force true label to not be masked
            for i in range(len(yvar)):
                mask[i][yvar[i]] = 1

            # allow for the multiplciaiton in log space
            mask[mask == 0] = -100000

            mask = mask.to(device)

            outputs = outputs * mask

            loss = loss_fn(outputs, yvar)
            if minimize:
                loss = -loss

            loss.backward()
            if ord == np.inf:
                grad_sign = delta.grad.data.sign()
                delta.data = delta.data + batch_multiply(eps_iter, grad_sign)
                delta.data = batch_clamp(eps, delta.data)
                delta.data = clamp(xvar.data + delta.data, clip_min, clip_max
                                   ) - xvar.data

            elif ord == 2:
                grad = delta.grad.data
                grad = normalize_by_pnorm(grad)
                delta.data = delta.data + batch_multiply(eps_iter, grad)
                delta.data = clamp(xvar.data + delta.data, clip_min, clip_max
                                   ) - xvar.data
                if eps is not None:
                    delta.data = clamp_by_pnorm(delta.data, ord, eps)

            elif ord == 1:
                grad = delta.grad.data
                abs_grad = torch.abs(grad)

                batch_size = grad.size(0)
                view = abs_grad.view(batch_size, -1)
                view_size = view.size(1)
                if l1_sparsity is None:
                    vals, idx = view.topk(1)
                else:
                    vals, idx = view.topk(
                        int(np.round((1 - l1_sparsity) * view_size)))

                out = torch.zeros_like(view).scatter_(1, idx, vals)
                out = out.view_as(grad)
                grad = grad.sign() * (out > 0).float()
                grad = normalize_by_pnorm(grad, p=1)
                delta.data = delta.data + batch_multiply(eps_iter, grad)

                delta.data = batch_l1_proj(delta.data.cpu(), eps)
                if xvar.is_cuda:
                    delta.data = delta.data.to(device)
                delta.data = clamp(xvar.data + delta.data, clip_min, clip_max
                                   ) - xvar.data
            else:
                error = "Only ord = inf, ord = 1 and ord = 2 have been implemented"
                raise NotImplementedError(error)

            new_delta += delta.data # added
            delta.grad.data.zero_()

        delta.data = new_delta / mask_steps # added

    x_adv = clamp(xvar + delta, clip_min, clip_max)
    return x_adv, delta
Ejemplo n.º 6
0
def _batch_l2_scale(x, rownorm):
    from advertorch.utils import clamp_by_pnorm
    return clamp_by_pnorm(x, 2., rownorm)
Ejemplo n.º 7
0
def perturb_iterative(xvar,
                      yvar,
                      predict,
                      nb_iter,
                      eps,
                      eps_iter,
                      loss_fn,
                      delta_init=None,
                      minimize=False,
                      ord=np.inf,
                      clip_min=0.0,
                      clip_max=1.0,
                      beta=0.5,
                      early_stop=True):
    """
    Iteratively maximize the loss over the input. It is a shared method for
    iterative attacks including IterativeGradientSign, LinfPGD, etc.

    :param xvar: input data.
    :param yvar: input labels.
    :param predict: forward pass function.
    :param nb_iter: number of iterations.
    :param eps: maximum distortion.
    :param eps_iter: attack step size per iteration.
    :param loss_fn: loss function.
    :param delta_init: (optional) tensor contains the random initialization.
    :param minimize: (optional bool) whether to minimize or maximize the loss.
    :param ord: (optional) the order of maximum distortion (inf or 2).
    :param clip_min: (optional float) mininum value per input dimension.
    :param clip_max: (optional float) maximum value per input dimension.
    :return: tensor containing the perturbed input.
    """
    if delta_init is not None:
        delta = delta_init
    else:
        delta = torch.zeros_like(xvar)
    count = 0
    delta.requires_grad_()
    for ii in range(nb_iter):
        count += 1
        loss, w_loss = loss_fn(predict, yvar, xvar, xvar + delta)
        outputs = predict(xvar + delta)
        p = torch.argmax(outputs, dim=1)
        if torch.max(p == yvar) != 1 and early_stop:
            break  # 攻击成功提前结束迭代
        predict.zero_grad()
        loss.backward(retain_graph=True)
        g1 = torch.mean(delta.grad.data.abs().reshape(-1, 28 * 28)).float()
        delta.grad.data.zero_()
        w_loss.backward(retain_graph=True)
        g2 = torch.mean(delta.grad.data.abs().reshape(-1, 28 * 28)).float()
        g = g1 / g2
        g = torch.min(g, torch.tensor(1e6))
        if count % 5 == 0:  # may not coverage
            beta = beta / 10
        delta.grad.data.zero_()
        # print('loss',loss)
        # print('w_loss', w_loss)
        # print(count)
        # print((p == yvar).sum())
        final_loss = loss + beta * g * w_loss
        final_loss.backward(retain_graph=True)

        if ord == np.inf:
            grad_sign = delta.grad.data.sign()
            delta.data = delta.data + batch_multiply(eps_iter, grad_sign)
            delta.data = batch_clamp(eps, delta.data)
            delta.data = clamp(xvar.data + delta.data, clip_min,
                               clip_max) - xvar.data
        elif ord == 1:
            grad = delta.grad.data
            grad = normalize_by_pnorm(grad, 1)
            grad = grad * 28 * 28
            delta.data = delta.data + batch_multiply(eps_iter, grad)
            delta.data = batch_clamp(eps, delta.data)
            delta.data = clamp(xvar.data + delta.data, clip_min,
                               clip_max) - xvar.data
        elif ord == 2:
            grad = delta.grad.data
            grad = normalize_by_pnorm(grad)
            delta.data = delta.data + batch_multiply(eps_iter, grad)
            delta.data = clamp(xvar.data + delta.data, clip_min,
                               clip_max) - xvar.data
            if eps is not None:
                delta.data = clamp_by_pnorm(delta.data, ord, eps)
        else:
            error = "Only ord = inf and ord = 2 have been implemented"
            raise NotImplementedError(error)

        delta.grad.data.zero_()

    x_adv = clamp(xvar + delta, clip_min, clip_max)
    iter_count = count
    return x_adv, iter_count