Beispiel #1
0
 def __enter__(self):
     cpu_count = multiprocessing.cpu_count()
     thread_pool = ThreadPoolExecutor(cpu_count)
     self._loop = asyncio.get_event_loop()
     self._loop.set_default_executor(thread_pool)
     self._thread_poll = thread_pool.__enter__()
     return self
Beispiel #2
0
class BoundedExecutor(concurrent.futures.Executor):
    def __init__(self, max_workers=None, max_inflight=4):
        self._executor = ThreadPoolExecutor(max_workers=max_workers)
        self._lock = threading.Lock()
        self._total_processed = 0
        self._sem = threading.BoundedSemaphore(max_inflight)

    def _acquire(self):
        return self._sem.acquire(True)

    def _done_cb(self, future):
        self._sem.release()
        if future.exception() is not None:
            log.error("Got exception in done callback!")
            raise future.exception()
        processed = future.result()
        if type(processed) is int:
            with self._lock:
                self._total_processed += processed
                total = self._total_processed
            log.info("Successfully made %s updates (%s cumulatively)!" %
                     (processed, total))

    def __enter__(self):
        self._executor.__enter__()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if exc_type is KeyboardInterrupt:
            logging.warning("Got interrupt signal!  Exiting..")
            sys.exit(0)
        return super(BoundedExecutor, self).__exit__(exc_type, exc_value,
                                                     traceback)

    # def total_processed(self):
    #   with self._lock:
    #     return self._total_processed

    def submit(self, fn, *args, **kwargs):
        self._acquire()
        future = self._executor.submit(fn, *args, **kwargs)
        future.add_done_callback(self._done_cb)
        return future

    def shutdown(self, wait=True):
        return self._executor.shutdown(wait=wait)
Beispiel #3
0
class AsyncThread:
    def __init__(self):
        self.loop = None
        self.pool = ThreadPoolExecutor()
        self.thread = threading.Thread(target=self.run)
        self.thread.start()
        while self.loop is None:
            time.sleep(0.1)

    def run(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)
        self.loop.run_forever()

    def stop(self):
        self.loop.call_soon_threadsafe(self.loop.stop)

    def _add_task(self, future, coro):
        task = self.loop.create_task(coro)
        future.set_result(task)

    def put(self, coro):
        future = Future()
        p = fn.partial(self._add_task, future, coro)
        self.loop.call_soon_threadsafe(p)
        return future.result()

    def wait(self, futures, timeout):
        fut = self.put(asyncio.wait(list(futures), timeout=timeout))
        return block(fut, timeout=timeout, throw=False)

    def cancel(self, task):
        self.loop.call_soon_threadsafe(task.cancel)

    def execute(self, fun, *args, **kwargs):
        return self.loop.run_in_executor(self.pool,
                                         fn.partial(fun, *args, **kwargs))

    def __enter__(self):
        self.pool.__enter__()
        return self

    def __exit__(self, cls, value, traceback):
        self.pool.__exit__(cls, value, traceback)
Beispiel #4
0
class OpaExecutor:
    def __init__(self, num):
        self.executor = ThreadPoolExecutor(num)
        self.tasks = []
        self.ex = None

    def __enter__(self):
        self.ex = self.executor.__enter__()
        return self

    def __exit__(self, typ, value, tcb):
        self.do_wait()
        self.executor.__exit__(typ, value, tcb)

    def launcher(self, data):
        try:
            return data[0](*data[1:])
        except:
            tb.print_exc()

    def submit(self, func, *args):
        data = [func]
        data.extend(args)
        res = self.ex.submit(self.launcher, data)
        self.tasks.append(res)
        return res

    def map(self, func, data):
        for x in data:
            self.tasks.append(self.ex.submit(self.launcher, [func, x]))

    def results(self):
        return list([x.result() for x in self.tasks])

    def do_wait(self):
        mv = len(self.tasks)
        pb = ProgressBar(max_value=mv).start()
        pos = 0
        pb.update(0)
        for out in as_completed(self.tasks):
            pos += 1
            pos = min(pos, mv)
            pb.update(pos)
        pb.finish()
        print('')
class BiasedBoundaryAttack:
    """
     Like BoundaryAttack, but uses biased sampling from various sources.
     This implementation is optimized for speed: it can query the model and, while waiting, already prepare the next perturbation candidate.
     Ideally, there is zero overhead over the prediction speed of the model under attack.

     However, we do NOT run predictions in parallel (as the Foolbox BoundaryAttack does).
     This attack is completely sequential to keep the number of queries minimal.

     Activate various biases in bench_settings.py.
    """
    def __init__(self,
                 blackbox_model,
                 sample_gen,
                 dm_main,
                 substitute_model=None):
        """
        Creates a reusable instance.
        :param blackbox_model: The model to attack.
        :param sample_gen: Random sample generator.
        :param substitute_model: A Foolbox differentiable surrogate model for gradients. If None, then the surrogate bias will not be used.
        """

        self.blackbox_model = blackbox_model
        self.sample_gen = sample_gen
        self.dm_main = dm_main.to_range_01(
        )  # Images are normed to 0/1 inside of run_attack()

        # A substitute model that provides batched gradients.
        if substitute_model is not None:
            if not isinstance(substitute_model, TensorFlowModel):
                raise ValueError(
                    "Substitute Model must provide gradients! (Foolbox: TensorFlowModel)"
                )
            self.substitute_model = BatchTensorflowModel(
                substitute_model._images,
                substitute_model._batch_logits,
                session=substitute_model.session)
        else:
            self.substitute_model = None

        # We use ThreadPools to calculate candidates and surrogate gradients while we're waiting for the model's next prediction.
        self.pg_thread_pool = ThreadPoolExecutor(max_workers=1)
        self.candidate_thread_pool = ThreadPoolExecutor(max_workers=1)

    def __enter__(self):
        self.pg_thread_pool.__enter__()
        self.candidate_thread_pool.__enter__()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        # Will block until the futures are calculated. Thankfully they're not very complicated.
        self.pg_thread_pool.__exit__(exc_type, exc_value, traceback)
        self.candidate_thread_pool.__exit__(exc_type, exc_value, traceback)
        print("BiasedBoundaryAttack: all threads stopped.")

    def run_attack(self,
                   X_orig,
                   label,
                   is_targeted,
                   X_start,
                   n_calls_left_fn,
                   n_max_per_batch=50,
                   n_seconds=None,
                   source_step=1e-2,
                   spherical_step=1e-2,
                   mask=None,
                   recalc_mask_every=None):
        """
        Runs the Biased Boundary Attack against a single image.
        The attack terminates when n_calls_left_fn() returns 0 or n_seconds have elapsed.

        :param X_orig: The original (clean) image to perturb.
        :param label: The target label (if targeted), or the original label (if untargeted).
        :param is_targeted: True if targeted.
        :param X_start: The starting point (must be of target class).
        :param n_calls_left_fn: A function that returns the currently remaining number of queries against the model.
        :param n_max_per_batch: How many samples are drawn per "batch". Samples are processed serially (the challenge doesn't allow
                                batching), but for each "batch", the attack dynamically adjusts hyperparams based on the success of
                                previous samples. This "batch" size is the max number of samples after which hyperparams are reset, and
                                a new "batch" is started. See generate_candidate().
        :param n_seconds: Maximum seconds allowed for the attack to complete.
        :param source_step: source step hyperparameter (see Boundary Attack)
        :param spherical_step: orthogonal step hyperparameter (see Boundary Attack)
        :param mask: Optional. If not none, a predefined mask (expert knowledge?) can be defined that will be applied to the perturbations.
        :param recalc_mask_every: If not none, automatically calculates a mask from the current image difference.
                                  Will recalculate this mask every (n) steps.
        :return: The best adversarial example so far.
        """

        assert len(X_orig.shape) == 3
        assert len(X_start.shape) == 3
        if mask is not None:
            assert mask.shape == X_orig.shape
            assert np.sum(mask < 0) == 0 and 1. - np.max(
                mask
            ) < 1e-4, "Mask must be scaled to [0,1]. At least one value must be 1."
        else:
            mask = np.ones(X_orig.shape, dtype=np.float32)

        time_start = timeit.default_timer()

        pg_future = None
        try:
            # WARN: Inside this function, image space is normed to [0,1]!
            X_orig = np.float32(X_orig) / 255.
            X_start = np.float32(X_start) / 255.

            label_current, dist_best = self._eval_sample(X_start, X_orig)
            if (label_current == label) != is_targeted:
                print(
                    "WARN: Starting point is not a valid adversarial example! Continuing for now."
                )

            X_adv_best = np.copy(X_start)

            last_mask_recalc_calls = n_calls_left_fn()

            # Abort if we're running out of queries
            while n_calls_left_fn() > 3:

                # Mask Bias: recalculate mask from current diff (hopefully this reduces the search space)
                if recalc_mask_every is not None and last_mask_recalc_calls - n_calls_left_fn(
                ) >= recalc_mask_every:
                    new_mask = np.abs(X_adv_best - X_orig)
                    new_mask /= np.max(new_mask)  # scale to [0,1]
                    new_mask = new_mask**0.5  # weaken the effect a bit.
                    print(
                        "Recalculated mask. Weighted dimensionality of search space is now {:.0f} (diff: {:.2%}). "
                        .format(np.sum(new_mask),
                                1. - np.sum(new_mask) / np.sum(mask)))
                    mask = new_mask
                    last_mask_recalc_calls = n_calls_left_fn()

                # Draw n candidates at the current position (before resetting hyperparams or before reaching the limit)
                n_candidates = min(n_max_per_batch, n_calls_left_fn())

                # Calculate the projected adversarial surrogate gradient at the current position.
                #  Putting this into a ThreadPoolExecutor. While this is processing, we can already start drawing the first sample.
                # Also cancel any pending requests from previous steps.
                if pg_future is not None:
                    pg_future.cancel()
                pg_future = self.pg_thread_pool.submit(
                    self.get_projected_gradients, **{
                        "x_current": X_adv_best,
                        "x_orig": X_orig,
                        "label": label,
                        "is_targeted": is_targeted
                    })

                # Also do candidate generation with a ThreadPoolExecutor.
                # Queue the first candidate.
                candidate_future = self.candidate_thread_pool.submit(
                    self.generate_candidate, **{
                        "i": 0,
                        "n": n_candidates,
                        "x_orig": X_orig,
                        "x_current": X_adv_best,
                        "mask": mask,
                        "source_step": source_step,
                        "spherical_step": spherical_step,
                        "pg_future": pg_future
                    })

                for i in range(n_candidates):
                    # Get candidate and queue the next one.
                    candidate, stats = candidate_future.result()
                    if i < n_candidates - 1:
                        candidate_future = self.candidate_thread_pool.submit(
                            self.generate_candidate, **{
                                "i": i + 1,
                                "n": n_candidates,
                                "x_orig": X_orig,
                                "x_current": X_adv_best,
                                "mask": mask,
                                "source_step": source_step,
                                "spherical_step": spherical_step,
                                "pg_future": pg_future
                            })

                    time_elapsed = timeit.default_timer() - time_start
                    if n_seconds is not None and time_elapsed >= n_seconds:
                        print("WARN: Running out of time! Aborting attack!")
                        return X_adv_best * 255.

                    # Test if successful. NOTE: dist is rounded here!
                    self.blackbox_model.adv_set_stats(stats)
                    candidate_label, rounded_dist = self._eval_sample(
                        candidate, X_orig)
                    unrounded_dist = self.dm_main.calc(candidate, X_orig)
                    if (candidate_label == label) == is_targeted:
                        if unrounded_dist < dist_best:
                            print(
                                "@ {:.3f}: After {} samples, found something @ {:.3f} (rounded {:.3f})! (reduced by {:.1%})"
                                .format(dist_best, i, unrounded_dist,
                                        rounded_dist,
                                        1. - unrounded_dist / dist_best))

                            # Terminate this batch (don't try the other candidates) and advance.
                            X_adv_best = candidate
                            dist_best = unrounded_dist
                            break

            return X_adv_best * 255.

        finally:
            # Be safe and wait for the gradient future. We want to be sure that no BG worker is blocking the GPU before returning.
            if pg_future is not None:
                futures.wait([pg_future])

    def generate_candidate(self, i, n, x_orig, x_current, mask, source_step,
                           spherical_step, pg_future):

        # This runs in a loop (while i<n) per "batch".
        # Whenever a candidate is successful, a new batch is started. Therefore, i is the number of previously unsuccessful samples.
        # Trying to use this in our favor, we progressively reduce step size for the next candidate.
        # When the batch is through, hyperparameters are reset for the next batch.

        # Scale both spherical and source step with i.
        scale = (1. - i / n) + 0.3
        c_source_step = source_step * scale
        c_spherical_step = spherical_step * scale

        # Get the adversarial projected gradient from the (other) BG worker.
        pg_factor = 0.3
        pgs = pg_future.result()
        pgs = pgs if i % 2 == 0 else None  # Only use gradient bias on every 2nd iteration, but always try it at first..

        if bench_settings.USE_PERLIN_BIAS:
            sampling_fn = self.sample_gen.get_perlin
        else:
            sampling_fn = self.sample_gen.get_normal

        candidate, stats = self.generate_boundary_sample(
            X_orig=x_orig,
            X_adv_current=x_current,
            mask=mask,
            source_step=c_source_step,
            spherical_step=c_spherical_step,
            sampling_fn=sampling_fn,
            pgs_current=pgs,
            pg_factor=pg_factor)

        stats["i_sample"] = int(i)
        stats["mask_sum"] = float(np.sum(mask))
        return candidate, stats

    def generate_boundary_sample(self,
                                 X_orig,
                                 X_adv_current,
                                 mask,
                                 source_step,
                                 spherical_step,
                                 sampling_fn,
                                 pgs_current=None,
                                 pg_factor=0.5):
        # Adapted from FoolBox BoundaryAttack.

        unnormalized_source_direction = np.float32(X_orig) - np.float32(
            X_adv_current)
        source_norm = np.linalg.norm(unnormalized_source_direction)
        source_direction = unnormalized_source_direction / source_norm

        # Get perturbation from provided distribution
        sampling_dir, stats = sampling_fn(return_stats=True)

        # ===========================================================
        # calculate candidate on sphere
        # ===========================================================
        dot = np.vdot(sampling_dir, source_direction)
        sampling_dir -= dot * source_direction  # Project orthogonal to source direction
        sampling_dir *= mask  # Apply regional mask
        sampling_dir /= np.linalg.norm(
            sampling_dir)  # Norming increases magnitude of masked regions

        # If available: Bias the spherical dirs in direction of the adversarial gradient, which is projected onto the sphere
        if pgs_current is not None:
            # We have a bunch of gradients that we can try. Randomly select one.
            # NOTE: we found this to perform better than simply averaging the gradients.
            pg_current = pgs_current[np.random.randint(0, len(pgs_current))]
            pg_current *= mask
            pg_current /= np.linalg.norm(pg_current)

            sampling_dir = (1. -
                            pg_factor) * sampling_dir + pg_factor * pg_current
            sampling_dir /= np.linalg.norm(sampling_dir)

        sampling_dir *= spherical_step * source_norm  # Norm to length stepsize*(dist from src)

        D = 1 / np.sqrt(spherical_step**2 + 1)
        direction = sampling_dir - unnormalized_source_direction
        spherical_candidate = X_orig + D * direction

        np.clip(spherical_candidate, 0., 1., out=spherical_candidate)

        # ===========================================================
        # step towards source
        # ===========================================================
        new_source_direction = X_orig - spherical_candidate

        new_source_direction_norm = np.linalg.norm(new_source_direction)
        new_source_direction /= new_source_direction_norm
        spherical_candidate = X_orig - source_norm * new_source_direction  # Snap sph.c. onto sphere

        # From there, take a step towards the target.
        candidate = spherical_candidate + (source_step *
                                           source_norm) * new_source_direction

        np.clip(candidate, 0., 1., out=candidate)
        return np.float32(candidate), stats

    def get_projected_gradients(self, x_current, x_orig, label, is_targeted):
        # Idea is: we have a direction (spherical candidate) in which we want to sample.
        # We know that the gradient of a substitute model, projected onto the sphere, usually points to an adversarial region.
        # Even if we are already adversarial, it should point "deeper" into that region.
        # If we sample in that direction, we should move toward the center of the adversarial cone.
        # Here, we simply project the gradient onto the same hyperplane as the spherical samples.
        #
        # Instead of a single projected gradient, this method returns an entire batch of them:
        # - Surrogate gradients are unreliable, so we sample them in a region around the current position.
        # - This gives us a similar benefit as observed in "PGD with random restarts".

        if self.substitute_model is None:
            return None

        source_direction = x_orig - x_current
        source_norm = np.linalg.norm(source_direction)
        source_direction = source_direction / source_norm

        # Take a tiny step towards the source before calculating the gradient. This marginally improves our results.
        step_inside = 1e-2 * source_norm
        x_inside = x_current + step_inside * source_direction

        # Perturb the current position before calc'ing gradient
        n_samples = 4
        radius_max = 1e-2 * source_norm  # deactivated for now
        x_perturb = sample_hypersphere(n_samples=n_samples,
                                       sample_shape=x_orig.shape,
                                       radius=1,
                                       sample_gen=self.sample_gen)
        x_perturb *= np.random.uniform(0., radius_max)

        x_inside_batch = x_inside + x_perturb
        gradient_batch = np.empty(x_inside_batch.shape, dtype=np.float32)

        gradients = (self.substitute_model.batch_gradients(
            x_inside_batch * 255., [label] * n_samples) / 255.)
        if is_targeted:
            gradients = -gradients

        for i in range(n_samples):
            # Project the gradients.
            dot = np.vdot(gradients[i], source_direction)
            projected_gradient = gradients[
                i] - dot * source_direction  # Project orthogonal to source direction
            projected_gradient /= np.linalg.norm(
                projected_gradient)  # Norm to length 1
            gradient_batch[i] = projected_gradient

        return gradient_batch

    def _eval_sample(self, x, x_orig_normed=None):
        # Round, then get label and distance.
        x_rounded = np.round(np.clip(x * 255., 0, 255))
        preds = self.blackbox_model.predictions(np.uint8(x_rounded))
        label = np.argmax(preds)

        if x_orig_normed is None:
            return label
        else:
            dist = self.dm_main.calc(x_rounded / 255., x_orig_normed)
            return label, dist
Beispiel #6
0
    from progressbar import ProgressBar
except:
    pass
is_python2 = sys.version_info < (3, 0)
if is_python2:
    try:
        from builtins import *
    except:
        pass

misc_backend = None
chdrft_executor = None
try:
    from concurrent.futures import ThreadPoolExecutor, as_completed
    chdrft_executor = ThreadPoolExecutor(max_workers=100)
    chdrft_executor = chdrft_executor.__enter__()
except:
    pass

if sys.version_info >= (3, 0):
    misc_backend = None
    import jsonpickle
    from jsonpickle import handlers
    misc_backend = jsonpickle.backend.JSONBackend()
    misc_backend.set_encoder_options('json', sort_keys=True, indent=4)
    misc_backend.set_preferred_backend('json')

    class BinaryHandler(jsonpickle.handlers.BaseHandler):
        def flatten(self, obj, data):
            data['data'] = base64.b64encode(obj).decode('ascii')
            return data
Beispiel #7
0
class ProcessingSession:
    def __init__(self, config, logger):
        self.running = True
        self.scan_finished = False
        self.reads_queued = self.reads_found = 0
        self.reads_processed = 0
        self.next_batch_id = 0
        self.reads_done = set()
        self.active_batches = 0
        self.error_status_counts = defaultdict(int)
        self.jobstack = []

        self.config = config
        self.logger = logger

        self.executor_compute = ProcessPoolExecutor(config['parallel'])
        self.executor_io = ThreadPoolExecutor(2)
        self.executor_mon = ThreadPoolExecutor(2)

        self.loop = self.fastq_writer = self.fast5_writer = \
            self.alignment_writer = self.npreaddb_writer = None
        self.dashboard = self.pbar = None

    def __enter__(self):
        self.loop = asyncio.get_event_loop()
        self.executor_compute.__enter__()
        self.executor_io.__enter__()
        self.executor_mon.__enter__()

        for signame in 'SIGINT SIGTERM'.split():
            self.loop.add_signal_handler(getattr(signal, signame), self.stop,
                                         signame)

        if self.config['fastq_output']:
            self.fastq_writer = FASTQWriter(self.config['outputdir'],
                                            self.config['output_layout'])
        if self.config['fast5_output']:
            self.fast5_writer = FAST5Writer(self.config['outputdir'],
                                            self.config['output_layout'],
                                            self.config['inputdir'],
                                            self.config['fast5_batch_size'])
        if self.config['nanopolish_output']:
            self.npreaddb_writer = NanopolishReadDBWriter(
                self.config['outputdir'], self.config['output_layout'])
        self.seqsummary_writer = SequencingSummaryWriter(
            self.config, self.config['outputdir'], self.config['label_names'],
            self.config['barcode_names'])
        self.finalsummary_tracker = FinalSummaryTracker(
            self.config['label_names'], self.config['barcode_names'])

        if self.config['minimap2_index']:
            self.show_message('==> Loading a minimap2 index file')
            self.alignment_writer = AlignmentWriter(
                self.config['minimap2_index'],
                os.path.join(self.config['outputdir'], 'bam', '{}.bam'),
                self.config['output_layout'])

        return self

    def __exit__(self, *args):
        if self.fastq_writer is not None:
            self.fastq_writer.close()
            self.fastq_writer = None

        if self.fast5_writer is not None:
            self.fast5_writer.close()
            self.fast5_writer = None

        if self.npreaddb_writer is not None:
            self.npreaddb_writer.close()
            self.npreaddb_writer = None

        if self.seqsummary_writer is not None:
            self.seqsummary_writer.close()
            self.seqsummary_writer = None

        if self.alignment_writer is not None:
            self.alignment_writer.close()
            self.alignment_writer = None

        self.executor_mon.__exit__(*args)
        self.executor_io.__exit__(*args)
        self.executor_compute.__exit__(*args)
        self.loop.close()

    def errx(self, message):
        if self.running:
            errprint(message, end='')
            self.stop('ERROR')

    def show_message(self, message):
        if not self.config['quiet']:
            print(message)

    def stop(self, signalname='unknown'):
        if self.running:
            if signalname in ['SIGTERM', 'SIGINT']:
                errprint("\nTermination in process. Please wait for a moment.")
            self.running = False
        for task in asyncio.Task.all_tasks():
            task.cancel()

        self.loop.stop()

    def run_in_executor_compute(self, *args):
        return self.loop.run_in_executor(self.executor_compute, *args)

    def run_in_executor_io(self, *args):
        return self.loop.run_in_executor(self.executor_io, *args)

    def run_in_executor_mon(self, *args):
        return self.loop.run_in_executor(self.executor_mon, *args)

    async def run_process_batch(self, batchid, files):
        # Wait until the input files become ready if needed
        if self.config['analysis_start_delay'] > 0:
            try:
                await asyncio.sleep(self.config['analysis_start_delay'])
            except CancelledError:
                return

        self.active_batches += 1
        try:

            results = await self.run_in_executor_compute(
                process_batch, batchid, files, self.config)

            if len(results
                   ) > 0 and results[0] == -1:  # Unhandled exception occurred
                error_message = results[1]
                self.logger.error(error_message)
                for line in results[2].splitlines():
                    self.logger.error(line)
                self.errx("ERROR: " + error_message)
                return

            # Remove duplicated results that could be fed multiple times in live monitoring
            nd_results = []
            for result in results:
                readpath = result['filename'], result['read_id']
                if readpath not in self.reads_done:
                    if result['status'] == 'okay':
                        self.reads_done.add(readpath)
                    elif 'error_message' in result:
                        self.logger.error(result['error_message'])
                    nd_results.append(result)
                else:  # Cancel the duplicated result
                    self.reads_queued -= 1
                    self.reads_found -= 1

                self.error_status_counts[result['status']] += 1

            if nd_results:
                if self.config['fastq_output']:
                    await self.run_in_executor_io(
                        self.fastq_writer.write_sequences, nd_results)

                if self.config['fast5_output']:
                    await self.run_in_executor_io(
                        self.fast5_writer.transfer_reads, nd_results)

                if self.config['nanopolish_output']:
                    await self.run_in_executor_io(
                        self.npreaddb_writer.write_sequences, nd_results)

                if self.config['minimap2_index']:
                    rescounts = await self.run_in_executor_io(
                        self.alignment_writer.process, nd_results)
                    if self.dashboard is not None:
                        self.dashboard.feed_mapped(rescounts)

                await self.run_in_executor_io(
                    self.seqsummary_writer.write_results, nd_results)

                self.finalsummary_tracker.feed_results(nd_results)

            if (self.error_status_counts['okay'] == 0 and self.running
                    and self.error_status_counts['not_basecalled'] >=
                    self.config['nobasecall_stop_trigger']):

                stopmsg = (
                    'Early stopping: {} out of {} reads are not basecalled. '
                    'Please check if the files are correctly analyzed, or '
                    'add `--basecall\' to the command line.'.format(
                        self.error_status_counts['not_basecalled'],
                        sum(self.error_status_counts.values())))
                self.logger.error(stopmsg)
                self.errx(stopmsg)

        except (CancelledError, BrokenProcessPool):
            return
        except Exception as exc:
            self.logger.error('Unhandled error during processing reads',
                              exc_info=exc)
            return self.errx('ERROR: Unhandled error ' + str(exc))
        finally:
            self.active_batches -= 1

        self.reads_processed += len(nd_results)
        self.reads_queued -= len(nd_results)

    def queue_processing(self, readpath):
        self.jobstack.append(readpath)
        self.reads_queued += 1
        self.reads_found += 1
        if len(self.jobstack) >= self.config['batch_chunk_size']:
            self.flush_jobstack()

    def flush_jobstack(self):
        if self.running and self.jobstack:
            batch_id = self.next_batch_id
            self.next_batch_id += 1

            # Remove files already processed successfully. The same file can be
            # fed into the stack while making transition from the existing
            # files to newly updated files from the live monitoring.
            reads_to_submit = [
                readpath for readpath in self.jobstack
                if readpath not in self.reads_done
            ]
            num_canceled = len(self.jobstack) - len(reads_to_submit)
            if num_canceled:
                self.reads_queued -= num_canceled
                self.reads_found -= num_canceled
            del self.jobstack[:]

            if reads_to_submit:
                work = self.run_process_batch(batch_id, reads_to_submit)
                self.loop.create_task(work)

    async def scan_dir_recursive(self, topdir, dirname=''):
        if not self.running:
            return

        is_topdir = (dirname == '')

        try:
            errormsg = None
            dirs, files = await self.run_in_executor_mon(
                scan_dir_recursive_worker, os.path.join(topdir, dirname))
        except CancelledError as exc:
            if is_topdir: return
            else: raise exc
        except Exception as exc:
            errormsg = str(exc)

        if errormsg is not None:
            return self.errx('ERROR: ' + str(errormsg))

        for filename in files:
            filepath = os.path.join(dirname, filename)
            for readpath in get_read_ids(filepath, topdir):
                self.queue_processing(readpath)

        try:
            for subdir in dirs:
                subdirpath = os.path.join(dirname, subdir)
                await self.scan_dir_recursive(topdir, subdirpath)
        except CancelledError as exc:
            if is_topdir: return
            else: raise exc

        if is_topdir:
            self.flush_jobstack()
            self.scan_finished = True

    async def live_watch_inputs(self, topdir, suffix=FAST5_SUFFIX):
        from inotify.adapters import InotifyTree
        from inotify.constants import IN_CLOSE_WRITE, IN_MOVED_TO

        watch_flags = IN_CLOSE_WRITE | IN_MOVED_TO
        topdir = os.path.abspath(topdir + '/') + '/'  # add / for commonprefix
        is_fast5_to_analyze = lambda fn: fn[:1] != '.' and fn.lower().endswith(
            suffix)
        try:
            evgen = InotifyTree(topdir, mask=watch_flags).event_gen()
            while True:
                event = await self.run_in_executor_mon(next, evgen)
                if event is None:
                    continue

                header, type_names, path, filename = event
                if 'IN_ISDIR' in type_names:
                    continue
                if header.mask & watch_flags and is_fast5_to_analyze(filename):
                    common = os.path.commonprefix([topdir, path])
                    if common != topdir:
                        errprint(
                            "ERROR: Change of {} detected, which is outside "
                            "{}.".format(path, topdir))
                        continue
                    relpath = os.path.join(path[len(common):], filename)
                    for readpath in get_read_ids(relpath, topdir):
                        if readpath not in self.reads_done:
                            self.queue_processing(readpath)

        except CancelledError:
            pass

    async def wait_until_finish(self):
        while self.running:
            try:
                await asyncio.sleep(0.5)
            except CancelledError:
                break

            if self.scan_finished and self.reads_queued <= 0:
                break

    async def show_progresses_offline(self):
        from progressbar import ProgressBar, widgets

        barformat_notfinalized = [
            widgets.AnimatedMarker(), ' ',
            widgets.Counter(), ' ',
            widgets.BouncingBar(), ' ',
            widgets.Timer()
        ]

        class LooseAdaptiveETA(widgets.AdaptiveETA):
            # Stabilize the ETA on results from large batches rushes in
            NUM_SAMPLES = 100

        barformat_finalized = [
            widgets.AnimatedMarker(), ' ',
            widgets.Percentage(), ' of ',
            widgets.FormatLabel('%(max)d'), ' ',
            widgets.Bar(), ' ',
            widgets.Timer('Elapsed: %s'), ' ',
            LooseAdaptiveETA()
        ]

        self.pbar = ProgressBar(widgets=barformat_notfinalized)
        self.pbar.start()
        notfinalized = True

        while self.running:
            if notfinalized and self.scan_finished:
                notfinalized = False
                self.pbar = ProgressBar(maxval=self.reads_found,
                                        widgets=barformat_finalized)
                self.pbar.currval = self.reads_processed
                self.pbar.start()
            else:
                self.pbar.maxval = self.reads_found
                self.pbar.update(self.reads_processed)

            try:
                await asyncio.sleep(0.3)
            except CancelledError:
                break

    async def show_progresses_live(self):
        self.show_message('==> Entering LIVE mode.')
        self.show_message(
            '\nPress Ctrl-C when the sequencing run is finished.')
        self.show_message(
            '(!) An analysis starts at least {} seconds after the file '
            'is discovered.'.format(self.config['analysis_start_delay']))
        prev_processed = prev_queued = prev_found = -1
        prev_message_width = 0
        iterglider = cycle(r'/-\|')

        while self.running:
            changedany = (prev_processed != self.reads_processed
                          or prev_queued != self.reads_queued
                          or prev_found != self.reads_found)

            if changedany or self.active_batches > 0:
                msg = "\rLIVE [{}] {} processed, {} queued ({} total reads)".format(
                    next(iterglider), self.reads_processed, self.reads_queued,
                    self.reads_found)
                if len(msg) < prev_message_width:
                    msg += ' ' * (prev_message_width - len(msg))

                sys.stdout.write(msg)
                sys.stdout.flush()

                prev_message_width = len(msg)
                prev_processed = self.reads_processed
                prev_queued = self.reads_queued
                prev_found = self.reads_found

            try:
                await asyncio.sleep(0.3)
            except CancelledError:
                break

    async def force_flushing_stalled_queue(self):
        prev_count = -1
        heartbeat = max(10, int(self.config['analysis_start_delay'] // 2))
        stall_counter = 0
        stall_trigger = 2

        while self.running:
            try:
                await asyncio.sleep(heartbeat)
            except CancelledError:
                break

            if self.reads_found != prev_count:
                stall_counter = 0
                prev_count = self.reads_found
                continue

            if self.reads_queued > 0:
                stall_counter += 1

                if stall_counter >= stall_trigger:
                    stall_counter = 0
                    self.flush_jobstack()

    def start_dashboard(self):
        from . import dashboard

        if self.config['contig_aliases'] and self.config['minimap2_index']:
            aliases = dashboard.load_aliases(self.config['contig_aliases'])
        else:
            aliases = {}

        view = dashboard.DashboardView(self, self.config['barcode_names'],
                                       'progress', 'mapped_rate',
                                       self.config['analysis_start_delay'],
                                       aliases)
        view.start(self.loop, bool(self.config['minimap2_index']))
        return view

    def terminate_executors(self):
        force_terminate_executor(self.executor_compute)

    def finalize_results(self):
        if self.config['dump_adapter_signals']:
            self.show_message(
                "==> Creating an inventory for adapter signal dumps")
            adapter_dump_prefix = os.path.join(self.config['outputdir'],
                                               'adapter-dumps')
            create_adapter_dumps_inventory(
                os.path.join(adapter_dump_prefix, 'inventory.h5'),
                os.path.join(adapter_dump_prefix, 'part-*.h5'))

        if self.config['dump_basecalls']:
            self.show_message(
                "==> Creating an inventory for basecalled events")
            events_prefix = os.path.join(self.config['outputdir'], 'events')
            create_events_inventory(
                os.path.join(events_prefix, 'inventory.h5'),
                os.path.join(events_prefix, 'part-*.h5'))

    @classmethod
    def run(kls, config, logging):
        with kls(config, logging) as sess:
            sess.show_message("==> Processing FAST5 files")

            if config['live']:
                # Start monitoring stalled queue
                mon_task = sess.loop.create_task(
                    sess.force_flushing_stalled_queue())
            else:
                # Start monitoring finished processing
                mon_task = sess.loop.create_task(sess.wait_until_finish())

            # Start a progress updater for the user
            if config['quiet']:
                pass
            elif config['dashboard']:
                sess.dashboard = sess.start_dashboard()
            elif config['live']:
                sess.loop.create_task(sess.show_progresses_live())
            else:
                sess.loop.create_task(sess.show_progresses_offline())

            # Start the directory scanner
            scanjob = sess.scan_dir_recursive(config['inputdir'])
            sess.loop.create_task(scanjob)

            # Start the directory change watcher in the live mode
            if config['live']:
                livewatcher = sess.live_watch_inputs(config['inputdir'])
                sess.loop.create_task(livewatcher)

            try:
                sess.loop.run_until_complete(mon_task)
            except CancelledError:
                errprint('\nInterrupted')
            except Exception as exc:
                if (isinstance(exc, RuntimeError) and exc.args[0].startswith(
                        'Event loop stopped before Future')):
                    pass
                else:
                    import traceback
                    errf = StringIO()
                    traceback.print_exc(file=errf)
                    errprint('\nERROR: ' + str(exc))
                    for line in errf.getvalue().splitlines():
                        logging.error(line)

            sess.terminate_executors()
            if sess.dashboard is not None:
                sess.dashboard.stop()

            for task in asyncio.Task.all_tasks():
                if not (task.done() or task.cancelled()):
                    try:
                        try:
                            task.cancel()
                        except:
                            pass
                        sess.loop.run_until_complete(task)
                    except CancelledError:
                        errprint('\nInterrupted')
                    except Exception as exc:
                        if (isinstance(exc, RuntimeError)
                                and exc.args[0].startswith(
                                    'Event loop stopped before Future')):
                            pass
                        else:
                            errprint('\nERROR: ' + str(exc))

            if not config['quiet'] and sess.scan_finished:
                if sess.pbar is not None:
                    sess.pbar.finish()
                    sess.show_message('')

                if sess.reads_found == sess.reads_processed:
                    sess.finalize_results()
                    sess.show_message('==> Finished.')
                    return sess.finalsummary_tracker.print_results
                else:
                    sess.show_message('==> Terminated.')
Beispiel #8
0
class EnvironmentManager:
    def __init__(self):
        self.mutex = RLock()
        self.buildInProgress = {} # env.id -> mutex
        self.builder = ThreadPoolExecutor(max_workers=3)

    def __enter__(self):
        return self.builder.__enter__()

    def __exit__(self, *args, **kwargs):
        return self.builder.__exit__(*args, **kwargs)

    @staticmethod
    def _envName(env):
        """
        Return image name for given environment.
        """
        # We use database ID + 8 character prefix from Dockerfile hash in order
        # to prevent situations when database changes and local images are
        # cached
        m = hashlib.sha256()
        m.update(env.dockerfile.encode(encoding="UTF-8"))
        return f"surveyor-env-{env.id}-{m.hexdigest()[:8]}"

    def _isEnvAvailable(self, envName):
        return podman.imageExists(f"localhost/{envName}")

    def _buildContainer(self, env, onNewBuildLog):
        """
        Build container for the given environment. Return container name and
        notify about completion via Condition.

        If onNewBuildLog is passed, it gets the output line by line.
        """
        envName = self._envName(env)
        try:
            buildLog = podman.buildImage(dockerfile=env.dockerfile, tag=envName,
                args={x.key: x.value for x in env.params},
                cpuLimit=env.cpuLimit, memLimit=env.memoryLimit,
                onOutput=onNewBuildLog,
                noCache=True) # Force rebuilding the container when it downloads external dependencies
            if buildLog is not None:
                logging.info(buildLog)
        except podman.PodmanError as e:
            raise EnvironmentBuildError(
                f"Build of environment {env.id} has failed with:\n{e.log}\n\n{e}")
        finally:
            with self.mutex:
                condition = self.buildInProgress[env.id]
                del self.buildInProgress[env.id]
                with condition:
                    condition.notify_all()
        return envName

    def getImage(self, env, onNewBuildLog=None):
        """
        Return image name of an container for given BenchmarkEnvironment. The
        name is wrapped into a future as the container might be build. If
        corresponding container is not found, it is built. If the container
        cannot be built, raises EnvironmentBuildError via the future.

        If onNewBuildLog is passed, it gets the output line by line.
        """
        envName = self._envName(env)
        buildInProgress = False
        with self.mutex:
            if self._isEnvAvailable(envName):
                return asFuture(envName)
            if env.id in self.buildInProgress:
                conditionVariable = self.buildInProgress[env.id]
                buildInProgress = True
            else:
                conditionVariable = Condition()
                self.buildInProgress[env.id] = conditionVariable
        if buildInProgress:
            with conditionVariable:
                conditionVariable.wait()
            # Note that build might have failed
            if self._isEnvAvailable(envName):
                return asFuture(envName)
            return self.getImage(env)
        logging.info(f"Environment {env.id} not available, building it")
        return self.builder.submit(lambda: self._buildContainer(env, onNewBuildLog))
class BiasedBoundaryAttack:
    """
     Like BoundaryAttack, but uses biased sampling from prior beliefs (lucky guesses).

     Apart from Perlin Noise and projected gradients, this implementation contains more work that is not in the paper:
     - We try addidional patterns (single-pixel modification, jitter patterns) to escape local minima whenever the attack gets stuck
     - We dynamically tune hyperparameters according to the observed success of previous samples
     - At each step, multiple gradients are calculated to counter stochastic defenses
     - Optimized for speed: only use gradients if we can't progress without them.

    """
    def __init__(self, blackbox_model, sample_gen, substitute_model=None):
        """
        Creates a reusable instance.
        :param blackbox_model: The model to attack.
        :param sample_gen: Random sample generator.
        :param substitute_model: A surrogate model for gradients - either a TensorFlowModel, BatchTensorFlowModel or EnsembleTFModel.
        """

        self.blackbox_model = blackbox_model
        self.sample_gen = sample_gen

        self._jitter_mask = self.precalc_jitter_mask()

        # A substitute model that provides batched gradients.
        self.batch_sub_model = None
        if substitute_model is not None:
            if isinstance(substitute_model, foolbox.models.TensorFlowModel):
                self.batch_sub_model = BatchTensorflowModel(
                    substitute_model._images,
                    substitute_model._batch_logits,
                    session=substitute_model.session)
            else:
                assert isinstance(substitute_model,
                                  EnsembleTFModel) or isinstance(
                                      substitute_model, BatchTensorflowModel)
                self.batch_sub_model = substitute_model

        # We use ThreadPools to calculate candidates and surrogate gradients while we're waiting for the model's next prediction.
        self.pg_thread_pool = ThreadPoolExecutor(max_workers=1)
        self.candidate_thread_pool = ThreadPoolExecutor(max_workers=1)

    def __enter__(self):
        self.pg_thread_pool.__enter__()
        self.candidate_thread_pool.__enter__()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        # Will block until the futures are calculated. Thankfully they're not very complicated.
        self.pg_thread_pool.__exit__(exc_type, exc_value, traceback)
        self.candidate_thread_pool.__exit__(exc_type, exc_value, traceback)
        print("BiasedBoundaryAttack: all threads stopped.")

    def run_attack(self,
                   X_orig,
                   label,
                   is_targeted,
                   X_start,
                   n_calls_left_fn,
                   n_max_per_batch=50,
                   n_seconds=None,
                   source_step=1e-2,
                   spherical_step=1e-2,
                   give_up_calls_left=0,
                   give_up_dist=9999):
        """
        Runs the Biased Boundary Attack against a single image.
        The attack terminates when n_calls_left_fn() returns 0, n_seconds have elapsed, or a "give up" condition is reached.

        Give-up functionality:
        - When few calls are remaining, but the distance is still high. Could use the additional time for other images.
        - Could theoretically be used to game the final score: spend more time on imgs that will reduce the median, and give up on others
        - Largely unused (didn't get to finish this)

        :param X_orig: The original (clean) image to perturb.
        :param label: The target label (if targeted), or the original label (if untargeted).
        :param is_targeted: True if targeted.
        :param X_start: The starting point (must be of target class).
        :param n_calls_left_fn: A function that returns the currently remaining number of queries against the model.
        :param n_max_per_batch: How many samples are drawn per "batch". Samples are processed serially (the challenge doesn't allow
                                batching), but for each "batch", the attack dynamically adjusts hyperparams based on the success of
                                previous samples. This "batch" size is the max number of samples after which hyperparams are reset, and
                                a new "batch" is started. See generate_candidate().
        :param n_seconds: Maximum seconds allowed for the attack to complete.
        :param source_step: source step hyperparameter (see Boundary Attack)
        :param spherical_step: orthogonal step hyperparameter (see Boundary Attack)
        :param give_up_calls_left: give-up condition: if less than this number of calls is left
        :param give_up_dist: give-up condition: if the current L2 distance is higher than this
        :return: The best adversarial example so far.
        """

        assert len(X_orig.shape) == 3
        assert len(X_start.shape) == 3
        assert X_orig.dtype == np.float32

        time_start = timeit.default_timer()

        pg_future = None
        try:
            # WARN: Inside this function, image space is normed to [0,1]!
            X_orig = np.float32(X_orig) / 255.
            X_start = np.float32(X_start) / 255.

            label_current, dist_best = self._eval_sample(X_start, X_orig)
            if (label_current == label) != is_targeted:
                print(
                    "WARN: Starting point is not a valid adversarial example! Continuing for now."
                )

            X_adv_best = np.copy(X_start)

            # Abort if we're running out of queries
            while n_calls_left_fn() > 3:

                # Determine how many samples to draw at the current position.
                n_candidates = min(n_max_per_batch, n_calls_left_fn())

                # Calculate the projected adversarial gradient at the current position.
                #  Putting this into a ThreadPoolExecutor. While this is processing, we can already draw ~2 samples without waiting for the
                #  gradient. If the first 2 samples were unsuccessful, then the later ones can be biased with the gradient.
                # Also cancel any pending requests from previous steps.
                if pg_future is not None:
                    pg_future.cancel()
                pg_future = self.pg_thread_pool.submit(
                    self.get_projected_gradients, **{
                        "x_current": X_adv_best,
                        "x_orig": X_orig,
                        "label": label,
                        "is_targeted": is_targeted
                    })

                # Also do candidate generation with a ThreadPoolExecutor. We need to squeeze out every bit of runtime.
                # Queue the first candidate.
                candidate_future = self.candidate_thread_pool.submit(
                    self.generate_candidate, **{
                        "i": 0,
                        "n": n_candidates,
                        "x_orig": X_orig,
                        "x_current": X_adv_best,
                        "source_step": source_step,
                        "spherical_step": spherical_step,
                        "pg_future": pg_future
                    })

                for i in range(n_candidates):
                    # Get candidate and queue the next one.
                    candidate = candidate_future.result()
                    if i < n_candidates - 1:
                        candidate_future = self.candidate_thread_pool.submit(
                            self.generate_candidate, **{
                                "i": i + 1,
                                "n": n_candidates,
                                "x_orig": X_orig,
                                "x_current": X_adv_best,
                                "source_step": source_step,
                                "spherical_step": spherical_step,
                                "pg_future": pg_future
                            })

                    time_elapsed = timeit.default_timer() - time_start
                    if n_seconds is not None and time_elapsed >= n_seconds:
                        print("WARN: Running out of time! Aborting attack!")
                        return X_adv_best * 255.

                    if dist_best > give_up_dist and n_calls_left_fn(
                    ) < give_up_calls_left:
                        print(
                            "Distance is way too high, aborting attack to save time."
                        )
                        return X_adv_best * 255.

                    # Test if successful. NOTE: dist is rounded here!
                    candidate_label, rounded_dist = self._eval_sample(
                        candidate, X_orig)
                    unrounded_dist = np.linalg.norm(candidate - X_orig)
                    if (candidate_label == label) == is_targeted:
                        if unrounded_dist < dist_best:
                            print(
                                "@ {:.3f}: After {} samples, found something @ {:.3f} (rounded {:.3f})! (reduced by {:.1%})"
                                .format(dist_best, i, unrounded_dist,
                                        rounded_dist,
                                        1. - rounded_dist / dist_best))

                            # Terminate this batch (don't try the other candidates) and advance.
                            X_adv_best = candidate
                            dist_best = unrounded_dist
                            break

            return X_adv_best * 255.

        finally:
            # Be safe and wait for the gradient future. We want to be sure that no BG worker is blocking the GPU before returning.
            if pg_future is not None:
                futures.wait([pg_future])

    def generate_candidate(self, i, n, x_orig, x_current, source_step,
                           spherical_step, pg_future):

        # This runs in a loop (while i<n) per "batch".
        # Whenever a candidate is successful, a new batch is started. Therefore, i is the number of previously unsuccessful samples.
        # Trying to use this in our favor, we tune our hyperparameters based on i:
        # - As i gets higher, progressively reduce step size for the next candidate
        # - When i gets high, try to blend jitter patterns and single pixels

        # Try this only once: blend a jitter pattern that brings us closer to the source,
        # but should be invisible to the defender (if they use denoising).
        if i == int(0.7 * n):
            candidate = x_current
            fade_eps = 0.005
            while np.sum(
                    np.abs(
                        np.round(candidate * 255.) -
                        np.round(x_current * 255.))) < 0.0001:
                #print("jitter at i={} with fade_eps={}".format(i, fade_eps))
                candidate = self.generate_jitter_sample(x_orig,
                                                        x_current,
                                                        fade_eps=fade_eps)
                fade_eps += 0.005
            return candidate

        # Last resort: change single pixels to rip us out of the local minimum.
        i_pixel_start = int(0.9 * n)
        if i >= i_pixel_start:
            l0_pixel_index = i - i_pixel_start
            #print("pixel at {}".format(l0_pixel_index))
            candidate = self.generate_l0_sample(x_orig,
                                                x_current,
                                                n_px_to_change=1,
                                                px_index=l0_pixel_index)
            return candidate

        # Default: use the BBA. Scale both spherical and source step with i.
        scale = (1. - i / n) + 0.3
        c_source_step = source_step * scale
        c_spherical_step = spherical_step * scale

        # Get the adversarial projected gradient from the (other) BG worker.
        #  Create the first 2 candidates without it, so we can already start querying the model. The BG worker can finish the gradients
        #  while we're waiting for those first 2 results.
        pg_factor = 0.5
        pgs = None
        if i >= 2:
            # if pg_future.running():
            #     print("Waiting for gradients...")
            pgs = pg_future.result()
        pgs = pgs if i % 2 == 0 else None  # Only use gradient bias on every 2nd iteration.

        candidate, spherical_candidate = self.generate_boundary_sample(
            X_orig=x_orig,
            X_adv_current=x_current,
            source_step=c_source_step,
            spherical_step=c_spherical_step,
            sampling_fn=self.sample_gen.get_perlin,
            pgs_current=pgs,
            pg_factor=pg_factor)

        return candidate

    def generate_l0_sample(self, X_orig, X_aex, n_px_to_change=1, px_index=0):
        # Modified copypasta from refinement_tricks.refine_jitter().
        # Change the n-th important pixel.

        # Sort indices of the pixels, descending by difference to original.
        # TODO: try color-triples?
        i_highest_diffs = np.argsort(np.abs(X_aex - X_orig), axis=None)[::-1]

        X_candidate = X_aex.copy()

        # Try and replace n pixels at once.
        i_pxs = i_highest_diffs[px_index:px_index + n_px_to_change]
        for i_px in i_pxs:
            i_px = np.unravel_index(i_px, X_orig.shape)
            X_candidate[i_px] = X_orig[i_px]

        return X_candidate

    def precalc_jitter_mask(self):
        # Prepare a jitter mask with XOR (alternating). TODO: we could really improve this pattern. S&P noise, anyone?
        jitter_width = 5
        jitter_mask = np.empty((64, 64, 3), dtype=np.bool)
        for i in range(64):
            for j in range(64):
                jitter_mask[i, j, :] = (i % jitter_width
                                        == 0) ^ (j % jitter_width == 0)
        return jitter_mask

    def generate_jitter_sample(self, X_orig, X_aex, fade_eps=0.01):
        # Modified copypasta from refinement_tricks.refine_pixels().

        jitter_mask = self._jitter_mask

        jitter_diff = np.zeros(X_orig.shape, dtype=np.float32)
        jitter_diff[jitter_mask] = (X_aex - X_orig)[jitter_mask]

        X_candidate = X_aex - fade_eps * jitter_diff
        return X_candidate

    def generate_boundary_sample(self,
                                 X_orig,
                                 X_adv_current,
                                 source_step,
                                 spherical_step,
                                 sampling_fn,
                                 pgs_current=None,
                                 pg_factor=0.3):
        # Partially adapted from FoolBox BoundaryAttack.

        unnormalized_source_direction = np.float64(X_orig) - np.float64(
            X_adv_current)
        source_norm = np.linalg.norm(unnormalized_source_direction)
        source_direction = unnormalized_source_direction / source_norm

        # Get perturbation from provided distribution
        sampling_dir = sampling_fn()

        # ===========================================================
        # calculate candidate on sphere
        # ===========================================================
        dot = np.vdot(sampling_dir, source_direction)
        sampling_dir -= dot * source_direction  # Project orthogonal to source direction
        sampling_dir /= np.linalg.norm(sampling_dir)

        # If available: Bias the spherical dirs in direction of the adversarial gradient, which is projected onto the sphere
        if pgs_current is not None:

            # We have a bunch of gradients that we can try. Randomly select one.
            # NOTE: we found this to perform better than simply averaging the gradients.
            pg_current = pgs_current[np.random.randint(0, len(pgs_current))]

            sampling_dir = (1. -
                            pg_factor) * sampling_dir + pg_factor * pg_current
            sampling_dir /= np.linalg.norm(sampling_dir)
            sampling_dir *= spherical_step * source_norm  # Norm to length stepsize*(dist from src)

        D = 1 / np.sqrt(spherical_step**2 + 1)
        direction = sampling_dir - unnormalized_source_direction
        spherical_candidate = X_orig + D * direction

        np.clip(spherical_candidate, 0., 1., out=spherical_candidate)

        # ===========================================================
        # step towards source
        # ===========================================================
        new_source_direction = X_orig - spherical_candidate
        new_source_direction_norm = np.linalg.norm(new_source_direction)

        # length if spherical_candidate would be exactly on the sphere
        length = source_step * source_norm
        # length including correction for deviation from sphere
        deviation = new_source_direction_norm - source_norm
        length += deviation

        # make sure the step size is positive
        length = max(0, length)

        # normalize the length
        length = length / new_source_direction_norm

        candidate = spherical_candidate + length * new_source_direction
        np.clip(candidate, 0., 1., out=candidate)

        return np.float32(candidate), np.float32(spherical_candidate)

    def get_projected_gradients(self, x_current, x_orig, label, is_targeted):
        # Idea is: we have a direction (spherical candidate) in which we want to sample.
        # We know that the gradient of a substitute model, projected onto the sphere, usually points to an adversarial region.
        # Even if we are already adversarial, it should point "deeper" into that region.
        # If we sample in that direction, we should move toward the center of the adversarial cone.
        # Here, we simply project the gradient onto the same hyperplane as the spherical samples.
        #
        # Instead of a single projected gradient, this method returns an entire batch of them:
        # - Surrogate gradients are unreliable, so we sample them in a region around the current position.
        # - This gives us a similar benefit as observed "PGD with random restarts".

        source_direction = x_orig - x_current
        source_norm = np.linalg.norm(source_direction)
        source_direction = source_direction / source_norm

        # Take a tiny step towards the source before calculating the gradient. This marginally improves our results.
        step_inside = 0.002 * source_norm
        x_inside = x_current + step_inside * source_direction

        # Perturb the current position before calc'ing gradients
        n_samples = 8
        radius_max = 0.01 * source_norm
        x_perturb = sample_hypersphere(n_samples=n_samples,
                                       sample_shape=x_orig.shape,
                                       radius=1,
                                       sample_gen=self.sample_gen)
        x_perturb *= np.random.uniform(0., radius_max)

        x_inside_batch = x_inside + x_perturb

        gradients = (self.batch_sub_model.gradient(x_inside_batch * 255.,
                                                   [label] * n_samples) / 255.)
        if is_targeted:
            gradients = -gradients

        # Project the gradients.
        for i in range(n_samples):
            dot = np.vdot(gradients[i], source_direction)
            projected_gradient = gradients[
                i] - dot * source_direction  # Project orthogonal to source direction
            projected_gradient /= np.linalg.norm(
                projected_gradient)  # Norm to length 1
            gradients[i] = projected_gradient
        return gradients

    def _eval_sample(self, x, x_orig_normed=None):
        # Round, then get label and distance.
        x_rounded = np.round(np.clip(x * 255., 0, 255))
        preds = self.blackbox_model.predictions(np.uint8(x_rounded))
        label = np.argmax(preds)

        if x_orig_normed is None:
            return label
        else:
            dist = np.linalg.norm(x_rounded / 255. - x_orig_normed)
            return label, dist