def perturb(self, source, guide, delta=None):
        """
        Given source, returns their adversarial counterparts
        with representations close to that of the guide.

        :param source: input tensor which we want to perturb.
        :param guide: targeted input.
        :param delta: tensor contains the random initialization.
        :return: tensor containing perturbed inputs.
        """
        # Initialization
        if delta is None:
            delta = torch.zeros_like(source)
            if self.rand_init:
                delta = delta.uniform_(-self.eps, self.eps)
        else:
            delta = delta.detach()

        delta.requires_grad_()

        source = replicate_input(source)
        guide = replicate_input(guide)
        guide_ftr = self.predict(guide).detach()

        xadv = perturb_iterative(source, guide_ftr, self.predict,
                                 self.nb_iter, eps_iter=self.eps_iter,
                                 loss_fn=self.loss_fn, minimize=True,
                                 ord=np.inf, eps=self.eps,
                                 clip_min=self.clip_min,
                                 clip_max=self.clip_max,
                                 delta_init=delta)

        xadv = clamp(xadv, self.clip_min, self.clip_max)

        return xadv.data
Exemplo n.º 2
0
    def _verify_and_process_inputs(self, x, y):
        if self.targeted:
            assert y is not None

        if not self.targeted:
            if y is None:
                y = self._get_predicted_label(x)

        x = replicate_input(x)
        y = replicate_input(y)
        return x, y
Exemplo n.º 3
0
 def perturb_single(self, x, y):
     # x shape [C * H * W]
     if self.comply_with_foolbox is True:
         np.random.seed(233333)
         rand_np = np.random.permutation(x.shape[1] * x.shape[2])
         pixels = torch.from_numpy(rand_np)
     else:
         pixels = torch.randperm(x.shape[1] * x.shape[2])
     pixels = pixels.to(x.device)
     pixels = pixels[:self.max_pixels]
     for ii in range(self.max_pixels):
         row = pixels[ii] % x.shape[2]
         col = pixels[ii] // x.shape[2]
         for val in [self.clip_min, self.clip_max]:
             adv = replicate_input(x)
             for mm in range(x.shape[0]):
                 adv[mm, row, col] = val
             out_label = self._get_predicted_label(adv.unsqueeze(0))
             if self.targeted is True:
                 if int(out_label[0]) == int(y):
                     return adv
             else:
                 if int(out_label[0]) != int(y):
                     return adv
     return x
Exemplo n.º 4
0
    def perturb(self, x, y=None):
        x, y = self._verify_and_process_inputs(x, y)

        # Initialization
        if y is None:
            y = self._get_predicted_label(x)
        x = replicate_input(x)
        batch_size = len(x)
        coeff_lower_bound = x.new_zeros(batch_size)
        coeff_upper_bound = x.new_ones(batch_size) * CARLINI_COEFF_UPPER
        loss_coeffs = torch.ones_like(y).float() * self.initial_const
        final_l2distsqs = [CARLINI_L2DIST_UPPER] * batch_size
        final_labels = [INVALID_LABEL] * batch_size
        final_advs = x
        x_atanh = self._get_arctanh_x(x)
        y_onehot = to_one_hot(y, self.num_classes).float()

        final_l2distsqs = torch.FloatTensor(final_l2distsqs).to(x.device)
        final_labels = torch.LongTensor(final_labels).to(x.device)

        # Start binary search
        for outer_step in range(self.binary_search_steps):
            delta = nn.Parameter(torch.zeros_like(x))
            optimizer = optim.Adam([delta], lr=self.learning_rate)
            cur_l2distsqs = [CARLINI_L2DIST_UPPER] * batch_size
            cur_labels = [INVALID_LABEL] * batch_size
            cur_l2distsqs = torch.FloatTensor(cur_l2distsqs).to(x.device)
            cur_labels = torch.LongTensor(cur_labels).to(x.device)
            prevloss = PREV_LOSS_INIT

            # record current output
            cur_output = torch.zeros(x.size()[0],
                                     self.num_classes).float().cuda()

            if (self.repeat and outer_step == (self.binary_search_steps - 1)):
                loss_coeffs = coeff_upper_bound
            for ii in range(self.max_iterations):
                loss, l2distsq, output, adv_img = \
                    self._forward_and_update_delta(
                        optimizer, x_atanh, delta, y_onehot, loss_coeffs)
                if self.abort_early:
                    if ii % (self.max_iterations // NUM_CHECKS or 1) == 0:
                        if loss > prevloss * ONE_MINUS_EPS:
                            break
                        prevloss = loss

                self._update_if_smaller_dist_succeed(adv_img, y, output,
                                                     l2distsq, batch_size,
                                                     cur_l2distsqs, cur_labels,
                                                     final_l2distsqs,
                                                     final_labels, final_advs,
                                                     cur_output)

            self._update_loss_coeffs(y, cur_labels, batch_size, loss_coeffs,
                                     coeff_upper_bound, coeff_lower_bound,
                                     cur_output)

        return final_advs
Exemplo n.º 5
0
 def _perturb_seed_pixel(self, x, p, row, col):
     x_pert = replicate_input(x)
     for ii in range(x.shape[0]):
         if x[ii, row, col] > 0:
             x_pert[ii, row, col] = p
         elif x[ii, row, col] < 0:
             x_pert[ii, row, col] = -1 * p
         else:
             x_pert[ii, row, col] = 0
     return x_pert
Exemplo n.º 6
0
    def perturb(self, x, y=None):
        x, y = self._verify_and_process_inputs(x, y)

        # Initialization
        if y is None:
            y = self._get_predicted_label(x)
        x = replicate_input(x)
        # batch_size = len(x)

        final_advs = x
        x_atanh = self._get_arctanh_x(x)
        y_onehot = to_one_hot(y, self.num_classes).float()

        delta = nn.Parameter(torch.zeros_like(x))
        optimizer = optim.Adam([delta], lr=self.learning_rate)
        prevloss = PREV_LOSS_INIT

        for ii in range(self.max_iterations):
            # loss, l2distsq, output, adv_img = \
            #     self._forward_and_update_delta(
            #         optimizer, x_atanh, delta, y_onehot, self.c)

            optimizer.zero_grad()
            adv = tanh_rescale(delta + x_atanh, self.clip_min, self.clip_max)
            transimgs_rescale = tanh_rescale(x_atanh, self.clip_min, self.clip_max)
            output = self.predict(adv)
            l2distsq = calc_l2distsq(adv, transimgs_rescale)
            loss, l2dist, adv_loss = self._loss_fn(output, y_onehot, l2distsq, self.c)
            loss.backward()
            optimizer.step()

            if ii % 1000 == 1:
                print('step: {}, dis: {:.2f}, loss1: {:.2f}.'.format(ii, l2dist.item(), adv_loss.item()))

            # if self.abort_early:
            #     if ii % (self.max_iterations // NUM_CHECKS or 1) == 0:
            #         if loss > prevloss * ONE_MINUS_EPS:
            #             break
            #         prevloss = loss

            final_advs = adv.data
        return final_advs
Exemplo n.º 7
0
 def _verify_and_process_inputs(self, x, y):
     x = replicate_input_withgrad(x)
     y = replicate_input(y)
     return x, y
    def perturb(self, x, y=None):
        """
        :param x:    clean images
        :param y:    clean labels, if None we use the predicted labels
        """

        self.device = x.device
        self.orig_dim = list(x.shape[1:])
        self.ndims = len(self.orig_dim)

        x = x.detach().clone().float().to(self.device)
        # assert next(self.predict.parameters()).device == x.device

        y_pred = self._get_predicted_label(x)
        if y is None:
            y = y_pred.detach().clone().long().to(self.device)
        else:
            y = y.detach().clone().long().to(self.device)
        pred = y_pred == y
        corr_classified = pred.float().sum()
        if self.verbose:
            print('Clean accuracy: {:.2%}'.format(pred.float().mean()))
        if pred.sum() == 0:
            return x
        pred = self.check_shape(pred.nonzero().squeeze())

        startt = time.time()
        # runs the attack only on correctly classified points
        im2 = replicate_input(x[pred])
        la2 = replicate_input(y[pred])
        if len(im2.shape) == self.ndims:
            im2 = im2.unsqueeze(0)
        bs = im2.shape[0]
        u1 = torch.arange(bs)
        adv = im2.clone()
        adv_c = x.clone()
        res2 = 1e10 * torch.ones([bs]).to(self.device)
        res_c = torch.zeros([x.shape[0]]).to(self.device)
        x1 = im2.clone()
        x0 = im2.clone().reshape([bs, -1])
        counter_restarts = 0

        while counter_restarts < self.n_restarts:
            if counter_restarts > 0:
                if self.norm == 'Linf':
                    t = 2 * torch.rand(x1.shape).to(self.device) - 1
                    x1 = im2 + (
                        torch.min(
                            res2,
                            self.eps * torch.ones(res2.shape).to(self.device)
                        ).reshape([-1, *([1] * self.ndims)])
                    ) * t / (t.reshape([t.shape[0], -1]).abs()
                             .max(dim=1, keepdim=True)[0]
                             .reshape([-1, *([1] * self.ndims)])) * .5
                elif self.norm == 'L2':
                    t = torch.randn(x1.shape).to(self.device)
                    x1 = im2 + (
                        torch.min(
                            res2,
                            self.eps * torch.ones(res2.shape).to(self.device)
                        ).reshape([-1, *([1] * self.ndims)])
                    ) * t / ((t ** 2)
                             .view(t.shape[0], -1)
                             .sum(dim=-1)
                             .sqrt()
                             .view(t.shape[0], *([1] * self.ndims))) * .5
                elif self.norm == 'L1':
                    t = torch.randn(x1.shape).to(self.device)
                    x1 = im2 + (torch.min(
                        res2,
                        self.eps * torch.ones(res2.shape).to(self.device)
                    ).reshape([-1, *([1] * self.ndims)])
                    ) * t / (t.abs().view(t.shape[0], -1)
                             .sum(dim=-1)
                             .view(t.shape[0], *([1] * self.ndims))) / 2

                x1 = x1.clamp(0.0, 1.0)

            counter_iter = 0
            while counter_iter < self.n_iter:
                with torch.no_grad():
                    df, dg = self.get_diff_logits_grads_batch(x1, la2)
                    if self.norm == 'Linf':
                        dist1 = df.abs() / (1e-12 +
                                            dg.abs()
                                            .view(dg.shape[0], dg.shape[1], -1)
                                            .sum(dim=-1))
                    elif self.norm == 'L2':
                        dist1 = df.abs() / (1e-12 + (dg ** 2)
                                            .view(dg.shape[0], dg.shape[1], -1)
                                            .sum(dim=-1).sqrt())
                    elif self.norm == 'L1':
                        dist1 = df.abs() / (1e-12 + dg.abs().reshape(
                            [df.shape[0], df.shape[1], -1]).max(dim=2)[0])
                    else:
                        raise ValueError('norm not supported')
                    ind = dist1.min(dim=1)[1]
                    dg2 = dg[u1, ind]
                    b = (- df[u1, ind] +
                         (dg2 * x1).view(x1.shape[0], -1).sum(dim=-1))
                    w = dg2.reshape([bs, -1])

                    if self.norm == 'Linf':
                        d3 = self.projection_linf(
                            torch.cat((x1.reshape([bs, -1]), x0), 0),
                            torch.cat((w, w), 0),
                            torch.cat((b, b), 0))
                    elif self.norm == 'L2':
                        d3 = self.projection_l2(
                            torch.cat((x1.reshape([bs, -1]), x0), 0),
                            torch.cat((w, w), 0),
                            torch.cat((b, b), 0))
                    elif self.norm == 'L1':
                        d3 = self.projection_l1(
                            torch.cat((x1.reshape([bs, -1]), x0), 0),
                            torch.cat((w, w), 0),
                            torch.cat((b, b), 0))
                    d1 = torch.reshape(d3[:bs], x1.shape)
                    d2 = torch.reshape(d3[-bs:], x1.shape)
                    if self.norm == 'Linf':
                        a0 = d3.abs().max(dim=1, keepdim=True)[0]\
                            .view(-1, *([1] * self.ndims))
                    elif self.norm == 'L2':
                        a0 = (d3 ** 2).sum(dim=1, keepdim=True).sqrt()\
                            .view(-1, *([1] * self.ndims))
                    elif self.norm == 'L1':
                        a0 = d3.abs().sum(dim=1, keepdim=True)\
                            .view(-1, *([1] * self.ndims))
                    a0 = torch.max(a0, 1e-8 * torch.ones(
                        a0.shape).to(self.device))
                    a1 = a0[:bs]
                    a2 = a0[-bs:]
                    alpha = torch.min(torch.max(a1 / (a1 + a2),
                                                torch.zeros(a1.shape)
                                                .to(self.device))[0],
                                      self.alpha_max * torch.ones(a1.shape)
                                      .to(self.device))
                    x1 = ((x1 + self.eta * d1) * (1 - alpha) +
                          (im2 + d2 * self.eta) * alpha).clamp(0.0, 1.0)

                    is_adv = self._get_predicted_label(x1) != la2

                    if is_adv.sum() > 0:
                        ind_adv = is_adv.nonzero().squeeze()
                        ind_adv = self.check_shape(ind_adv)
                        if self.norm == 'Linf':
                            t = (x1[ind_adv] - im2[ind_adv]).reshape(
                                [ind_adv.shape[0], -1]).abs().max(dim=1)[0]
                        elif self.norm == 'L2':
                            t = ((x1[ind_adv] - im2[ind_adv]) ** 2)\
                                .view(ind_adv.shape[0], -1).sum(dim=-1).sqrt()
                        elif self.norm == 'L1':
                            t = (x1[ind_adv] - im2[ind_adv])\
                                .abs().view(ind_adv.shape[0], -1).sum(dim=-1)
                        adv[ind_adv] = x1[ind_adv] * (t < res2[ind_adv]).\
                            float().reshape([-1, *([1] * self.ndims)]) \
                            + adv[ind_adv]\
                            * (t >= res2[ind_adv]).float().reshape(
                            [-1, *([1] * self.ndims)])
                        res2[ind_adv] = t * (t < res2[ind_adv]).float()\
                            + res2[ind_adv] * (t >= res2[ind_adv]).float()
                        x1[ind_adv] = im2[ind_adv] + (
                            x1[ind_adv] - im2[ind_adv]) * self.beta

                    counter_iter += 1

            counter_restarts += 1

        ind_succ = res2 < 1e10
        if self.verbose:
            print('success rate: {:.0f}/{:.0f}'
                  .format(ind_succ.float().sum(), corr_classified) +
                  ' (on correctly classified points) in {:.1f} s'
                  .format(time.time() - startt))

        res_c[pred] = res2 * ind_succ.float() + 1e10 * (1 - ind_succ.float())
        ind_succ = self.check_shape(ind_succ.nonzero().squeeze())
        adv_c[pred[ind_succ]] = adv[ind_succ].clone()

        return adv_c
Exemplo n.º 9
0
    def perturb_single(self, x, y):
        # x shape C * H * W
        rescaled_x = replicate_input(x)
        best_img = None
        best_dist = np.inf
        rescaled_x, lb, ub = self._rescale_to_m0d5_to_0d5(
            rescaled_x, vmin=self.clip_min, vmax=self.clip_max)

        if self.comply_with_foolbox is True:
            np.random.seed(233333)
            init_rand = np.random.permutation(x.shape[1] * x.shape[2])
        else:
            init_rand = None

        # Algorithm 3 in v1

        pxy = self._random_sample_seeds(
            x.shape[1], x.shape[2], seed_ratio=self.seed_ratio,
            max_nb_seeds=self.max_nb_seeds, init_rand=init_rand)
        pxy = pxy.to(x.device)
        ii = 0
        if self.comply_with_foolbox:
            adv = rescaled_x
        while ii < self.round_ub:
            if not self.comply_with_foolbox:
                adv = replicate_input(rescaled_x)
            # Computing the function g using the neighbourhood
            if self.comply_with_foolbox:
                rand_np = np.random.permutation(len(pxy))[:self.max_nb_seeds]
                pxy = pxy[torch.from_numpy(rand_np)]
            else:
                pxy = pxy[torch.randperm(len(pxy))[:self.max_nb_seeds]]

            pert_lst = [
                self._perturb_seed_pixel(
                    adv, self.p, int(row), int(col)) for row, col in pxy]
            # Compute the score for each pert in the list
            scores, curr_best_img, curr_best_dist = self._rescale_x_score(
                self.predict, pert_lst, y, x, best_dist)
            if curr_best_img is not None:
                best_img = curr_best_img
                best_dist = curr_best_dist
            _, indices = torch.sort(scores)
            indices = indices[:self.t]
            pxy_star = pxy[indices.data.cpu()]
            # Generation of the perturbed image adv
            for row, col in pxy_star:
                for b in range(x.shape[0]):
                    adv[b, int(row), int(col)] = self._cyclic(
                        self.r, lb, ub, adv[b, int(row), int(col)])
            # Check whether the perturbed image is an adversarial image
            revert_adv = self._revert_rescale(adv)
            curr_lb = self._get_predicted_label(revert_adv.unsqueeze(0))
            curr_dist = torch.sum((x - revert_adv) ** 2)
            if (is_successful(int(curr_lb), y, self.targeted) and
                    curr_dist < best_dist):
                best_img = revert_adv
                best_dist = curr_dist
                return best_img
            elif is_successful(curr_lb, y, self.targeted):
                return best_img
            pxy = [
                (row, col)
                for rowcenter, colcenter in pxy_star
                for row in range(
                    int(rowcenter) - self.d, int(rowcenter) + self.d + 1)
                for col in range(
                    int(colcenter) - self.d, int(colcenter) + self.d + 1)]
            pxy = list(set((row, col) for row, col in pxy if (
                0 <= row < x.shape[2] and 0 <= col < x.shape[1])))
            pxy = torch.FloatTensor(pxy)
            ii += 1
        if best_img is None:
            return x
        return best_img
Exemplo n.º 10
0
    def perturb(self, x, y=None):

        x, y = self._verify_and_process_inputs(x, y)

        # Initialization
        if y is None:
            y = self._get_predicted_label(x)

        x = replicate_input(x)
        batch_size = len(x)
        coeff_lower_bound = x.new_zeros(batch_size)
        coeff_upper_bound = x.new_ones(batch_size) * COEFF_UPPER
        loss_coeffs = torch.ones_like(y).float() * self.initial_const

        final_dist = [DIST_UPPER] * batch_size
        final_labels = [INVALID_LABEL] * batch_size

        final_advs = x.clone()
        y_onehot = to_one_hot(y, self.num_classes).float()

        final_dist = torch.FloatTensor(final_dist).to(x.device)
        final_labels = torch.LongTensor(final_labels).to(x.device)

        # Start binary search
        for outer_step in range(self.binary_search_steps):

            self.global_step = 0

            # slack vector from the paper
            yy_k = nn.Parameter(x.clone())
            xx_k = x.clone()

            cur_dist = [DIST_UPPER] * batch_size
            cur_labels = [INVALID_LABEL] * batch_size

            cur_dist = torch.FloatTensor(cur_dist).to(x.device)
            cur_labels = torch.LongTensor(cur_labels).to(x.device)

            prevloss = PREV_LOSS_INIT

            if (self.repeat and outer_step == (self.binary_search_steps - 1)):
                loss_coeffs = coeff_upper_bound

            lr = self.learning_rate

            for ii in range(self.max_iterations):

                # reset gradient
                if yy_k.grad is not None:
                    yy_k.grad.detach_()
                    yy_k.grad.zero_()

                # loss over yy_k with only L2 same as C&W
                # we don't update L1 loss with SGD because we use ISTA
                output = self.predict(yy_k)
                l2distsq = calc_l2distsq(yy_k, x)
                loss_opt = self._loss_fn(output,
                                         y_onehot,
                                         None,
                                         l2distsq,
                                         loss_coeffs,
                                         opt=True)
                loss_opt.backward()

                # gradient step
                yy_k.data.add_(-lr, yy_k.grad.data)
                self.global_step += 1

                # ploynomial decay of learning rate
                lr = self.init_learning_rate * \
                    (1 - self.global_step / self.max_iterations)**0.5

                yy_k, xx_k = self._fast_iterative_shrinkage_thresholding(
                    x, yy_k, xx_k)

                # loss ElasticNet or L1 over xx_k
                output = self.predict(xx_k)
                l2distsq = calc_l2distsq(xx_k, x)
                l1dist = calc_l1dist(xx_k, x)

                if self.decision_rule == 'EN':
                    dist = l2distsq + (l1dist * self.beta)
                elif self.decision_rule == 'L1':
                    dist = l1dist
                loss = self._loss_fn(output, y_onehot, l1dist, l2distsq,
                                     loss_coeffs)

                if self.abort_early:
                    if ii % (self.max_iterations // NUM_CHECKS or 1) == 0:
                        if loss > prevloss * ONE_MINUS_EPS:
                            break
                        prevloss = loss

                self._update_if_smaller_dist_succeed(xx_k.data, y, output,
                                                     dist, batch_size,
                                                     cur_dist, cur_labels,
                                                     final_dist, final_labels,
                                                     final_advs)

            self._update_loss_coeffs(y, cur_labels, batch_size, loss_coeffs,
                                     coeff_upper_bound, coeff_lower_bound)

        return final_advs
Exemplo n.º 11
0
    def perturb(self, x, y=None):
        x, y = self._verify_and_process_inputs(x, y)

        # Initialization
        if y is None:
            y = self._get_predicted_label(x)

        x = replicate_input(x)
        batch_size = len(x)
        final_adversarials = x.clone()

        # An array of booleans that stores which samples have not converged
        # yet
        active = torch.ones((batch_size, ), dtype=torch.bool, device=x.device)

        initial_const = self.initial_const
        taus = torch.ones((batch_size, ), device=x.device) * self.initial_tau

        # The previous adversarials. This is used to perform a "warm start"
        # during optimisation
        prev_adversarials = x.clone()

        max_tau = self.initial_tau

        i = 0

        while max_tau > self.min_tau:
            new_adversarials = self._run_attack(x, y, initial_const, taus,
                                                prev_adversarials.clone(),
                                                active).detach()

            # Store the adversarials for the next iteration,
            # even if they failed
            prev_adversarials = new_adversarials

            adversarial_outputs = self._outputs(new_adversarials,
                                                filter_=active)
            successful = self._successful(adversarial_outputs, y).detach()

            # If the Linf distance is lower than tau and the adversarial
            # is successful, use it as the new tau
            linf_distances = torch.max(torch.abs(new_adversarials -
                                                 x).flatten(1),
                                       dim=1)[0]
            linf_lower = linf_distances < taus

            taus = utils.fast_boolean_choice(taus, linf_distances,
                                             linf_lower & successful)

            # Save the remaining adversarials
            replace = successful

            if not self.update_inactive:
                replace = replace & active

            final_adversarials = utils.fast_boolean_choice(
                final_adversarials, new_adversarials, replace)

            taus *= self.tau_factor
            max_tau = taus.max()

            if self.reduce_const:
                initial_const /= 2

            # Drop failed samples or with a low tau
            low_tau = taus <= self.min_tau
            drop = low_tau | (~successful)
            active = active & (~drop)

            if self.tau_check != 0 and (i + 1) % self.tau_check == 0:
                # Causes an implicit sync point
                if not active.any():
                    break

            i += 1

        return final_adversarials
Exemplo n.º 12
0
    def perturb(self, x, y=None):
        x, y = self._verify_and_process_inputs(x, y)

        # Initialization
        if y is None:
            y = self._get_predicted_label(x)

        x = replicate_input(x)
        batch_size = len(x)
        final_adversarials = x.clone()

        # An array of booleans that stores which samples have not converged
        # yet
        active = torch.ones((batch_size, ), dtype=torch.bool, device=x.device)

        initial_const = self.initial_const
        taus = torch.ones((batch_size, ), device=x.device) * self.initial_tau

        # The previous adversarials. This is used to perform a "warm start"
        # during optimisation
        prev_adversarials = x.clone()

        while torch.any(active):
            new_adversarials = self._run_attack(
                x[active],
                y[active],
                initial_const,
                taus[active],
                prev_adversarials[active].clone(),
                outer_active_mask=active).detach()

            # Store the adversarials for the next iteration,
            # even if they failed
            prev_adversarials[active] = new_adversarials

            adversarial_outputs = self._outputs(new_adversarials,
                                                active_mask=active)
            successful = self._successful(adversarial_outputs,
                                          y[active]).detach()

            # If the Linf distance is lower than tau and the adversarial
            # is successful, use it as the new tau
            linf_distances = torch.max(torch.abs(new_adversarials -
                                                 x[active]).flatten(1),
                                       dim=1)[0]
            assert linf_distances.shape == (len(new_adversarials), )

            linf_lower = linf_distances < taus[active]

            utils.replace_active(linf_distances, taus, active,
                                 linf_lower & successful)

            # Save the remaining adversarials
            utils.replace_active(new_adversarials, final_adversarials, active,
                                 successful)

            taus *= self.tau_factor

            if self.reduce_const:
                initial_const /= 2

            # Drop failed samples or with a low tau
            low_tau = taus[active] <= self.min_tau
            drop = low_tau | (~successful)
            active[active] = ~drop

        return final_adversarials