Ejemplo n.º 1
0
 def _iter_zmq(self):
     with Socket(self.receiver, zmq.PULL, 'bind') as socket:
         backurl = 'tcp://%s:%s' % (config.dbserver.host, socket.port)
         task_in_url = 'tcp://%s:%s' % (config.dbserver.host,
                                        config.zworkers.task_in_port)
         with Socket(task_in_url, zmq.PUSH, 'connect') as sender:
             num_results = 0
             for args in self._genargs(backurl):
                 sender.send((self.task_func, args))
                 num_results += 1
         yield num_results
         yield from self.loop(range(num_results), iter(socket))
Ejemplo n.º 2
0
 def _iter_zmq(self):
     with Socket(self.receiver, zmq.PULL, 'bind') as socket:
         task_in_url = ('tcp://%(master_host)s:%(task_in_port)s' %
                        config.zworkers)
         with Socket(task_in_url, zmq.PUSH, 'connect') as sender:
             n = 0
             for args in self._genargs(socket.backurl):
                 sender.send((self.task_func, args))
                 n += 1
         yield n
         for _ in range(n):
             obj = socket.zsocket.recv_pyobj()
             # receive n responses for the n requests sent
             yield obj
Ejemplo n.º 3
0
 def _iter_celery(self):
     with Socket(self.receiver, zmq.PULL, 'bind') as socket:
         logging.info('Using receiver %s', socket.backurl)
         results = []
         for piks in self._genargs(socket.backurl):
             res = safetask.delay(self.task_func, piks)
             # populating Starmap.task_ids, used in celery_cleanup
             self.task_ids.append(res.task_id)
             results.append(res)
         num_results = len(results)
         yield num_results
         it = self.iter_native(results)
         isocket = iter(socket)
         while num_results:
             res = next(isocket)
             if self.calc_id and self.calc_id != res.mon.calc_id:
                 logging.warn(
                     'Discarding a result from job %d, since this '
                     'is job %d', res.mon.calc_id, self.calc_id)
                 continue
             err = next(it)
             if isinstance(err, Exception):  # TaskRevokedError
                 raise err
             num_results -= 1
             yield res
Ejemplo n.º 4
0
 def _iter_celery_zmq(self):
     with Socket(self.receiver, zmq.PULL, 'bind') as socket:
         logging.info('Using receiver %s', socket.backurl)
         it = self._iter_celery(socket.backurl)
         yield next(it)  # number of results
         isocket = iter(socket)
         for _ in it:
             yield next(isocket)
Ejemplo n.º 5
0
 def submit(self, *args, func=None, monitor=None):
     """
     Submit the given arguments to the underlying task
     """
     monitor = monitor or self.monitor
     func = func or self.task_func
     if not hasattr(self, 'socket'):  # first time
         self.__class__.running_tasks = self.tasks
         self.socket = Socket(self.receiver, zmq.PULL, 'bind').__enter__()
         monitor.backurl = 'tcp://%s:%s' % (
             config.dbserver.host, self.socket.port)
     assert not isinstance(args[-1], Monitor)  # sanity check
     dist = 'no' if self.num_tasks == 1 else self.distribute
     if dist != 'no':
         args = pickle_sequence(args)
         self.sent += numpy.array([len(p) for p in args])
     res = getattr(self, dist + '_submit')(func, args, monitor)
     self.tasks.append(res)
Ejemplo n.º 6
0
 def submit(self, *args):
     """
     Submit the given arguments to the underlying task
     """
     global running_tasks
     if not hasattr(self, 'socket'):  # first time
         running_tasks = self.tasks
         self.socket = Socket(self.receiver, zmq.PULL, 'bind').__enter__()
         self.monitor.backurl = 'tcp://%s:%s' % (config.dbserver.host,
                                                 self.socket.port)
     assert not isinstance(args[-1], Monitor)  # sanity check
     # add incremental task number and task weight
     self.monitor.task_no = len(self.tasks) + 1
     dist = 'no' if self.num_tasks == 1 else self.distribute
     if dist != 'no':
         args = pickle_sequence(args)
         self.sent += numpy.array([len(p) for p in args])
     res = getattr(self, dist + '_submit')(args)
     self.tasks.append(res)
Ejemplo n.º 7
0
def safely_call(func, args, monitor=dummy_mon):
    """
    Call the given function with the given arguments safely, i.e.
    by trapping the exceptions. Return a pair (result, exc_type)
    where exc_type is None if no exceptions occur, otherwise it
    is the exception class and the result is a string containing
    error message and traceback.

    :param func: the function to call
    :param args: the arguments
    """
    isgenfunc = inspect.isgeneratorfunction(func)
    monitor.operation = 'total ' + func.__name__
    if hasattr(args[0], 'unpickle'):
        # args is a list of Pickled objects
        args = [a.unpickle() for a in args]
    if monitor is dummy_mon:  # in the DbServer
        assert not isgenfunc, func
        return Result.new(func, args, monitor)

    mon = args[-1]
    mon.operation = 'total ' + func.__name__
    mon.measuremem = True
    if mon is not monitor:
        mon.children.append(monitor)  # monitor is a child of mon
    mon.weight = getattr(args[0], 'weight', 1.)  # used in task_info
    with Socket(monitor.backurl, zmq.PUSH, 'connect') as zsocket:
        msg = check_mem_usage()  # warn if too much memory is used
        if msg:
            zsocket.send(Result(None, mon, msg=msg))
        if inspect.isgeneratorfunction(func):
            gfunc = func
        else:

            def gfunc(*args):
                yield func(*args)

        gobj = gfunc(*args)
        for count in itertools.count():
            res = Result.new(next, (gobj, ), mon, count=count)
            # StopIteration -> TASK_ENDED
            try:
                zsocket.send(res)
            except Exception:  # like OverflowError
                _etype, exc, tb = sys.exc_info()
                err = Result(exc,
                             mon,
                             ''.join(traceback.format_tb(tb)),
                             count=count)
                zsocket.send(err)
            mon.duration = 0
            mon.counts = 0
            mon.children.clear()
            if res.msg == 'TASK_ENDED':
                break
Ejemplo n.º 8
0
 def _iter_zmq(self):
     with Socket(self.receiver, zmq.PULL, 'bind') as socket:
         task_in_url = ('tcp://%(master_host)s:%(task_in_port)s' %
                        config.zworkers)
         with Socket(task_in_url, zmq.PUSH, 'connect') as sender:
             num_results = 0
             for args in self._genargs(socket.backurl):
                 sender.send((self.task_func, args))
                 num_results += 1
         yield num_results
         isocket = iter(socket)
         while num_results:
             res = next(isocket)
             if self.calc_id and self.calc_id != res.mon.calc_id:
                 logging.warn(
                     'Discarding a result from job %d, since this '
                     'is job %d', res.mon.calc_id, self.calc_id)
                 continue
             num_results -= 1
             yield res
Ejemplo n.º 9
0
 def _iter_celery(self):
     with Socket(self.receiver, zmq.PULL, 'bind') as socket:
         backurl = 'tcp://%s:%s' % (config.dbserver.host, socket.port)
         logging.debug('Using receiver %s', backurl)
         results = []
         for piks in self._genargs(backurl):
             res = safetask.delay(self.task_func, piks)
             # populating Starmap.task_ids, used in celery_cleanup
             self.task_ids.append(res.task_id)
             results.append(res)
         num_results = len(results)
         yield num_results
         yield from self.loop(self.iter_native(results), iter(socket))
Ejemplo n.º 10
0
def safely_call(func, args):
    """
    Call the given function with the given arguments safely, i.e.
    by trapping the exceptions. Return a pair (result, exc_type)
    where exc_type is None if no exceptions occur, otherwise it
    is the exception class and the result is a string containing
    error message and traceback.

    :param func: the function to call
    :param args: the arguments
    """
    with Monitor('total ' + func.__name__, measuremem=True) as child:
        if args and hasattr(args[0], 'unpickle'):
            # args is a list of Pickled objects
            args = [a.unpickle() for a in args]
        if args and isinstance(args[-1], Monitor):
            mon = args[-1]
            mon.operation = func.__name__
            mon.children.append(child)  # child is a child of mon
        else:  # in the DbServer
            mon = child
        try:
            res = Result(func(*args), mon)
        except Exception:
            _etype, exc, tb = sys.exc_info()
            res = Result(exc, mon, ''.join(traceback.format_tb(tb)))
    # FIXME: check_mem_usage is disabled here because it's causing
    # dead locks in threads when log messages are raised.
    # Check is done anyway in other parts of the code
    # further investigation is needed
    # check_mem_usage(mon)  # check if too much memory is used
    backurl = getattr(mon, 'backurl', None)
    if backurl is None:
        return res
    with Socket(backurl, zmq.PUSH, 'connect') as zsocket:
        try:
            zsocket.send(res)
        except Exception:  # like OverflowError
            _etype, exc, tb = sys.exc_info()
            err = Result(exc, mon, ''.join(traceback.format_tb(tb)))
            zsocket.send(err)
    return zsocket.num_sent
Ejemplo n.º 11
0
 def zmq_submit(self, func, args, monitor):
     if not hasattr(self, 'sender'):
         task_in_url = 'tcp://%s:%s' % (config.dbserver.host,
                                        config.zworkers.task_in_port)
         self.sender = Socket(task_in_url, zmq.PUSH, 'connect').__enter__()
     return self.sender.send((func, args, self.task_no, monitor))
Ejemplo n.º 12
0
class Starmap(object):
    calc_id = None
    hdf5 = None
    pids = ()
    running_tasks = []  # currently running tasks

    @classmethod
    def init(cls, poolsize=None, distribute=OQ_DISTRIBUTE):
        if distribute == 'processpool' and not hasattr(cls, 'pool'):
            orig_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
            # we use spawn here to avoid deadlocks with logging, see
            # https://github.com/gem/oq-engine/pull/3923 and
            # https://codewithoutrules.com/2018/09/04/python-multiprocessing/
            cls.pool = multiprocessing.get_context('spawn').Pool(
                poolsize, init_workers)
            signal.signal(signal.SIGINT, orig_handler)
            cls.pids = [proc.pid for proc in cls.pool._pool]
        elif distribute == 'threadpool' and not hasattr(cls, 'pool'):
            cls.pool = multiprocessing.dummy.Pool(poolsize)
        elif distribute == 'no' and hasattr(cls, 'pool'):
            cls.shutdown()
        elif distribute == 'dask':
            cls.dask_client = Client(config.distribution.dask_scheduler)

    @classmethod
    def shutdown(cls):
        if hasattr(cls, 'pool'):
            cls.pool.close()
            cls.pool.terminate()
            cls.pool.join()
            del cls.pool
            cls.pids = []
        if hasattr(cls, 'dask_client'):
            del cls.dask_client

    @classmethod
    def apply(cls, task, args, concurrent_tasks=cpu_count * 3,
              maxweight=None, weight=lambda item: 1,
              key=lambda item: 'Unspecified',
              distribute=None, progress=logging.info):
        r"""
        Apply a task to a tuple of the form (sequence, \*other_args)
        by first splitting the sequence in chunks, according to the weight
        of the elements and possibly to a key (see :func:
        `openquake.baselib.general.split_in_blocks`).

        :param task: a task to run in parallel
        :param args: the arguments to be passed to the task function
        :param concurrent_tasks: hint about how many tasks to generate
        :param maxweight: if not None, used to split the tasks
        :param weight: function to extract the weight of an item in arg0
        :param key: function to extract the kind of an item in arg0
        :param distribute: if not given, inferred from OQ_DISTRIBUTE
        :param progress: logging function to use (default logging.info)
        :returns: an :class:`IterResult` object
        """
        arg0 = args[0]  # this is assumed to be a sequence
        mon = args[-1]
        args = args[1:-1]
        if maxweight:  # block_splitter is lazy
            task_args = ((blk,) + args for blk in block_splitter(
                arg0, maxweight, weight, key))
        else:  # split_in_blocks is eager
            task_args = [(blk,) + args for blk in split_in_blocks(
                arg0, concurrent_tasks or 1, weight, key)]
        return cls(task, task_args, mon, distribute, progress).submit_all()

    def __init__(self, task_func, task_args=(), monitor=None, distribute=None,
                 progress=logging.info):
        self.__class__.init(distribute=distribute or OQ_DISTRIBUTE)
        self.task_func = task_func
        self.monitor = monitor or Monitor(task_func.__name__)
        self.calc_id = getattr(self.monitor, 'calc_id', None)
        self.name = self.monitor.operation or task_func.__name__
        self.task_args = task_args
        self.distribute = distribute or oq_distribute(task_func)
        self.progress = progress
        try:
            self.num_tasks = len(self.task_args)
        except TypeError:  # generators have no len
            self.num_tasks = None
        # a task can be a function, a class or an instance with a __call__
        if inspect.isfunction(task_func):
            self.argnames = inspect.getfullargspec(task_func).args
        elif inspect.isclass(task_func):
            self.argnames = inspect.getfullargspec(task_func.__init__).args[1:]
        else:  # instance with a __call__ method
            self.argnames = inspect.getfullargspec(task_func.__call__).args[1:]
        self.receiver = 'tcp://%s:%s' % (
            config.dbserver.listen, config.dbserver.receiver_ports)
        self.sent = numpy.zeros(len(self.argnames) - 1)
        self.monitor.backurl = None  # overridden later
        self.tasks = []  # populated by .submit
        h5 = self.monitor.hdf5
        task_info = 'task_info/' + self.name
        if h5 and task_info not in h5:  # first time
            # task_info and performance_data should be generated in advance
            hdf5.create(h5, task_info, task_info_dt)
        if h5 and 'performance_data' not in h5:
            hdf5.create(h5, 'performance_data', perf_dt)

    @property
    def hdf5(self):
        return self.monitor.hdf5

    def log_percent(self):
        """
        Log the progress of the computation in percentage
        """
        done = self.total - self.todo
        percent = int(float(done) / self.total * 100)
        if not hasattr(self, 'prev_percent'):  # first time
            self.prev_percent = 0
            self.progress('Sent %s of data in %d %s task(s)',
                          humansize(self.sent.sum()), self.total, self.name)
        elif percent > self.prev_percent:
            self.progress('%s %3d%% [of %d]',
                          self.name, percent, len(self.tasks))
            self.prev_percent = percent
        return done

    def submit(self, *args, func=None, monitor=None):
        """
        Submit the given arguments to the underlying task
        """
        monitor = monitor or self.monitor
        func = func or self.task_func
        if not hasattr(self, 'socket'):  # first time
            self.__class__.running_tasks = self.tasks
            self.socket = Socket(self.receiver, zmq.PULL, 'bind').__enter__()
            monitor.backurl = 'tcp://%s:%s' % (
                config.dbserver.host, self.socket.port)
        assert not isinstance(args[-1], Monitor)  # sanity check
        dist = 'no' if self.num_tasks == 1 else self.distribute
        if dist != 'no':
            args = pickle_sequence(args)
            self.sent += numpy.array([len(p) for p in args])
        res = getattr(self, dist + '_submit')(func, args, monitor)
        self.tasks.append(res)

    @property
    def task_no(self):
        """
        :returns: number of the last submitted task, starting from 0
        """
        return len(self.tasks)

    def submit_all(self):
        """
        :returns: an IterResult object
        """
        for args in self.task_args:
            self.submit(*args)
        return self.get_results()

    def get_results(self):
        """
        :returns: an :class:`IterResult` instance
        """
        return IterResult(self._loop(), self.name, self.argnames,
                          self.sent, self.monitor.hdf5)

    def reduce(self, agg=operator.add, acc=None):
        """
        Submit all tasks and reduce the results
        """
        return self.submit_all().reduce(agg, acc)

    def __iter__(self):
        return iter(self.submit_all())

    def no_submit(self, func, args, monitor):
        return safely_call(func, args, self.task_no, monitor)

    def processpool_submit(self, func, args, monitor):
        return self.pool.apply_async(
            safely_call, (func, args, self.task_no, monitor))

    threadpool_submit = processpool_submit

    def celery_submit(self, func, args, monitor):
        return safetask.delay(func, args, self.task_no, monitor)

    def zmq_submit(self, func, args, monitor):
        if not hasattr(self, 'sender'):
            task_in_url = 'tcp://%s:%s' % (config.dbserver.host,
                                           config.zworkers.task_in_port)
            self.sender = Socket(task_in_url, zmq.PUSH, 'connect').__enter__()
        return self.sender.send((func, args, self.task_no, monitor))

    def dask_submit(self, func, args, monitor):
        return self.dask_client.submit(safely_call, func, args, self.task_no)

    def _loop(self):
        if not hasattr(self, 'socket'):  # no submit was ever made
            return ()
        if hasattr(self, 'sender'):
            self.sender.__exit__(None, None, None)
        isocket = iter(self.socket)
        self.total = self.todo = len(self.tasks)
        while self.todo:
            res = next(isocket)
            if self.calc_id and self.calc_id != res.mon.calc_id:
                logging.warning('Discarding a result from job %s, since this '
                                'is job %d', res.mon.calc_id, self.calc_id)
                continue
            elif res.msg == 'TASK_ENDED':
                self.log_percent()
                self.todo -= 1
            elif res.msg:
                logging.warning(res.msg)
            elif res.func_args:  # resubmit subtask
                func, *args = res.func_args
                self.submit(*args, func=func, monitor=res.mon)
                yield res
                self.todo += 1
            else:
                yield res
        self.log_percent()
        self.socket.__exit__(None, None, None)
        self.tasks.clear()