Esempio n. 1
0
    def run(
        self,
        model: Model,
        inputs: T,
        criterion: Union[Criterion, T],
        *,
        early_stop: Optional[float] = None,
        starting_points: Optional[ep.Tensor] = None,
        **kwargs: Any,
    ) -> T:
        originals, restore_type = ep.astensor_(inputs)

        self._nqueries = {i: 0 for i in range(len(originals))}
        self._set_cos_sin_function(originals)
        self.theta_max = ep.ones(originals, len(originals)) * self._theta_max
        criterion = get_criterion(criterion)
        self._criterion_is_adversarial = get_is_adversarial(criterion, model)

        # Get Starting Point
        if starting_points is not None:
            best_advs = starting_points
        elif starting_points is None:
            init_attack: MinimizationAttack = LinearSearchBlendedUniformNoiseAttack(steps=50)
            best_advs = init_attack.run(model, originals, criterion, early_stop=early_stop)
        else:
            raise ValueError("starting_points {} doesn't exist.".format(starting_points))

        assert self._is_adversarial(best_advs).all()

        # Initialize the direction orthogonalized with the first direction
        fd = best_advs - originals
        norm = ep.norms.l2(fd.flatten(1), axis=1)
        fd = fd / atleast_kd(norm, fd.ndim)
        self._directions_ortho = {i: v.expand_dims(0) for i, v in enumerate(fd)}

        # Load Basis
        if "basis_params" in kwargs:
            self._basis = Basis(originals, **kwargs["basis_params"])
        else:
            self._basis = Basis(originals)

        for _ in range(self._steps):
            # Get candidates. Shape: (n_candidates, batch_size, image_size)
            candidates = self._get_candidates(originals, best_advs)
            candidates = candidates.transpose((1, 0, 2, 3, 4))

            
            best_candidates = ep.zeros_like(best_advs).raw
            for i, o in enumerate(originals):
                o_repeated = ep.concatenate([o.expand_dims(0)] * len(candidates[i]), axis=0)
                index = ep.argmax(self.distance(o_repeated, candidates[i])).raw
                best_candidates[i] = candidates[i][index].raw

            is_success = self.distance(best_candidates, originals) < self.distance(best_advs, originals)
            best_advs = ep.where(atleast_kd(is_success, best_candidates.ndim), ep.astensor(best_candidates), best_advs)

            if all(v > self._max_queries for v in self._nqueries.values()):
                print("Max queries attained for all the images.")
                break
        return restore_type(best_advs)
Esempio n. 2
0
    def _get_best_theta(
            self, 
            function_evolution: Callable[[ep.Tensor], ep.Tensor], 
            best_params: ep.Tensor) -> ep.Tensor:
        v_type = function_evolution(best_params)
        coefficients = ep.zeros(v_type, 2 * self.T).raw
        for i in range(0, self.T):
            coefficients[2* i] = 1 - (i / self.T)
            coefficients[2 * i + 1] = - coefficients[2* i]

        for i,  coeff in enumerate(coefficients):
            params = coeff * self.theta_max
            x_evol = function_evolution(params)
            x = ep.where(
                atleast_kd(best_params == 0, v_type.ndim), 
                x_evol, 
                ep.zeros_like(v_type))

            is_advs = self._is_adversarial(x)

            best_params = ep.where(
                (best_params == 0) * is_advs,
                params,
                best_params
            )
            if (best_params != 0).all():
                break
        
        return best_params  
Esempio n. 3
0
    def _binary_search(self, originals: ep.Tensor, perturbed: ep.Tensor, boost: Optional[bool] = False) -> ep.Tensor:
        # Choose upper thresholds in binary search based on constraint.
        highs = ep.ones(perturbed, len(perturbed))
        d = np.prod(perturbed.shape[1:])
        thresholds = self._BS_gamma / (d * math.sqrt(d))
        lows = ep.zeros_like(highs)

        # Boost Binary search
        if boost:
            boost_vec = 0.1 * originals + 0.9 * perturbed
            is_advs = self._is_adversarial(boost_vec)
            is_advs = atleast_kd(is_advs, originals.ndim)
            originals = ep.where(is_advs.logical_not(), boost_vec, originals)
            perturbed = ep.where(is_advs, boost_vec, perturbed)

        # use this variable to check when mids stays constant and the BS has converged
        old_mids = highs
        iteration = 0
        while ep.any(highs - lows > thresholds) and iteration < self._BS_max_iteration:
            iteration += 1
            mids = (lows + highs) / 2
            mids_perturbed = self._project(originals, perturbed, mids)
            is_adversarial_ = self._is_adversarial(mids_perturbed)

            highs = ep.where(is_adversarial_, mids, highs)
            lows = ep.where(is_adversarial_, lows, mids)

            # check of there is no more progress due to numerical imprecision
            reached_numerical_precision = (old_mids == mids).all()
            old_mids = mids
            if reached_numerical_precision:
                break
        
        results = self._project(originals, perturbed, highs)
        return results
Esempio n. 4
0
 def _add_step_in_circular_direction(degree: ep.Tensor) -> ep.Tensor:
     degree = atleast_kd(degree, direction1.ndim).raw * np.pi / 180
     results = self._cos(degree) * direction1 + self._sin(degree) * direction2
     results = (originals + ep.astensor(results * distances * self._cos(degree))).clip(0, 1)
     if self.with_quantification:
         results = self._quantify(results)
     return results
Esempio n. 5
0
 def _gram_schmidt(self, v: ep.Tensor, ortho_with: ep.Tensor):
     v_repeated = ep.concatenate([v.expand_dims(0)] * len(ortho_with), axis=0)
     
     #inner product
     gs_coeff = (ortho_with * v_repeated).flatten(1).sum(1)
     proj = atleast_kd(gs_coeff, ortho_with.ndim) * ortho_with
     v = v - proj.sum(0)
     return v / ep.norms.l2(v)
Esempio n. 6
0
    def _get_candidates(self, originals: ep.Tensor, best_advs: ep.Tensor) -> ep.Tensor:
        """
        Find the lowest epsilon to misclassified x following the direction: q of class 1 / q + eps*direction of class 0
        """
        epsilons = ep.zeros(originals, len(originals))
        direction_2 = ep.zeros_like(originals)
        while (epsilons == 0).any():
            # if epsilon ==0, we are still searching a good direction
            direction_2 = ep.where(
                atleast_kd(epsilons == 0, direction_2.ndim),
                self._basis.get_vector(self._directions_ortho),
                direction_2
            )
            
            for i, eps_i in enumerate(epsilons):
                if eps_i == 0:
                    self._directions_ortho[i] = ep.concatenate((self._directions_ortho[i], direction_2[i].expand_dims(0)), axis=0)
                    if len(self._directions_ortho[i]) > self.n_ortho + 1:
                        self._directions_ortho[i] = ep.concatenate((self._directions_ortho[i][:1], self._directions_ortho[i][self.n_ortho:]))
                        
            function_evolution = self._get_evolution_function(originals, best_advs, direction_2)
            new_epsilons = self._get_best_theta(function_evolution, epsilons)

            self.theta_max = ep.where(new_epsilons == 0, self.theta_max * self.rho, self.theta_max)
            self.theta_max = ep.where((new_epsilons != 0) * (epsilons == 0), self.theta_max / self.rho, self.theta_max)
            epsilons = new_epsilons

        function_evolution = self._get_evolution_function(originals, best_advs, direction_2)
        if self.with_alpha_line_search:
            epsilons = self._binary_search_on_alpha(function_evolution, epsilons)

        epsilons = epsilons.expand_dims(0)
        if self.with_interpolation:
            epsilons =  ep.concatenate((epsilons, epsilons[0] / 2), axis=0)

        candidates = ep.concatenate([function_evolution(eps).expand_dims(0) for eps in epsilons], axis=0)

        if self.with_interpolation:
            d = self.distance(best_advs, originals)
            delta = self.distance(self._binary_search(originals, candidates[1],  boost=True), originals)
            theta_star = epsilons[0]

            num = theta_star * (4 * delta - d * (self._cos(theta_star.raw) + 3))
            den = 4 * (2 * delta - d * (self._cos(theta_star.raw) + 1))

            theta_hat = num / den
            q_interp = function_evolution(theta_hat)
            if self.with_distance_line_search:
                q_interp = self._binary_search(originals, q_interp,  boost=True)
            candidates = ep.concatenate((candidates, q_interp.expand_dims(0)), axis=0)

        return candidates
Esempio n. 7
0
    def _get_evolution_function(self, originals: ep.Tensor, best_advs: ep.Tensor, direction2: ep.Tensor) -> Callable[[ep.Tensor], ep.Tensor]:
        distances = self.distance(best_advs, originals)
        direction1 = (best_advs - originals).flatten(start=1) / distances.reshape((-1, 1))
        direction1 = direction1.reshape(originals.shape)
        distances = atleast_kd(distances, direction1.ndim)

        def _add_step_in_circular_direction(degree: ep.Tensor) -> ep.Tensor:
            degree = atleast_kd(degree, direction1.ndim).raw * np.pi / 180
            results = self._cos(degree) * direction1 + self._sin(degree) * direction2
            results = (originals + ep.astensor(results * distances * self._cos(degree))).clip(0, 1)
            if self.with_quantification:
                results = self._quantify(results)
            return results

        return _add_step_in_circular_direction
    def run(
        self,
        model: Model,
        inputs: T,
        criterion: Union[Criterion, T],
        perlin_param,
        mask_param,
        *,
        early_stop: Optional[float] = None,
        starting_points: Optional[T] = None,
        **kwargs: Any,
    ) -> T:
        raise_if_kwargs(kwargs)
        originals, restore_type = ep.astensor_(inputs)
        initial_pict = ep.astensor(inputs)
        #print('inputs', inputs.shape)
        del inputs, kwargs

        criterion = get_criterion(criterion)
        perlin_param = perlin_param
        is_adversarial = get_is_adversarial(criterion, model)

        if starting_points is None:
            init_attack: MinimizationAttack
            if self.init_attack is None:
                init_attack = LinearSearchBlendedUniformNoiseAttack(steps=50)
                logging.info(
                    f"Neither starting_points nor init_attack given. Falling"
                    f" back to {init_attack!r} for initialization.")
            else:
                init_attack = self.init_attack
            # TODO: use call and support all types of attacks (once early_stop is
            # possible in __call__)
            best_advs = init_attack.run(model,
                                        originals,
                                        criterion,
                                        early_stop=early_stop)
        else:
            best_advs = ep.astensor(starting_points)

        is_adv = is_adversarial(best_advs)
        if not is_adv.all():
            failed = is_adv.logical_not().float32().sum()
            if starting_points is None:
                raise ValueError(
                    f"init_attack failed for {failed} of {len(is_adv)} inputs")
            else:
                raise ValueError(
                    f"{failed} of {len(is_adv)} starting_points are not adversarial"
                )
        del starting_points

        tb = TensorBoard(logdir=self.tensorboard)

        N = len(originals)
        ndim = originals.ndim
        spherical_steps = ep.ones(originals, N) * self.spherical_step
        source_steps = ep.ones(originals, N) * self.source_step

        tb.scalar("batchsize", N, 0)

        # create two queues for each sample to track success rates
        # (used to update the hyper parameters)
        stats_spherical_adversarial = ArrayQueue(maxlen=100, N=N)
        stats_step_adversarial = ArrayQueue(maxlen=30, N=N)

        bounds = model.bounds

        for step in range(1, self.steps + 1):
            converged = source_steps < self.source_step_convergance
            if converged.all():
                break  # pragma: no cover
            converged = atleast_kd(converged, ndim)

            # TODO: performance: ignore those that have converged
            # (we could select the non-converged ones, but we currently
            # cannot easily invert this in the end using EagerPy)

            unnormalized_source_directions = originals - best_advs
            source_norms = ep.norms.l2(flatten(unnormalized_source_directions),
                                       axis=-1)
            source_directions = unnormalized_source_directions / atleast_kd(
                source_norms, ndim)

            # only check spherical candidates every k steps
            check_spherical_and_update_stats = step % self.update_stats_every_k == 0

            #-------------START----------------
            # MASK
            new_mask = ep.abs(originals - best_advs)
            new_mask /= ep.max(new_mask)
            new_mask = new_mask**mask_param
            mask = new_mask

            # Perlin Noise
            #print('originals shape', originals.numpy().shape)
            perlin_noise = ep.astensor(
                torch.tensor([
                    get_perlin(originals.numpy()[0].transpose((1, 2, 0)),
                               perlin_param)
                ]).to('cuda'))
            #-----------END-----------------

            candidates, spherical_candidates = draw_proposals(
                bounds,
                originals,
                best_advs,
                unnormalized_source_directions,
                source_directions,
                source_norms,
                spherical_steps,
                source_steps,
                mask,
                perlin_noise,
            )
            candidates.dtype == originals.dtype
            spherical_candidates.dtype == spherical_candidates.dtype

            is_adv = is_adversarial(candidates)

            spherical_is_adv: Optional[ep.Tensor]
            if check_spherical_and_update_stats:
                spherical_is_adv = is_adversarial(spherical_candidates)
                stats_spherical_adversarial.append(spherical_is_adv)
                # TODO: algorithm: the original implementation ignores those samples
                # for which spherical is not adversarial and continues with the
                # next iteration -> we estimate different probabilities (conditional vs. unconditional)
                # TODO: thoughts: should we always track this because we compute it anyway
                stats_step_adversarial.append(is_adv)
            else:
                spherical_is_adv = None

            # in theory, we are closer per construction
            # but limited numerical precision might break this
            distances = ep.norms.l2(flatten(originals - candidates), axis=-1)
            closer = distances < source_norms
            is_best_adv = ep.logical_and(is_adv, closer)
            is_best_adv = atleast_kd(is_best_adv, ndim)

            cond = converged.logical_not().logical_and(is_best_adv)
            best_advs = ep.where(cond, candidates, best_advs)

            self.distances_iter[step - 1] = ep.norms.l2(
                flatten(initial_pict - best_advs)).numpy() / (3 * 32 * 32)

            tb.probability("converged", converged, step)
            tb.scalar("updated_stats", check_spherical_and_update_stats, step)
            tb.histogram("norms", source_norms, step)
            tb.probability("is_adv", is_adv, step)
            if spherical_is_adv is not None:
                tb.probability("spherical_is_adv", spherical_is_adv, step)
            tb.histogram("candidates/distances", distances, step)
            tb.probability("candidates/closer", closer, step)
            tb.probability("candidates/is_best_adv", is_best_adv, step)
            tb.probability("new_best_adv_including_converged", is_best_adv,
                           step)
            tb.probability("new_best_adv", cond, step)

            if check_spherical_and_update_stats:
                full = stats_spherical_adversarial.isfull()
                tb.probability("spherical_stats/full", full, step)
                if full.any():
                    probs = stats_spherical_adversarial.mean()
                    cond1 = ep.logical_and(probs > 0.5, full)
                    spherical_steps = ep.where(
                        cond1, spherical_steps * self.step_adaptation,
                        spherical_steps)
                    source_steps = ep.where(
                        cond1, source_steps * self.step_adaptation,
                        source_steps)
                    cond2 = ep.logical_and(probs < 0.2, full)
                    spherical_steps = ep.where(
                        cond2, spherical_steps / self.step_adaptation,
                        spherical_steps)
                    source_steps = ep.where(
                        cond2, source_steps / self.step_adaptation,
                        source_steps)
                    stats_spherical_adversarial.clear(
                        ep.logical_or(cond1, cond2))
                    tb.conditional_mean(
                        "spherical_stats/isfull/success_rate/mean", probs,
                        full, step)
                    tb.probability_ratio("spherical_stats/isfull/too_linear",
                                         cond1, full, step)
                    tb.probability_ratio(
                        "spherical_stats/isfull/too_nonlinear", cond2, full,
                        step)

                full = stats_step_adversarial.isfull()
                tb.probability("step_stats/full", full, step)
                if full.any():
                    probs = stats_step_adversarial.mean()
                    # TODO: algorithm: changed the two values because we are currently tracking p(source_step_sucess)
                    # instead of p(source_step_success | spherical_step_sucess) that was tracked before
                    cond1 = ep.logical_and(probs > 0.25, full)
                    source_steps = ep.where(
                        cond1, source_steps * self.step_adaptation,
                        source_steps)
                    cond2 = ep.logical_and(probs < 0.1, full)
                    source_steps = ep.where(
                        cond2, source_steps / self.step_adaptation,
                        source_steps)
                    stats_step_adversarial.clear(ep.logical_or(cond1, cond2))
                    tb.conditional_mean("step_stats/isfull/success_rate/mean",
                                        probs, full, step)
                    tb.probability_ratio(
                        "step_stats/isfull/success_rate_too_high", cond1, full,
                        step)
                    tb.probability_ratio(
                        "step_stats/isfull/success_rate_too_low", cond2, full,
                        step)

            tb.histogram("spherical_step", spherical_steps, step)
            tb.histogram("source_step", source_steps, step)
        tb.close()

        #print(ep.norms.l2(flatten(originals - best_advs), axis=-1).numpy())
        return restore_type(best_advs)
Esempio n. 9
0
    def run(
        self,
        model: Model,
        inputs: T,
        criterion: Union[Criterion, T],
        *,
        early_stop: Optional[float] = None,
        starting_points: Optional[T] = None,
        **kwargs: Any,
    ) -> T:
        raise_if_kwargs(kwargs)
        originals, restore_type = ep.astensor_(inputs)
        device = inputs.device
        del inputs, kwargs

        criterion = get_criterion(criterion)
        is_adversarial = get_is_adversarial(criterion, model)

        self.qcount = 0
        self.normHistory = np.zeros((int)(self.steps / 100) + 1)

        if starting_points is None:
            init_attack: MinimizationAttack
            if self.init_attack is None:
                init_attack = LinearSearchBlendedUniformNoiseAttack(steps=50)
                logging.info(
                    f"Neither starting_points nor init_attack given. Falling"
                    f" back to {init_attack!r} for initialization.")
            else:
                init_attack = self.init_attack
            # TODO: use call and support all types of attacks (once early_stop is
            # possible in __call__)
            best_advs = init_attack.run(model,
                                        originals,
                                        criterion,
                                        early_stop=early_stop)
            self.qcount += init_attack.qcount
        else:  #move starting points to boundary
            epsilons = np.linspace(0, 1, num=50 + 1, dtype=np.float32)
            best = ep.ones(originals, (len(originals), ))
            for epsilon in epsilons:
                x = (1 - epsilon) * originals + epsilon * starting_points
                is_adv = is_adversarial(x)
                self.qcount += 1

                epsilon = epsilon.item()

                best = ep.minimum(ep.where(is_adv, epsilon, 1.0), best)
                if (best < 1).all():
                    break

            best = atleast_kd(best, originals.ndim)
            x = (1 - best) * originals + best * starting_points
            best_advs = ep.astensor(x)

        self.normHistory[0:] = ep.norms.l2(flatten(best_advs - originals),
                                           axis=-1).numpy()

        is_adv = is_adversarial(best_advs)
        self.qcount += 1
        if not is_adv.all():
            failed = is_adv.logical_not().float32().sum()
            if starting_points is None:
                raise ValueError(
                    f"init_attack failed for {failed} of {len(is_adv)} inputs")
            else:
                raise ValueError(
                    f"{failed} of {len(is_adv)} starting_points are not adversarial"
                )
        del starting_points

        tb = TensorBoard(logdir=self.tensorboard)

        N = len(originals)
        ndim = originals.ndim
        spherical_steps = ep.ones(originals, N) * self.spherical_step
        source_steps = ep.ones(originals, N) * self.source_step

        tb.scalar("batchsize", N, 0)

        # create two queues for each sample to track success rates
        # (used to update the hyper parameters)
        stats_spherical_adversarial = ArrayQueue(maxlen=100, N=N)
        stats_step_adversarial = ArrayQueue(maxlen=30, N=N)

        bounds = model.bounds

        for step in range(1, self.steps + 1):
            converged = source_steps < self.source_step_convergance
            if converged.all():
                break  # pragma: no cover
            converged = atleast_kd(converged, ndim)

            # TODO: performance: ignore those that have converged
            # (we could select the non-converged ones, but we currently
            # cannot easily invert this in the end using EagerPy)

            unnormalized_source_directions = originals - best_advs
            source_norms = ep.norms.l2(flatten(unnormalized_source_directions),
                                       axis=-1)
            source_directions = unnormalized_source_directions / atleast_kd(
                source_norms, ndim)

            # only check spherical candidates every k steps
            check_spherical_and_update_stats = step % self.update_stats_every_k == 0

            candidates, spherical_candidates = draw_proposals(
                bounds, originals, best_advs, unnormalized_source_directions,
                source_directions, source_norms, spherical_steps, source_steps,
                self.surrogate_models, self.ODS, device)
            candidates.dtype == originals.dtype
            spherical_candidates.dtype == spherical_candidates.dtype

            is_adv = is_adversarial(candidates)

            self.qcount += 1
            if self.qcount % 100 == 0:
                self.normHistory[(int)(self.qcount / 100):] = ep.norms.l2(
                    flatten(best_advs - originals), axis=-1).numpy()
                if self.qcount >= self.steps:
                    break

            spherical_is_adv: Optional[ep.Tensor]
            if check_spherical_and_update_stats:
                spherical_is_adv = is_adversarial(spherical_candidates)
                self.qcount += 1
                if self.qcount % 100 == 0:
                    self.normHistory[(int)(self.qcount / 100):] = ep.norms.l2(
                        flatten(best_advs - originals), axis=-1).numpy()
                    if self.qcount >= self.steps:
                        break

                stats_spherical_adversarial.append(spherical_is_adv)
                # TODO: algorithm: the original implementation ignores those samples
                # for which spherical is not adversarial and continues with the
                # next iteration -> we estimate different probabilities (conditional vs. unconditional)
                # TODO: thoughts: should we always track this because we compute it anyway
                stats_step_adversarial.append(is_adv)
            else:
                spherical_is_adv = None

            # in theory, we are closer per construction
            # but limited numerical precision might break this
            distances = ep.norms.l2(flatten(originals - candidates), axis=-1)
            closer = distances < source_norms
            is_best_adv = ep.logical_and(is_adv, closer)
            is_best_adv = atleast_kd(is_best_adv, ndim)

            cond = converged.logical_not().logical_and(is_best_adv)
            best_advs = ep.where(cond, candidates, best_advs)

            tb.probability("converged", converged, step)
            tb.scalar("updated_stats", check_spherical_and_update_stats, step)
            tb.histogram("norms", source_norms, step)
            tb.probability("is_adv", is_adv, step)
            if spherical_is_adv is not None:
                tb.probability("spherical_is_adv", spherical_is_adv, step)
            tb.histogram("candidates/distances", distances, step)
            tb.probability("candidates/closer", closer, step)
            tb.probability("candidates/is_best_adv", is_best_adv, step)
            tb.probability("new_best_adv_including_converged", is_best_adv,
                           step)
            tb.probability("new_best_adv", cond, step)

            if check_spherical_and_update_stats:
                full = stats_spherical_adversarial.isfull()
                tb.probability("spherical_stats/full", full, step)
                if full.any():
                    probs = stats_spherical_adversarial.mean()
                    cond1 = ep.logical_and(probs > 0.5, full)
                    spherical_steps = ep.where(
                        cond1, spherical_steps * self.step_adaptation,
                        spherical_steps)
                    source_steps = ep.where(
                        cond1, source_steps * self.step_adaptation,
                        source_steps)
                    cond2 = ep.logical_and(probs < 0.2, full)
                    spherical_steps = ep.where(
                        cond2, spherical_steps / self.step_adaptation,
                        spherical_steps)
                    source_steps = ep.where(
                        cond2, source_steps / self.step_adaptation,
                        source_steps)
                    stats_spherical_adversarial.clear(
                        ep.logical_or(cond1, cond2))
                    tb.conditional_mean(
                        "spherical_stats/isfull/success_rate/mean", probs,
                        full, step)
                    tb.probability_ratio("spherical_stats/isfull/too_linear",
                                         cond1, full, step)
                    tb.probability_ratio(
                        "spherical_stats/isfull/too_nonlinear", cond2, full,
                        step)

                full = stats_step_adversarial.isfull()
                tb.probability("step_stats/full", full, step)
                if full.any():
                    probs = stats_step_adversarial.mean()
                    # TODO: algorithm: changed the two values because we are currently tracking p(source_step_sucess)
                    # instead of p(source_step_success | spherical_step_sucess) that was tracked before
                    cond1 = ep.logical_and(probs > 0.25, full)
                    source_steps = ep.where(
                        cond1, source_steps * self.step_adaptation,
                        source_steps)
                    cond2 = ep.logical_and(probs < 0.1, full)
                    source_steps = ep.where(
                        cond2, source_steps / self.step_adaptation,
                        source_steps)
                    stats_step_adversarial.clear(ep.logical_or(cond1, cond2))
                    tb.conditional_mean("step_stats/isfull/success_rate/mean",
                                        probs, full, step)
                    tb.probability_ratio(
                        "step_stats/isfull/success_rate_too_high", cond1, full,
                        step)
                    tb.probability_ratio(
                        "step_stats/isfull/success_rate_too_low", cond2, full,
                        step)

            tb.histogram("spherical_step", spherical_steps, step)
            tb.histogram("source_step", source_steps, step)
        tb.close()
        return restore_type(best_advs)
Esempio n. 10
0
 def _project(self, originals: ep.Tensor, perturbed: ep.Tensor, epsilons: ep.Tensor) -> ep.Tensor:
     epsilons = atleast_kd(epsilons, originals.ndim)
     return (1.0 - epsilons) * originals + epsilons * perturbed
Esempio n. 11
0
    def run(
        self,
        model: Model,
        inputs: T,
        criterion: Union[Criterion, Any] = None,
        *,
        early_stop: Optional[float] = None,
        **kwargs: Any,
    ) -> T:
        raise_if_kwargs(kwargs)
        x, restore_type = ep.astensor_(inputs)
        criterion_ = get_criterion(criterion)
        del inputs, criterion, kwargs

        is_adversarial = get_is_adversarial(criterion_, model)

        min_, max_ = model.bounds

        N = len(x)
        self.qcount = 0

        for j in range(self.directions):
            # random noise inputs tend to be classified into the same class,
            # so we might need to make very many draws if the original class
            # is that one
            random_ = ep.uniform(x, x.shape, min_, max_)
            is_adv_ = atleast_kd(is_adversarial(random_), x.ndim)
            self.qcount += 1
            if j == 0:
                random = random_
                is_adv = is_adv_
            else:
                random = ep.where(is_adv, random, random_)
                is_adv = is_adv.logical_or(is_adv_)

            if is_adv.all():
                break

        if not is_adv.all():
            warnings.warn(
                f"{self.__class__.__name__} failed to draw sufficient random"
                f" inputs that are adversarial ({is_adv.sum()} / {N}).")

        x0 = x

        epsilons = np.linspace(0, 1, num=self.steps + 1, dtype=np.float32)
        best = ep.ones(x, (N, ))

        for epsilon in epsilons:
            x = (1 - epsilon) * x0 + epsilon * random
            # TODO: due to limited floating point precision, clipping can be required
            is_adv = is_adversarial(x)
            self.qcount += 1

            epsilon = epsilon.item()

            best = ep.minimum(ep.where(is_adv, epsilon, 1.0), best)

            if (best < 1).all():
                break

        best = atleast_kd(best, x0.ndim)
        x = (1 - best) * x0 + best * random

        return restore_type(x)
Esempio n. 12
0
    def _binary_search_on_alpha(
            self, 
            function_evolution: Callable[[ep.Tensor], ep.Tensor], 
            lower: ep.Tensor) -> ep.Tensor:    
        # Upper --> not adversarial /  Lower --> adversarial
        v_type = function_evolution(lower)
        def get_alpha(theta: ep.Tensor) -> ep.Tensor:
            return 1 - ep.astensor(self._cos(theta.raw * np.pi / 180))

        check_opposite = lower > 0 # if param < 0: abs(param) doesn't work
        
        # Get the upper range
        upper = ep.where(
            abs(lower) != self.theta_max, 
            lower + ep.sign(lower) * self.theta_max / self.T,
            ep.zeros_like(lower)
            )

        mask_upper = (upper == 0)
        while mask_upper.any():
            # Find the correct lower/upper range
            # if True in mask_upper, the range haven't been found
            new_upper = lower + ep.sign(lower) * self.theta_max / self.T
            potential_x = function_evolution(new_upper)
            x = ep.where(
                atleast_kd(mask_upper, potential_x.ndim),
                potential_x,
                ep.zeros_like(potential_x)
            )

            is_advs =  self._is_adversarial(x)
            lower = ep.where(ep.logical_and(mask_upper, is_advs), new_upper, lower) 
            upper = ep.where(ep.logical_and(mask_upper, is_advs.logical_not()), new_upper, upper) 
            mask_upper = mask_upper * is_advs

        step = 0
        over_gamma = abs(get_alpha(upper) - get_alpha(lower)) > self._BS_gamma
        while step < self._BS_max_iteration and over_gamma.any(): 
            mid_bound = (upper + lower) / 2
            mid = ep.where(
                atleast_kd(ep.logical_and(mid_bound != 0, over_gamma), v_type.ndim),
                function_evolution(mid_bound),
                ep.zeros_like(v_type)
            )
            is_adv = self._is_adversarial(mid)

            mid_opp = ep.where(
                atleast_kd(ep.logical_and(ep.astensor(check_opposite), over_gamma), mid.ndim),
                function_evolution(-mid_bound),
                ep.zeros_like(mid)
            )
            is_adv_opp = self._is_adversarial(mid_opp)

            lower = ep.where(over_gamma * is_adv, mid_bound, lower)
            lower = ep.where(over_gamma * is_adv.logical_not() * check_opposite * is_adv_opp, -mid_bound, lower)
            upper = ep.where(over_gamma * is_adv.logical_not() * check_opposite * is_adv_opp, - upper, upper)
            upper = ep.where(over_gamma * (abs(lower) != abs(mid_bound)), mid_bound, upper)

            check_opposite = over_gamma * check_opposite * is_adv_opp * (lower > 0)
            over_gamma = abs(get_alpha(upper) - get_alpha(lower)) > self._BS_gamma

            step += 1
        return ep.astensor(lower)
Esempio n. 13
0
    def run(
        self,
        model: Model,
        inputs: T,
        criterion: Union[Misclassification, TargetedMisclassification, T],
        *,
        starting_points: Optional[ep.Tensor] = None,
        early_stop: Optional[float] = None,
        **kwargs: Any,
    ) -> T:
        raise_if_kwargs(kwargs)
        criterion_ = get_criterion(criterion)

        if isinstance(criterion_, Misclassification):
            targeted = False
            classes = criterion_.labels
        elif isinstance(criterion_, TargetedMisclassification):
            targeted = True
            classes = criterion_.target_classes
        else:
            raise ValueError("unsupported criterion")

        def loss_fn(
            inputs: ep.Tensor, labels: ep.Tensor
        ) -> Tuple[ep.Tensor, Tuple[ep.Tensor, ep.Tensor]]:

            logits = model(inputs)

            if targeted:
                c_minimize = best_other_classes(logits, labels)
                c_maximize = labels  # target_classes
            else:
                c_minimize = labels  # labels
                c_maximize = best_other_classes(logits, labels)

            loss = logits[rows, c_minimize] - logits[rows, c_maximize]

            return -loss.sum(), (logits, loss)

        x, restore_type = ep.astensor_(inputs)
        del inputs, criterion, kwargs
        N = len(x)

        # start from initialization points/attack
        if starting_points is not None:
            x1 = starting_points
        else:
            if self.init_attack is not None:
                x1 = self.init_attack.run(model, x, criterion_)
            else:
                x1 = None

        # if initial points or initialization attacks are provided,
        #   search for the boundary
        if x1 is not None:
            is_adv = get_is_adversarial(criterion_, model)
            assert is_adv(x1).all()
            lower_bound = ep.zeros(x, shape=(N, ))
            upper_bound = ep.ones(x, shape=(N, ))
            for _ in range(self.binary_search_steps):
                epsilons = (lower_bound + upper_bound) / 2
                mid_points = self.mid_points(x, x1, epsilons, model.bounds)
                is_advs = is_adv(mid_points)
                lower_bound = ep.where(is_advs, lower_bound, epsilons)
                upper_bound = ep.where(is_advs, epsilons, upper_bound)
            starting_points = self.mid_points(x, x1, upper_bound, model.bounds)
            delta = starting_points - x
        else:
            # start from x0
            delta = ep.zeros_like(x)

        if classes.shape != (N, ):
            name = "target_classes" if targeted else "labels"
            raise ValueError(
                f"expected {name} to have shape ({N},), got {classes.shape}")

        min_, max_ = model.bounds
        rows = range(N)
        grad_and_logits = ep.value_and_grad_fn(x, loss_fn, has_aux=True)

        if self.p != 0:
            epsilon = ep.inf * ep.ones(x, len(x))
        else:
            epsilon = ep.ones(x, len(x)) if x1 is None \
                else ep.norms.l0(flatten(delta), axis=-1)
        if self.p != 0:
            worst_norm = ep.norms.lp(flatten(ep.maximum(x - min_, max_ - x)),
                                     p=self.p,
                                     axis=-1)
        else:
            worst_norm = flatten(ep.ones_like(x)).bool().sum(axis=1).float32()

        best_lp = worst_norm
        best_delta = delta
        adv_found = ep.zeros(x, len(x)).bool()

        for i in range(self.steps):
            # perform cosine annealing of learning rates
            stepsize = (self.min_stepsize +
                        (self.max_stepsize - self.min_stepsize) *
                        (1 + math.cos(math.pi * i / self.steps)) / 2)
            gamma = (0.001 + (self.gamma - 0.001) *
                     (1 + math.cos(math.pi * (i / self.steps))) / 2)

            x_adv = x + delta

            loss, (logits,
                   loss_batch), gradients = grad_and_logits(x_adv, classes)
            is_adversarial = criterion_(x_adv, logits)

            lp = ep.norms.lp(flatten(delta), p=self.p, axis=-1)
            is_smaller = lp <= best_lp
            is_both = ep.logical_and(is_adversarial, is_smaller)
            adv_found = ep.logical_or(adv_found, is_adversarial)
            best_lp = ep.where(is_both, lp, best_lp)
            best_delta = ep.where(atleast_kd(is_both, x.ndim), delta,
                                  best_delta)

            # update epsilon
            if self.p != 0:
                distance_to_boundary = abs(loss_batch) / ep.norms.lp(
                    flatten(gradients), p=self.dual, axis=-1)
                epsilon = ep.where(
                    is_adversarial,
                    ep.minimum(
                        epsilon * (1 - gamma),
                        ep.norms.lp(flatten(best_delta), p=self.p, axis=-1)),
                    ep.where(
                        adv_found, epsilon * (1 + gamma),
                        ep.norms.lp(flatten(delta), p=self.p, axis=-1) +
                        distance_to_boundary))
            else:
                epsilon = ep.where(
                    is_adversarial,
                    ep.minimum(
                        ep.minimum(epsilon - 1,
                                   (epsilon * (1 - gamma)).astype(int).astype(
                                       epsilon.dtype)),
                        ep.norms.lp(flatten(best_delta), p=self.p, axis=-1)),
                    ep.maximum(epsilon + 1,
                               (epsilon * (1 + gamma)).astype(int).astype(
                                   epsilon.dtype)))
                epsilon = ep.maximum(0, epsilon).astype(epsilon.dtype)

            # clip epsilon
            epsilon = ep.minimum(epsilon, worst_norm)

            # computes normalized gradient update
            grad_ = self.normalize(gradients, x=x,
                                   bounds=model.bounds) * stepsize

            # do step
            delta = delta + grad_

            # project according to the given norm
            delta = self.project(x=x + delta, x0=x, epsilon=epsilon) - x

            # clip to valid bounds
            delta = ep.clip(x + delta, *model.bounds) - x

        x_adv = x + best_delta
        return restore_type(x_adv)
Esempio n. 14
0
    def run(self, model, inputs, criterion, *, early_stop, **kwargs):
        raise_if_kwargs(kwargs)
        x, restore_type = ep.astensor_(inputs)
        criterion_ = get_criterion(criterion)
        del inputs, criterion, kwargs

        N = len(x)

        if isinstance(criterion_, Misclassification):
            targeted = False
            classes = criterion_.labels
            change_classes_logits = self.confidence
        elif isinstance(criterion_, TargetedMisclassification):
            targeted = True
            classes = criterion_.target_classes
            change_classes_logits = -self.confidence
        else:
            raise ValueError("unsupported criterion")

        def is_adversarial(perturbed: ep.Tensor, logits: ep.Tensor) -> ep.Tensor:
            if change_classes_logits != 0:
                logits += ep.onehot_like(logits, classes, value=change_classes_logits)
            return criterion_(perturbed, logits)

        if classes.shape != (N,):
            name = "target_classes" if targeted else "labels"
            raise ValueError(
                f"expected {name} to have shape ({N},), got {classes.shape}"
            )

        bounds = model.bounds
        to_attack_space = partial(_to_attack_space, bounds=bounds)
        to_model_space = partial(_to_model_space, bounds=bounds)

        x_attack = to_attack_space(x)
        reconstsructed_x = to_model_space(x_attack)

        rows = range(N)

        def loss_fun(delta, consts):
            assert delta.shape == x_attack.shape
            assert consts.shape == (N,)

            x = to_model_space(x_attack + delta)
            logits = model(x)

            if targeted:
                c_minimize = best_other_classes(logits, classes)
                c_maximize = classes  # target_classes
            else:
                c_minimize = classes  # labels
                c_maximize = best_other_classes(logits, classes)

            is_adv_loss = logits[rows, c_minimize] - logits[rows, c_maximize]
            assert is_adv_loss.shape == (N,)

            is_adv_loss = is_adv_loss + self.confidence
            is_adv_loss = ep.maximum(0, is_adv_loss)
            is_adv_loss = is_adv_loss * consts

            squared_norms = flatten(x - reconstsructed_x).square().sum(axis=-1)
            loss = is_adv_loss.sum() + squared_norms.sum()
            return loss, (x, logits)

        loss_aux_and_grad = ep.value_and_grad_fn(x, loss_fun, has_aux=True)

        consts = self.initial_const * np.ones((N,))
        lower_bounds = np.zeros((N,))
        upper_bounds = np.inf * np.ones((N,))

        best_advs = ep.zeros_like(x)
        best_advs_norms = ep.full(x, (N,), ep.inf)

        self._consts = []
        self._steps_per_iter = []
        self._best_const = -1
        # the binary search searches for the smallest consts that produce adversarials
        for binary_search_step in range(self.binary_search_steps):
            if (
                    binary_search_step == self.binary_search_steps - 1
                    and self.binary_search_steps >= 10
            ):
                # in the last binary search step, repeat the search once
                consts = np.minimum(upper_bounds, 1e10)

            iter_step = 0

            # create a new optimizer find the delta that minimizes the loss
            delta = ep.zeros_like(x_attack)
            optimizer = AdamOptimizer(delta)

            # tracks whether adv with the current consts was found
            found_advs = np.full((N,), fill_value=False)
            loss_at_previous_check = np.inf

            consts_ = ep.from_numpy(x, consts.astype(np.float32))

            for step in range(self.steps):
                loss, (perturbed, logits), gradient = loss_aux_and_grad(delta, consts_)
                delta += optimizer(gradient, self.stepsize)

                if self.abort_early and step % (np.ceil(self.steps / 10)) == 0:
                    # after each tenth of the overall steps, check progress
                    if not (loss <= 0.9999 * loss_at_previous_check):
                        break  # stop Adam if there has been no progress
                    loss_at_previous_check = loss

                iter_step += 1

                found_advs_iter = is_adversarial(perturbed, logits)
                found_advs = np.logical_or(found_advs, found_advs_iter.numpy())

                norms = flatten(perturbed - x).norms.l2(axis=-1)
                closer = norms < best_advs_norms
                new_best = ep.logical_and(closer, found_advs_iter)
                if closer and found_advs_iter:
                    self._best_const = binary_search_step

                new_best_ = atleast_kd(new_best, best_advs.ndim)
                best_advs = ep.where(new_best_, perturbed, best_advs)
                best_advs_norms = ep.where(new_best, norms, best_advs_norms)
                self._consts.append(consts_.numpy().tolist())

            self._steps_per_iter.append(iter_step)

            upper_bounds = np.where(found_advs, consts, upper_bounds)
            lower_bounds = np.where(found_advs, lower_bounds, consts)

            consts_exponential_search = consts * 10
            consts_binary_search = (lower_bounds + upper_bounds) / 2
            consts = np.where(
                np.isinf(upper_bounds), consts_exponential_search, consts_binary_search
            )

        return restore_type(best_advs)
Esempio n. 15
0
def draw_proposals(bounds: Bounds, originals: ep.Tensor, perturbed: ep.Tensor,
                   unnormalized_source_directions: ep.Tensor,
                   source_directions: ep.Tensor, source_norms: ep.Tensor,
                   spherical_steps: ep.Tensor, source_steps: ep.Tensor,
                   surrogate_models, ODS,
                   device) -> Tuple[ep.Tensor, ep.Tensor]:
    # remember the actual shape
    shape = originals.shape
    assert perturbed.shape == shape
    assert unnormalized_source_directions.shape == shape
    assert source_directions.shape == shape

    perturbed_org = perturbed
    # flatten everything to (batch, size)
    originals = flatten(originals)
    perturbed = flatten(perturbed)
    unnormalized_source_directions = flatten(unnormalized_source_directions)
    source_directions = flatten(source_directions)
    N, D = originals.shape

    assert source_norms.shape == (N, )
    assert spherical_steps.shape == (N, )
    assert source_steps.shape == (N, )

    if ODS == True:
        eta = draw_ODS(perturbed_org, surrogate_models, device)
        eta, _ = ep.astensor_(eta)
        eta = eta.reshape((-1, 1))
    else:
        # draw from an iid Gaussian (we can share this across the whole batch)
        eta = ep.normal(perturbed, (D, 1))

    # make orthogonal (source_directions are normalized)
    eta = eta.T - ep.matmul(source_directions, eta) * source_directions
    assert eta.shape == (N, D)

    # rescale
    norms = ep.norms.l2(eta, axis=-1)
    assert norms.shape == (N, )
    eta = eta * atleast_kd(spherical_steps * source_norms / norms, eta.ndim)

    # project on the sphere using Pythagoras
    distances = atleast_kd((spherical_steps.square() + 1).sqrt(), eta.ndim)
    directions = eta - unnormalized_source_directions
    spherical_candidates = originals + directions / distances

    # clip
    min_, max_ = bounds
    spherical_candidates = spherical_candidates.clip(min_, max_)

    # step towards the original inputs
    new_source_directions = originals - spherical_candidates
    assert new_source_directions.ndim == 2
    new_source_directions_norms = ep.norms.l2(flatten(new_source_directions),
                                              axis=-1)

    # length if spherical_candidates would be exactly on the sphere
    lengths = source_steps * source_norms

    # length including correction for numerical deviation from sphere
    lengths = lengths + new_source_directions_norms - source_norms

    # make sure the step size is positive
    lengths = ep.maximum(lengths, 0)

    # normalize the length
    lengths = lengths / new_source_directions_norms
    lengths = atleast_kd(lengths, new_source_directions.ndim)

    candidates = spherical_candidates + lengths * new_source_directions

    # clip
    candidates = candidates.clip(min_, max_)

    # restore shape
    candidates = candidates.reshape(shape)
    spherical_candidates = spherical_candidates.reshape(shape)
    return candidates, spherical_candidates
def draw_proposals(bounds: Bounds, originals: ep.Tensor, perturbed: ep.Tensor,
                   unnormalized_source_directions: ep.Tensor,
                   source_directions: ep.Tensor, source_norms: ep.Tensor,
                   spherical_steps: ep.Tensor, source_steps: ep.Tensor,
                   surrogate_model: Model) -> Tuple[ep.Tensor, ep.Tensor]:
    # remember the actual shape
    shape = originals.shape
    assert perturbed.shape == shape
    assert unnormalized_source_directions.shape == shape
    assert source_directions.shape == shape

    # flatten everything to (batch, size)
    originals = flatten(originals)
    perturbed = flatten(perturbed)
    unnormalized_source_directions = flatten(unnormalized_source_directions)
    source_directions = flatten(source_directions)
    N, D = originals.shape

    assert source_norms.shape == (N, )
    assert spherical_steps.shape == (N, )
    assert source_steps.shape == (N, )

    # draw from an iid Gaussian (we can share this across the whole batch)
    eta = ep.normal(perturbed, (D, 1))

    # make orthogonal (source_directions are normalized)
    eta = eta.T - ep.matmul(source_directions, eta) * source_directions
    assert eta.shape == (N, D)

    pg_factor = 0.5

    if not surrogate_model is None:
        device = surrogate_model.device
        projected_gradient = get_projected_gradients(perturbed.reshape(shape),
                                                     originals.reshape(shape),
                                                     0, surrogate_model)
        projected_gradient = projected_gradient.reshape((N, D))
        projected_gradient = torch.tensor(projected_gradient, device=device)

        projected_gradient, restore_type = ep.astensor_(projected_gradient)

        eta = (1. - pg_factor) * eta + pg_factor * projected_gradient

    # rescale
    norms = ep.norms.l2(eta, axis=-1)
    assert norms.shape == (N, )
    eta = eta * atleast_kd(spherical_steps * source_norms / norms, eta.ndim)

    # project on the sphere using Pythagoras
    distances = atleast_kd((spherical_steps.square() + 1).sqrt(), eta.ndim)
    directions = eta - unnormalized_source_directions
    spherical_candidates = originals + directions / distances

    # clip
    min_, max_ = bounds
    spherical_candidates = spherical_candidates.clip(min_, max_)

    # step towards the original inputs
    new_source_directions = originals - spherical_candidates
    assert new_source_directions.ndim == 2
    new_source_directions_norms = ep.norms.l2(flatten(new_source_directions),
                                              axis=-1)

    # length if spherical_candidates would be exactly on the sphere
    lengths = source_steps * source_norms

    # length including correction for numerical deviation from sphere
    lengths = lengths + new_source_directions_norms - source_norms

    # make sure the step size is positive
    lengths = ep.maximum(lengths, 0)

    # normalize the length
    lengths = lengths / new_source_directions_norms
    lengths = atleast_kd(lengths, new_source_directions.ndim)

    candidates = spherical_candidates + lengths * new_source_directions

    # clip
    candidates = candidates.clip(min_, max_)

    # restore shape
    candidates = candidates.reshape(shape)
    spherical_candidates = spherical_candidates.reshape(shape)

    return candidates, spherical_candidates
    def run(
        self,
        model: Model,
        inputs: T,
        criterion: Union[Criterion, T],
        *,
        early_stop: Optional[float] = None,
        starting_points: Optional[T] = None,
        **kwargs: Any,
    ) -> T:
        raise_if_kwargs(kwargs)
        originals, restore_type = ep.astensor_(inputs)
        del inputs, kwargs

        criterion = get_criterion(criterion)
        is_adversarial = get_is_adversarial(criterion, model)

        if starting_points is None:
            init_attack: MinimizationAttack
            if self.init_attack is None:
                init_attack = LinearSearchBlendedUniformNoiseAttack(steps=50)
                logging.info(
                    f"Neither starting_points nor init_attack given. Falling"
                    f" back to {init_attack!r} for initialization.")
            else:
                init_attack = self.init_attack
            # TODO: use call and support all types of attacks (once early_stop is
            # possible in __call__)
            best_advs = init_attack.run(model,
                                        originals,
                                        criterion,
                                        early_stop=early_stop)
        else:
            best_advs = ep.astensor(starting_points)

        is_adv = is_adversarial(best_advs)
        if not is_adv.all():
            failed = is_adv.logical_not().float32().sum()
            if starting_points is None:
                raise ValueError(
                    f"init_attack failed for {failed} of {len(is_adv)} inputs")
            else:
                raise ValueError(
                    f"{failed} of {len(is_adv)} starting_points are not adversarial"
                )
        del starting_points

        tb = TensorBoard(logdir=self.tensorboard)

        N = len(originals)
        ndim = originals.ndim
        spherical_steps = ep.ones(originals, N) * self.spherical_step
        source_steps = ep.ones(originals, N) * self.source_step

        tb.scalar("batchsize", N, 0)

        # create two queues for each sample to track success rates
        # (used to update the hyper parameters)
        stats_spherical_adversarial = ArrayQueue(maxlen=100, N=N)
        stats_step_adversarial = ArrayQueue(maxlen=30, N=N)

        bounds = model.bounds

        self.class_1 = []
        self.class_2 = []

        self.surrogate_model = None
        device = model.device
        train_step = 500

        for step in tqdm(range(1, self.steps + 1)):
            converged = source_steps < self.source_step_convergance
            if converged.all():
                break  # pragma: no cover
            converged = atleast_kd(converged, ndim)

            # TODO: performance: ignore those that have converged
            # (we could select the non-converged ones, but we currently
            # cannot easily invert this in the end using EagerPy)

            unnormalized_source_directions = originals - best_advs
            source_norms = ep.norms.l2(flatten(unnormalized_source_directions),
                                       axis=-1)
            source_directions = unnormalized_source_directions / atleast_kd(
                source_norms, ndim)

            # only check spherical candidates every k steps
            check_spherical_and_update_stats = step % self.update_stats_every_k == 0

            candidates, spherical_candidates = draw_proposals(
                bounds, originals, best_advs, unnormalized_source_directions,
                source_directions, source_norms, spherical_steps, source_steps,
                self.surrogate_model)
            candidates.dtype == originals.dtype
            spherical_candidates.dtype == spherical_candidates.dtype

            is_adv = is_adversarial(candidates)
            is_adv_spherical_candidates = is_adversarial(spherical_candidates)

            if is_adv.item():
                self.class_1.append(candidates)

            if not is_adv_spherical_candidates.item():
                self.class_2.append(spherical_candidates)

            if (step % train_step == 0) and (step > 0):

                start_time = time()

                class_1 = self.class_1
                class_2 = self.class_2

                class_1 = np.array([image.numpy()[0] for image in class_1])
                class_2 = np.array([image.numpy()[0] for image in class_2])

                class_2 = class_2[:len(class_1)]
                data = np.concatenate([class_1, class_2])
                labels = np.append(np.ones(len(class_1)),
                                   np.zeros(len(class_2)))

                X = torch.tensor(data).to(device)
                y = torch.tensor(labels, dtype=torch.long).to(device)

                if self.surrogate_model is None:
                    model_sur = torchvision.models.resnet18(pretrained=True)
                    #model.features[0] = torch.nn.Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
                    model_sur.fc = torch.nn.Linear(in_features=512,
                                                   out_features=2,
                                                   bias=True)
                    model_sur = model_sur.to(device)
                else:
                    model_sur = model_surrogate

                X_train, X_test, y_train, y_test = train_test_split(
                    X, y, test_size=0.2, random_state=42)

                optimizer = torch.optim.Adam(model_sur.parameters(), lr=3e-4)
                loss = torch.nn.CrossEntropyLoss()

                model_surrogate, accuracy_history_test, accuracy_history_train = train(
                    model_sur, optimizer, loss, X_train, y_train, X_test,
                    y_test)
                model_surrogate = model_surrogate.eval()

                self.surrogate_model = fb.PyTorchModel(model_surrogate,
                                                       bounds=(0, 1),
                                                       device=device)

                end_time = time()

                #print('Time for train: ', np.round(end_time - start_time, 2))
                #print('\n')

            spherical_is_adv: Optional[ep.Tensor]
            if check_spherical_and_update_stats:
                spherical_is_adv = is_adversarial(spherical_candidates)
                stats_spherical_adversarial.append(spherical_is_adv)
                # TODO: algorithm: the original implementation ignores those samples
                # for which spherical is not adversarial and continues with the
                # next iteration -> we estimate different probabilities (conditional vs. unconditional)
                # TODO: thoughts: should we always track this because we compute it anyway
                stats_step_adversarial.append(is_adv)
            else:
                spherical_is_adv = None

            # in theory, we are closer per construction
            # but limited numerical precision might break this
            distances = ep.norms.l2(flatten(originals - candidates), axis=-1)
            closer = distances < source_norms
            is_best_adv = ep.logical_and(is_adv, closer)
            is_best_adv = atleast_kd(is_best_adv, ndim)

            cond = converged.logical_not().logical_and(is_best_adv)
            best_advs = ep.where(cond, candidates, best_advs)

            tb.probability("converged", converged, step)
            tb.scalar("updated_stats", check_spherical_and_update_stats, step)
            tb.histogram("norms", source_norms, step)
            tb.probability("is_adv", is_adv, step)
            if spherical_is_adv is not None:
                tb.probability("spherical_is_adv", spherical_is_adv, step)
            tb.histogram("candidates/distances", distances, step)
            tb.probability("candidates/closer", closer, step)
            tb.probability("candidates/is_best_adv", is_best_adv, step)
            tb.probability("new_best_adv_including_converged", is_best_adv,
                           step)
            tb.probability("new_best_adv", cond, step)

            if check_spherical_and_update_stats:
                full = stats_spherical_adversarial.isfull()
                tb.probability("spherical_stats/full", full, step)
                if full.any():
                    probs = stats_spherical_adversarial.mean()
                    cond1 = ep.logical_and(probs > 0.5, full)
                    spherical_steps = ep.where(
                        cond1, spherical_steps * self.step_adaptation,
                        spherical_steps)
                    source_steps = ep.where(
                        cond1, source_steps * self.step_adaptation,
                        source_steps)
                    cond2 = ep.logical_and(probs < 0.2, full)
                    spherical_steps = ep.where(
                        cond2, spherical_steps / self.step_adaptation,
                        spherical_steps)
                    source_steps = ep.where(
                        cond2, source_steps / self.step_adaptation,
                        source_steps)
                    stats_spherical_adversarial.clear(
                        ep.logical_or(cond1, cond2))
                    tb.conditional_mean(
                        "spherical_stats/isfull/success_rate/mean", probs,
                        full, step)
                    tb.probability_ratio("spherical_stats/isfull/too_linear",
                                         cond1, full, step)
                    tb.probability_ratio(
                        "spherical_stats/isfull/too_nonlinear", cond2, full,
                        step)

                full = stats_step_adversarial.isfull()
                tb.probability("step_stats/full", full, step)
                if full.any():
                    probs = stats_step_adversarial.mean()
                    # TODO: algorithm: changed the two values because we are currently tracking p(source_step_sucess)
                    # instead of p(source_step_success | spherical_step_sucess) that was tracked before
                    cond1 = ep.logical_and(probs > 0.25, full)
                    source_steps = ep.where(
                        cond1, source_steps * self.step_adaptation,
                        source_steps)
                    cond2 = ep.logical_and(probs < 0.1, full)
                    source_steps = ep.where(
                        cond2, source_steps / self.step_adaptation,
                        source_steps)
                    stats_step_adversarial.clear(ep.logical_or(cond1, cond2))
                    tb.conditional_mean("step_stats/isfull/success_rate/mean",
                                        probs, full, step)
                    tb.probability_ratio(
                        "step_stats/isfull/success_rate_too_high", cond1, full,
                        step)
                    tb.probability_ratio(
                        "step_stats/isfull/success_rate_too_low", cond2, full,
                        step)

            tb.histogram("spherical_step", spherical_steps, step)
            tb.histogram("source_step", source_steps, step)
        tb.close()
        return restore_type(best_advs)