Пример #1
0
    def __init__(self, processes=1):
        """
        Constructor

        Parameters
        ----------
        processes: int (default=1)
            Number of processes to start
        """
        self.__manager = multiprocessing.managers.SyncManager()
        self.__manager.start()
        self.__shared_objects = self.__manager.dict({})
        self.__pool = MemmappingPool(processes=processes)
 def initialize(self, n_parallel):
     self.n_parallel = n_parallel
     if self.pool is not None:
         print("Warning: terminating existing pool")
         self.pool.terminate()
         self.queue.close()
         self.worker_queue.close()
         self.G = SharedGlobal()
     if n_parallel > 1:
         self.queue = mp.Queue()
         self.worker_queue = mp.Queue()
         self.pool = MemmappingPool(
             self.n_parallel,
             temp_folder="/tmp",
         )
Пример #3
0
 def initialize(self, n_parallel):
     self.n_parallel = n_parallel
     if self.pool is not None:
         print('Warning: terminating existing pool')
         self.pool.terminate()
         self.queue.close()
         self.worker_queue.close()
         self.G = SharedGlobal()
     if n_parallel > 1:
         self.manager = mp.Manager()
         self.queue = mp.Queue()
         self.worker_queue = mp.Queue()
         self.pool = MemmappingPool(
             self.n_parallel,
             temp_folder='/tmp',
         )
     self.initialized = True
Пример #4
0
class WorkerProcessPool(AbstractWorkerPool):
    """
    A pool of worker processes.

    Call the `work_stream` method to create a `WorkStream` instance that can
    be used to perform tasks in a separate process.

    The work stream is provided a generator that generates tasks that are to
    be executed in a pool of processes. The work stream will attempt to
    ensure that a buffer of results from those tasks is kept full; retrieving
    a result will cause the work stream to top up the result buffer as
    necessary.
    """
    def __init__(self, processes=1):
        """
        Constructor

        Parameters
        ----------
        processes: int (default=1)
            Number of processes to start
        """
        self.__manager = multiprocessing.managers.SyncManager()
        self.__manager.start()
        self.__shared_objects = self.__manager.dict({})
        self.__pool = MemmappingPool(processes=processes)

    def shared_constant(self, x):
        """
        Create a shared constant; a constant value or object that is shared
        between workers. Parameters that are sent to a worker often have to
        be sent once for each job (e.g. mini-batch). If a large object is
        sent across processes, the pickling and unpickling can impose an
        overhead sufficient to eliminate the benefits of parallel processing.
        Shared constants are sent once and cached by the workers, reducing
        this overhead.

        Parameters
        ----------
        x
            The shared constant value

        Returns
        -------
        The wrapped shared constant
        """
        return _SharedConstant(x)

    def _apply_async(self, fn, args):
        serialised_args = _serialise_args(self.__shared_objects, args)
        return self.__pool.apply_async(_apply_async_helper_proc,
                                       (self.__shared_objects, fn) +
                                       serialised_args)
Пример #5
0
class StatefulPool:
    def __init__(self):
        self.n_parallel = 1
        self.pool = None
        self.queue = None
        self.worker_queue = None
        self.G = SharedGlobal()
        self.manager = None
        self.initialized = False

    def initialize(self, n_parallel):
        self.n_parallel = n_parallel
        if self.pool is not None:
            print('Warning: terminating existing pool')
            self.pool.terminate()
            self.pool = None
            self.queue.close()
            self.worker_queue.close()
            self.G = SharedGlobal()
        if n_parallel > 1:
            self.manager = mp.Manager()
            self.queue = mp.Queue()
            self.worker_queue = mp.Queue()
            self.pool = MemmappingPool(
                self.n_parallel,
                temp_folder='/tmp',
            )
        self.initialized = True

    def close(self):
        if self.manager:
            self.manager.shutdown()
        if self.pool:
            self.pool.close()

    def run_each(self, runner, args_list=None):
        """
        Run the method on each worker process, and collect the result of
        execution.

        The runner method will receive 'g' as its first argument, followed
        by the arguments in the args_list, if any
        :return:
        """
        assert not inspect.ismethod(runner), (
            'run_each() cannot run a class method. Please ensure that runner '
            'is a function with the prototype def foo(g, ...), where g is an '
            'object of type metarl.sampler.stateful_pool.SharedGlobal')

        if args_list is None:
            args_list = [tuple()] * self.n_parallel
        assert len(args_list) == self.n_parallel
        if self.n_parallel > 1:
            results = self.pool.map_async(_worker_run_each,
                                          [(runner, args)
                                           for args in args_list])
            for i in range(self.n_parallel):
                self.worker_queue.get()
            for i in range(self.n_parallel):
                self.queue.put(None)
            return results.get()
        return [runner(self.G, *args_list[0])]

    def run_map(self, runner, args_list):
        assert not inspect.ismethod(runner), (
            'run_map() cannot run a class method. Please ensure that runner '
            "is a function with the prototype 'def foo(g, ...)', where g is "
            'an object of type metarl.sampler.stateful_pool.SharedGlobal')

        if self.n_parallel > 1:
            return self.pool.map(_worker_run_map,
                                 [(runner, args) for args in args_list])
        else:
            ret = []
            for args in args_list:
                ret.append(runner(self.G, *args))
            return ret

    def run_imap_unordered(self, runner, args_list):
        assert not inspect.ismethod(runner), (
            'run_imap_unordered() cannot run a class method. Please ensure '
            "that runner is a function with the prototype 'def foo(g, ...)', "
            'where g is an object of type '
            'metarl.sampler.stateful_pool.SharedGlobal')

        if self.n_parallel > 1:
            for x in self.pool.imap_unordered(_worker_run_map,
                                              [(runner, args)
                                               for args in args_list]):
                yield x
        else:
            for args in args_list:
                yield runner(self.G, *args)

    def run_collect(self,
                    collect_once,
                    threshold,
                    args=None,
                    show_prog_bar=True):
        """
        Run the collector method using the worker pool. The collect_once method
        will receive 'g' as its first argument, followed by the provided args,
        if any. The method should return a pair of values. The first should be
        the object to be collected, and the second is the increment to be
        added.
        This will continue until the total increment reaches or exceeds the
        given threshold.

        Sample script:

        def collect_once(g):
            return 'a', 1

        stateful_pool.run_collect(collect_once, threshold=3)
        # should return ['a', 'a', 'a']

        :param collector:
        :param threshold:
        :return:
        """
        assert not inspect.ismethod(collect_once), (
            'run_collect() cannot run a class method. Please ensure that '
            "collect_once is a function with the prototype 'def foo(g, ...)', "
            'where g is an object of type '
            'metarl.sampler.stateful_pool.SharedGlobal')

        if args is None:
            args = tuple()
        if self.pool:
            counter = self.manager.Value('i', 0)
            lock = self.manager.RLock()
            results = self.pool.map_async(
                _worker_run_collect,
                [(collect_once, counter, lock, threshold, args)] *
                self.n_parallel)
            if show_prog_bar:
                pbar = ProgBarCounter(threshold)
            last_value = 0
            while True:
                time.sleep(0.1)
                with lock:
                    if counter.value >= threshold:
                        if show_prog_bar:
                            pbar.stop()
                        break
                    if show_prog_bar:
                        pbar.inc(counter.value - last_value)
                    last_value = counter.value
            return sum(results.get(), [])
        else:
            count = 0
            results = []
            if show_prog_bar:
                pbar = ProgBarCounter(threshold)
            while count < threshold:
                result, inc = collect_once(self.G, *args)
                results.append(result)
                count += inc
                if show_prog_bar:
                    pbar.inc(inc)
            if show_prog_bar:
                pbar.stop()
            return results
        return []
Пример #6
0
class StatefulPool(object):
    def __init__(self):
        self.n_parallel = 1
        self.pool = None
        self.queue = None
        self.worker_queue = None
        self.G = SharedGlobal()

    def initialize(self, n_parallel):
        self.n_parallel = n_parallel
        if self.pool is not None:
            print("Warning: terminating existing pool")
            self.pool.terminate()
            self.queue.close()
            self.worker_queue.close()
            self.G = SharedGlobal()
        if n_parallel > 1:
            self.queue = mp.Queue()
            self.worker_queue = mp.Queue()
            self.pool = MemmappingPool(
                self.n_parallel,
                temp_folder="/tmp",
            )

    def run_each(self, runner, args_list=None):
        """
        Run the method on each worker process, and collect the result of execution.
        The runner method will receive 'G' as its first argument, followed by the arguments
        in the args_list, if any
        :return:
        """
        if args_list is None:
            args_list = [tuple()] * self.n_parallel
        assert len(args_list) == self.n_parallel
        if self.n_parallel > 1:
            #return [runner(self.G, *args_list[i]) for i in range(self.n_parallel)]
            results = self.pool.map_async(_worker_run_each,
                                          [(runner, args)
                                           for args in args_list])
            for i in range(self.n_parallel):
                self.worker_queue.get()
            for i in range(self.n_parallel):
                self.queue.put(None)
            return results.get()
        return [runner(self.G, *args_list[0])]

    def run_map(self, runner, args_list):
        if self.n_parallel > 1:
            return self.pool.map(_worker_run_map,
                                 [(runner, args) for args in args_list])
        else:
            ret = []
            for args in args_list:
                ret.append(runner(self.G, *args))
            return ret

    def run_imap_unordered(self, runner, args_list):
        if self.n_parallel > 1:
            for x in self.pool.imap_unordered(_worker_run_map,
                                              [(runner, args)
                                               for args in args_list]):
                yield x
        else:
            for args in args_list:
                yield runner(self.G, *args)

    def run_collect(self,
                    collect_once,
                    threshold,
                    args=None,
                    show_prog_bar=True,
                    multi_task=False):
        """
        Run the collector method using the worker pool. The collect_once method will receive 'G' as
        its first argument, followed by the provided args, if any. The method should return a pair of values.
        The first should be the object to be collected, and the second is the increment to be added.
        This will continue until the total increment reaches or exceeds the given threshold.

        Sample script:

        def collect_once(G):
            return 'a', 1

        stateful_pool.run_collect(collect_once, threshold=3) # => ['a', 'a', 'a']

        :param collector:
        :param threshold:
        :return:
        """
        if args is None:
            args = tuple()
        if self.pool and multi_task:
            manager = mp.Manager()
            counter = manager.Value('i', 0)
            lock = manager.RLock()

            inputs = [(collect_once, counter, lock, threshold, arg)
                      for arg in args]
            results = self.pool.map_async(
                _worker_run_collect,
                inputs,
            )
            if show_prog_bar:
                pbar = ProgBarCounter(threshold)
            last_value = 0
            while True:
                time.sleep(0.1)
                with lock:
                    if counter.value >= threshold:
                        if show_prog_bar:
                            pbar.stop()
                        break
                    if show_prog_bar:
                        pbar.inc(counter.value - last_value)
                    last_value = counter.value
            finished_results = results.get()
            # TODO - for some reason this is buggy.
            return {
                i: finished_results[i]
                for i in range(len(finished_results))
            }
        elif multi_task:
            assert False  # not supported
        elif self.pool:
            manager = mp.Manager()
            counter = manager.Value('i', 0)
            lock = manager.RLock()
            results = self.pool.map_async(
                _worker_run_collect,
                [(collect_once, counter, lock, threshold, args)] *
                self.n_parallel)
            if show_prog_bar:
                pbar = ProgBarCounter(threshold)
            last_value = 0
            while True:
                time.sleep(0.1)
                with lock:
                    if counter.value >= threshold:
                        if show_prog_bar:
                            pbar.stop()
                        break
                    if show_prog_bar:
                        pbar.inc(counter.value - last_value)
                    last_value = counter.value
            return sum(results.get(), [])
        else:
            count = 0
            results = []
            if show_prog_bar:
                pbar = ProgBarCounter(threshold)
            while count < threshold:
                result, inc = collect_once(self.G, *args)
                results.append(result)
                count += inc
                if show_prog_bar:
                    pbar.inc(inc)
            if show_prog_bar:
                pbar.stop()
            return results