Exemplo n.º 1
0
class Promise:
    def __init__(self):
        self._event = Event()
        self._data = None
        self._exception = None

    def __repr__(self):
        res = super().__repr__()
        if self.is_set():
            extra = repr(self._exception) if self._exception else repr(
                self._data)
        else:
            extra = 'unset'
        return f'<{res[1:-1]} [{extra}]>'

    def is_set(self):
        '''Return `True` if the promise is set'''
        return self._event.is_set()

    def clear(self):
        '''Clear the promise'''
        self._data = None
        self._exception = None
        self._event.clear()

    async def set(self, data):
        '''Set the promise. Wake all waiting tasks (if any).'''
        self._data = data
        await self._event.set()

    async def get(self):
        '''Wait for the promise to be set, and return the data.

        If an exception was set, it will be raised.'''
        await self._event.wait()

        if self._exception is not None:
            raise self._exception

        return self._data

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc, tb):
        if exc_type is not None:
            self._exception = exc
            await self._event.set()

            return True
Exemplo n.º 2
0
 async def wait_for_flow_control(self, stream_id):
     """
     Blocks until the flow control window for a given stream is opened.
     """
     evt = Event()
     self.flow_control_events[stream_id] = evt
     await evt.wait()
Exemplo n.º 3
0
 def __init__(self, worker_nums, worker_timeout, getter_queue, putter_queue,
              task_module_path, logger):
     self.worker_nums = worker_nums
     self.worker_timeout = worker_timeout
     self.getter_queue, self.putter_queue = getter_queue, putter_queue
     self.task_module_path = task_module_path
     self.workers = {}
     self.idle_workers = []
     self.busy_workers = {}
     self.watch_tasks = {}
     self.idle_available = Event()
     self.wait_for_idle = False
     self.logger = logger
     self.alive = True
     self.logger.debug('+++new worker pool instance++++')
     return
Exemplo n.º 4
0
 def __init__(self,
              queues,
              logger,
              getter_queue,
              putter_queue,
              qos=None,
              amqp_url='amqp://*****:*****@localhost:5672//'):
     # queues = ['q1', 'q2', 'q3', ...]
     self.logger = logger
     self.queues = set([i.upper() for i in queues])
     self.amqp_url = amqp_url
     self.getter_queue, self.putter_queue = getter_queue, putter_queue
     self.parse_amqp_url()
     self.status = ConnectionStatus.INITAL
     self.reconnect_done_event = Event()
     self.qos = qos if qos is not None else len(self.queues)
     self.logger.debug('connection initial~~~~')
     return
Exemplo n.º 5
0
    def _get_hub_class(self, hub_type, sensor, sensor_name, capabilities):
        stop_evt = Event()

        @attach(sensor, name=sensor_name, capabilities=capabilities)
        class TestHub(hub_type):
            async def sensor_change(self):
                pass

            async def run(self):
                pass
                await stop_evt.wait()

        return TestHub, stop_evt
Exemplo n.º 6
0
 def __init__(self):
     self.outgoing = Queue()
     self.incoming = Queue()
     self.closure = None
     self.closing = Event()
Exemplo n.º 7
0
class WebsocketPrototype(ABC):

    __slots__ = ('socket', 'protocol', 'outgoing', 'incoming', 'closure',
                 'closing')

    def __init__(self):
        self.outgoing = Queue()
        self.incoming = Queue()
        self.closure = None
        self.closing = Event()

    @property
    def closed(self):
        return self.closing.is_set()

    async def send(self, data):
        if self.closed:
            raise WebsocketClosedError()
        await self.outgoing.put(Message(data=data))

    async def recv(self):
        if not self.closed:
            async with TaskGroup(wait=any) as g:
                receiver = await g.spawn(self.incoming.get)
                await g.spawn(self.closing.wait)
            if g.completed is receiver:
                return receiver.result

    async def __aiter__(self):
        async for msg in self.incoming:
            yield msg

    async def close(self, code=1000, reason='Closed.'):
        await self.outgoing.put(CloseConnection(code=code, reason=reason))

    async def _handle_incoming(self):
        events = self.protocol.events()
        while not self.closed:
            try:
                data = await self.socket.recv(4096)
            except ConnectionResetError:
                return await self.closing.set()

            self.protocol.receive_data(data)
            try:
                event = next(events)
            except StopIteration:
                # Connection dropped unexpectedly
                return await self.closing.set()

            if isinstance(event, CloseConnection):
                self.closure = event
                await self.outgoing.put(event.response())
                await self.closing.set()
            elif isinstance(event, Message):
                await self.incoming.put(event.data)
            elif isinstance(event, Ping):
                await self.outgoing.put(event.response())

    async def _handle_outgoing(self):
        async for event in self.outgoing:

            if event is None or self.protocol.state is ConnectionState.CLOSED:
                return await self.closing.set()

            data = self.protocol.send(event)
            try:
                await self.socket.sendall(data)
                if isinstance(data, CloseConnection):
                    self.closure = event
                    return await self.closing.set()
            except socket.error:
                return await self.closing.set()

    async def flow(self, *tasks):
        async with TaskGroup(tasks=tasks) as ws:
            incoming = await ws.spawn(self._handle_incoming)
            outgoing = await ws.spawn(self._handle_outgoing)
            finished = await ws.next_done()

            if finished is incoming:
                await self.outgoing.put(None)
                await outgoing.join()
            elif finished in tasks:
                # Task is finished.
                # We ask for the outgoing to finish
                if finished.exception:
                    await self.close(1011, 'Task died prematurely.')
                else:
                    await self.close()
                await outgoing.join()
Exemplo n.º 8
0
class MagneConnection:

    # rabbitmq frame max size is 131072
    MAX_DATA_SIZE = 131072

    def __init__(self,
                 queues,
                 logger,
                 getter_queue,
                 putter_queue,
                 qos=None,
                 amqp_url='amqp://*****:*****@localhost:5672//'):
        # queues = ['q1', 'q2', 'q3', ...]
        self.logger = logger
        self.queues = set([i.upper() for i in queues])
        self.amqp_url = amqp_url
        self.getter_queue, self.putter_queue = getter_queue, putter_queue
        self.parse_amqp_url()
        self.status = ConnectionStatus.INITAL
        self.reconnect_done_event = Event()
        self.qos = qos if qos is not None else len(self.queues)
        self.logger.debug('connection initial~~~~')
        return

    def parse_amqp_url(self):
        protocol, address = self.amqp_url.split('://')
        assert protocol == 'amqp'
        name_pwd, ip_address = address.split('@')
        self.username, self.pwd = name_pwd.split(':')
        self.host, port = ip_address.split(':')
        if '//' in port:
            self.port, self.vhost = port.split('//')[0], '/'
        else:
            self.port, self.vhost = port.split('/')
        return

    async def assert_recv_method(self, method_class):
        data = await self.sock.recv(self.MAX_DATA_SIZE)
        frame_obj = pika.frame.decode_frame(data)[1]
        try:
            assert isinstance(frame_obj.method, method_class)
        except Exception as e:
            self.logger.error('assert_recv_method : %s, %s, %s' %
                              (method_class, frame_obj, e),
                              exc_info=True)
            raise e
        return frame_obj

    async def send_amqp_procotol_header(self):
        amqp_header_frame = pika.frame.ProtocolHeader()
        await self.sock.sendall(amqp_header_frame.marshal())
        return

    async def send_start_ok(self):
        start_ok_response = b'\0' + pika.compat.as_bytes(
            self.username) + b'\0' + pika.compat.as_bytes(self.pwd)
        start_ok_obj = pika.spec.Connection.StartOk(
            client_properties=CLIENT_INFO, response=start_ok_response)
        frame_value = pika.frame.Method(0, start_ok_obj)
        await self.sock.sendall(frame_value.marshal())
        return

    async def send_tune_ok(self):
        # TODO: for now, do not want a heartbeat
        tunk = pika.spec.Connection.TuneOk(frame_max=self.MAX_DATA_SIZE)
        frame_value = pika.frame.Method(0, tunk)
        await self.sock.sendall(frame_value.marshal())
        return

    async def send_connection_open(self):
        connection_open = pika.spec.Connection.Open(insist=True)
        frame_value = pika.frame.Method(0, connection_open)
        await self.sock.sendall(frame_value.marshal())
        # got openok
        await self.assert_recv_method(pika.spec.Connection.OpenOk)
        return

    async def connect(self):
        self.sock = await curio.open_connection(self.host, self.port)
        self.logger.debug('open amqp connection')
        # send amqp header frame
        await self.send_amqp_procotol_header()
        self.logger.debug('send amqp header')
        # got start
        await self.assert_recv_method(pika.spec.Connection.Start)
        self.logger.debug('get amqp connection.Start')
        # send start ok back
        await self.send_start_ok()
        self.logger.debug('send amqp connection.StartOk')
        # got tune
        await self.assert_recv_method(pika.spec.Connection.Tune)
        self.logger.debug('get amqp connection.Tune')
        # send tune ok
        await self.send_tune_ok()
        self.logger.debug('send amqp connection.TuneOk')
        # and we send open
        await self.send_connection_open()
        self.logger.debug(
            'send amqp connection.Open and get connection.OpenOk')
        # open channel
        await self.open_channel()
        return

    async def open_channel(self):
        # send Channel.Open
        channel_open = pika.spec.Channel.Open()
        self.logger.debug('send channel.Open')
        frame_value = pika.frame.Method(1, channel_open)
        await self.sock.sendall(frame_value.marshal())
        # got Channel.Open-Ok
        frame_obj = await self.assert_recv_method(pika.spec.Channel.OpenOk)
        self.logger.debug('get channel.OpenOk')
        self.channel_number = channel_number = frame_obj.channel_number
        assert frame_obj.channel_number == 1
        self.channel_obj = Channel(channel_number=channel_number)
        # create exchange, queue, and update QOS
        for queue_name in self.queues:
            exchange_name = queue_name
            await self.declare_exchange(channel_number, exchange_name)
            self.logger.debug('declare exchange %s' % exchange_name)
            await self.declare_queue(channel_number, queue_name)
            self.logger.debug('declare queue %s' % queue_name)
            await self.bind_queue_exchange(channel_number,
                                           exchange_name,
                                           queue_name,
                                           routing_key=queue_name)
            self.logger.debug('bind exchange %s and queue %s' %
                              (exchange_name, queue_name))
        await self.update_qos(channel_number, self.qos)
        self.logger.info('update qos %s' % self.qos)
        return

    async def declare_exchange(self, channel_number, name):
        exchange_declare = pika.spec.Exchange.Declare(exchange=name)
        frame_value = pika.frame.Method(channel_number, exchange_declare)
        await self.sock.sendall(frame_value.marshal())
        await self.assert_recv_method(pika.spec.Exchange.DeclareOk)
        return Exchange(name=name)

    async def declare_queue(self, channel_number, name):
        queue_declare = pika.spec.Queue.Declare(queue=name)
        frame_value = pika.frame.Method(channel_number, queue_declare)
        await self.sock.sendall(frame_value.marshal())
        await self.assert_recv_method(pika.spec.Queue.DeclareOk)
        return Queue(name=name)

    async def bind_queue_exchange(self, channel_number, exchange, queue,
                                  routing_key):
        queue_bind = pika.spec.Queue.Bind(queue=queue,
                                          exchange=exchange,
                                          routing_key=routing_key)
        frame_value = pika.frame.Method(channel_number, queue_bind)
        await self.sock.sendall(frame_value.marshal())
        await self.assert_recv_method(pika.spec.Queue.BindOk)
        return

    async def update_qos(self, channel_number, qos, global_=True):
        qos_obj = pika.spec.Basic.Qos(prefetch_count=qos, global_=global_)
        frame_value = pika.frame.Method(channel_number, qos_obj)
        await self.sock.sendall(frame_value.marshal())
        await self.assert_recv_method(pika.spec.Basic.QosOk)
        return

    async def ack(self, channel_number, delivery_tag):
        self.logger.debug('ack: %s, %s' % (channel_number, delivery_tag))
        ack = pika.spec.Basic.Ack(delivery_tag=delivery_tag)
        frame_value = pika.frame.Method(channel_number, ack)
        await self.sock.sendall(frame_value.marshal())
        return

    async def start_consume(self):
        # create amqp consumers
        for tag, queue_name in enumerate(self.queues):
            start_comsume = pika.spec.Basic.Consume(queue=queue_name,
                                                    consumer_tag=str(tag))
            self.logger.debug('send basic.Consume %s %s' %
                              (queue_name, str(tag)))
            frame_value = pika.frame.Method(self.channel_obj.channel_number,
                                            start_comsume)
            await self.sock.sendall(frame_value.marshal())
            data = await self.sock.recv(self.MAX_DATA_SIZE)
            count, frame_obj = pika.frame.decode_frame(data)
            # TODO: Deliver frame may comes by
            if isinstance(frame_obj.method,
                          pika.spec.Basic.ConsumeOk) is False:
                if isinstance(frame_obj.method, pika.spec.Basic.Deliver):
                    count = 0
                else:
                    raise Exception('got basic.ConsumeOk error, frame_obj %s' %
                                    frame_obj)
            self.logger.debug('get basic.ConsumeOk')
            # message data after ConsumeOk
            if len(data) > count:
                await self.send_msg(data[count:])
        self.logger.debug('start consume done!')
        self.status = ConnectionStatus.RUNNING
        return

    async def run(self):
        # send start consume
        await self.start_consume()
        # spawn wait_queue task
        self.fetch_amqp_task = await curio.spawn(self.fetch_from_amqp)
        self.logger.debug('spawn fetch_from_amqp')
        self.wait_ack_queue_task = await curio.spawn(self.wait_ack_queue)
        self.logger.debug('spawn wait_ack_queue')
        return

    async def reconnect(self):
        self.logger.info('starting reconnect~~~')
        self.reconnect_done_event.clear()
        reconnect_count = 1
        sleep_time = 0
        while True:
            try:
                await self.connect()
                await self.start_consume()
            except ConnectionRefusedError:
                sleep_time += 2 * reconnect_count
                self.logger.info('reconnect %s fail! sleep: %s seconds' %
                                 (reconnect_count, sleep_time))
                await curio.sleep(sleep_time)
                reconnect_count += 1
                continue
            except Exception as e:
                self.logger.error('got unexcept exception: %s',
                                  e,
                                  exc_info=True)
                continue
            self.logger.info('reconnect success!')
            break
        await self.reconnect_done_event.set()
        return

    async def handle_connection_error(self):
        reconnect_task = None
        try:
            if self.status != ConnectionStatus.ERROR:
                self.status = ConnectionStatus.ERROR
                # should perform reconnect
                # spawn and join!
                self.logger.info('spawn reconnect task')
                reconnect_task = await curio.spawn(self.reconnect)
                await reconnect_task.join()
            else:
                # wait for reconnect done
                await self.reconnect_done_event.wait()
        except curio.CancelledError:
            # need manually cancel reconnect_task
            if reconnect_task is not None:
                self.logger.info('reconnect_task cancel')
                await reconnect_task.cancel()
        return

    async def fetch_from_amqp(self):
        self.logger.info('staring fetch_from_amqp')
        while True:
            try:
                data = await self.sock.recv(self.MAX_DATA_SIZE)
                await self.send_msg(data)
            except ConnectionResetError:
                self.logger.error(
                    'fetch_from_amqp ConnectionResetError, wait for reconnect...'
                )
                await self.handle_connection_error()
                self.logger.debug('go on fetch_from_amqp')
            except curio.CancelledError:
                self.logger.info('fetch_from_amqp cancel')
                break
            except Exception as e:
                self.logger.error('fetch_from_amqp error: %s' % e,
                                  exc_info=True)
        return

    async def send_msg(self, data):
        # [Basic.Deliver, frame.Header, frame.Body, ...]
        bodys = []
        while data:
            count, frame_obj = pika.frame.decode_frame(data)
            data = data[count:]
            if isinstance(frame_obj.method, pika.spec.Basic.Deliver):
                body = {
                    'channel_number': frame_obj.channel_number,
                    'delivery_tag': frame_obj.method.delivery_tag,
                    'consumer_tag': frame_obj.method.consumer_tag,
                    'exchange': frame_obj.method.exchange,
                    'routing_key': frame_obj.method.routing_key,
                }
                count, frame_obj = pika.frame.decode_frame(data)
                if isinstance(frame_obj, pika.frame.Header):
                    data = data[count:]
                    count, frame_obj = pika.frame.decode_frame(data)
                    if isinstance(frame_obj, pika.frame.Body):
                        data = data[count:]
                        body['data'] = frame_obj.fragment.decode("utf-8")
                        bodys.append(json.dumps(body))
        await self.send_queue(bodys)
        return

    async def send_queue(self, datas):
        for data in datas:
            await self.putter_queue.put(data)
        return

    async def wait_ack_queue(self):
        # TODO: cancel while await ack, what should we do?
        self.logger.info('staring wait_ack_queue')
        try:
            while True:
                ack_delivery_tag = await self.getter_queue.get()
                try:
                    await self.ack(self.channel_number, ack_delivery_tag)
                except ConnectionResetError:
                    self.logger.error(
                        'wait_queue ConnectionResetError, wait for reconnect...'
                    )
                    await self.handle_connection_error()
                    self.logger.debug('go on wait_ack_queue')
                    # reinsert failed ack_delivery_tag into queue
                    self.getter_queue._queue.appendleft(ack_delivery_tag)
                    self.getter_queue._task_count += 1
        except curio.CancelledError:
            self.logger.info('wait_ack_queue cancel')
        return

    async def send_close_connection(self):
        try:
            # 302: An operator intervened to close the connection for some reason. The client may retry at some later date.
            # close channel first
            close_channel_frame = pika.spec.Channel.Close(
                reply_code=302,
                reply_text='close connection',
                class_id=0,
                method_id=0)
            close_channel_frame_value = pika.frame.Method(
                self.channel_number, close_channel_frame)
            await self.sock.sendall(close_channel_frame_value.marshal())
            await curio.timeout_after(1, self.assert_recv_method,
                                      pika.spec.Channel.CloseOk)
            self.channel_number = 0
            self.logger.info('closed channel')

            close_connection_frame = pika.spec.Connection.Close(
                reply_code=302,
                reply_text='close connection',
                class_id=0,
                method_id=0)
            frame_value = pika.frame.Method(self.channel_number,
                                            close_connection_frame)
            await self.sock.sendall(frame_value.marshal())
            await curio.timeout_after(1, self.assert_recv_method,
                                      pika.spec.Connection.CloseOk)
            self.logger.info('closed connection')
        except curio.TaskTimeout:
            self.logger.error(
                'send close connection frame got CloseOk TaskTimeout')
        except ConnectionResetError:
            self.logger.error(
                'send close connection frame ConnectionResetError')
        except Exception as e:
            self.logger.error('send close connection frame exception: %s' % e,
                              exc_info=True)
        self.logger.info('closed amqp connection')
        return

    async def close_amqp_connection(self):
        # close connection if necessarily
        if self.status & ConnectionStatus.RUNNING:
            try:
                # last ack
                self.logger.info('last ack...')
                last_ack_delivery_tags = []
                while self.getter_queue.empty() is False:
                    delivery_tag = await self.getter_queue.get()
                    last_ack_delivery_tags.append(delivery_tag)
                self.logger.debug('%s wait for last ack' %
                                  last_ack_delivery_tags)
                for d in last_ack_delivery_tags:
                    await self.ack(self.channel_number, d)
            except ConnectionResetError:
                self.logger.error('last ack occur ConnectionResetError')
            except Exception as e:
                self.logger.error('last ack occur exception: %s' % e,
                                  exc_info=True)
            else:
                self.logger.info('closing amqp connection')
                await self.send_close_connection()
        # self.status = ConnectionStatus.CLOSE means we would not accept any ack msg more
        self.status = ConnectionStatus.CLOSED
        return

    async def pre_close(self):
        # would not put any msg into queue more
        self.logger.debug('preclosing...')
        self.logger.debug('empty putter_queue')
        self.putter_queue._queue = deque
        # would not fetch any amqp msg more
        self.logger.debug('cancel fetch_amqp_task...')
        await self.fetch_amqp_task.cancel()
        self.status = self.status | ConnectionStatus.PRECLOSE
        self.logger.debug('status %s, preclose done' % self.status)
        return

    async def close(self):
        '''
        connection should not be closed independently, it should be coordinated by master
        so, before connection close, worker pool have closed already
        '''
        self.logger.debug('closing connection')
        if not (self.status & ConnectionStatus.PRECLOSE):
            self.logger.warning(
                'should pre close connection!!, now will preclose')
            await self.pre_close()
        self.logger.debug('cancel wait_ack_queue_task')
        await self.wait_ack_queue_task.cancel()
        self.logger.debug('cancel close_amqp_connection')
        await self.close_amqp_connection()
        self.logger.debug('close connection done')
        return
Exemplo n.º 9
0
 def __init__(self):
     self._event = Event()
     self._result = None
     self._exception = None
Exemplo n.º 10
0
class MagneWorkerPool:
    def __init__(self, worker_nums, worker_timeout, getter_queue, putter_queue,
                 task_module_path, logger):
        self.worker_nums = worker_nums
        self.worker_timeout = worker_timeout
        self.getter_queue, self.putter_queue = getter_queue, putter_queue
        self.task_module_path = task_module_path
        self.workers = {}
        self.idle_workers = []
        self.busy_workers = {}
        self.watch_tasks = {}
        self.idle_available = Event()
        self.wait_for_idle = False
        self.logger = logger
        self.alive = True
        self.logger.debug('+++new worker pool instance++++')
        return

    def __str__(self):
        return 'WorkerPool<ws:%s, iws:%s>' % (self.workers, self.idle_workers)

    def manage_worker(self):
        nc = len(self.workers) - self.worker_nums
        if nc < 0:
            # spawn extra workers
            while -nc:
                w = MagneWorker(self.task_module_path)
                self.workers[w.ident] = w
                self.logger.info('create new worker: %s' % w.ident)
                self.idle_workers.append(w.ident)
                nc += 1
        elif nc > 0:
            # kill idle worker
            while self.idle_workers and nc:
                w = self.idle_workers.pop()
                wobj = self.workers.pop(w)
                wobj.shutdown()
                nc -= 1
        return

    async def start(self):
        try:
            self.manage_worker()
        except Exception as e:
            self.logger.error('worker pool start exception, %s' % e,
                              exc_info=True)
            self.kill_all_workers()
            raise e
        self.wait_amqp_msg_task = await curio.spawn(self.wait_amqp_msg)
        self.logger.debug('spawn wait_amqp_msg')
        self.logger.info('worker pool started!')
        return

    async def apply(self, func_name, args, delivery_tag):
        # if we do not spawn watch task, and await worker recv,
        # we would block in recv, and can not recv next amqp msg!
        while not self.idle_workers:
            self.wait_for_idle = True
            # there is no idle worker for ready, just wait
            self.logger.debug('waiting for any idle worker...')
            await self.idle_available.wait()
            self.logger.debug('a idle worker avaliable...')
            self.idle_available.clear()
        self.wait_for_idle = False
        w = self.idle_workers.pop(0)
        wobj = self.workers[w]
        # apply msg to worker process
        try:
            await wobj.apply(func_name, args, delivery_tag)
        except Exception as e:
            self.logger.error('apply worker %s exception!: %s' %
                              (wobj.ident, e),
                              exc_info=True)
            self.idle_workers.append(w)
            await self.send_ack_queue(delivery_tag)
            return
        self.logger.debug('worker pool apply worker %s: %s %s(%s)' %
                          (w, delivery_tag, func_name, args))
        self.busy_workers[w] = wobj
        # watching task be set to daemon, it is a good idea?
        watch_worker_task = await curio.spawn(self.watch_worker,
                                              wobj,
                                              daemon=True)
        # save watch task for closing
        self.watch_tasks[wobj.ident] = watch_worker_task
        return

    async def wait_amqp_msg(self):
        self.logger.info('staring wait_amqp_msg')
        delivery_tag = None
        try:
            while True:
                msg = await self.getter_queue.get()
                self.logger.debug('wait_amqp_msg got msg %s' % msg)
                try:
                    msg_dict = json.loads(msg)
                    self.logger.debug('msg_dict %s' % msg_dict)
                    # msg_dict must contains delivery_tag!
                    delivery_tag = msg_dict['delivery_tag']
                    data = json.loads(msg_dict['data'])
                    func_name, args = data['func'], data['args']
                except Exception as e:
                    self.logger.error('invalid msg, %s, %s' % (msg, e),
                                      exc_info=True)
                    # delivery_tag must had been set!!!!
                    await self.send_ack_queue(delivery_tag)
                else:
                    self.logger.info('got a task %s, %s(%s)' %
                                     (delivery_tag, func_name, args))
                    await self.apply(func_name, args, delivery_tag)
        except curio.CancelledError:
            # cancel while await getter_queue.get, it`s fine, just discarding all msgs
            # cancel while await self.apply?
            # if had not apply to worker process yet, discarding msg is fine
            # if had apply to worker process, we will wait a least one worker timeout while close
            self.logger.info('wait_amqp_msg cancel')
        except Exception as e:
            self.logger.error('wait_amqp_msg error: %s' % e, exc_info=True)
            raise e
        return

    async def send_ack_queue(self, delivery_tag):
        self.logger.debug('send ack %s' % delivery_tag)
        try:
            await self.putter_queue.put(delivery_tag)
        except Exception as e:
            self.logger.error('worker pool send_ack_queue error, %s' % e,
                              exc_info=True)
            raise e
        return

    async def watch_worker(self, wobj):
        func_name, args = wobj.func, wobj.args
        self.logger.debug('watching worker %s for %s(%s)' %
                          (wobj.ident, func_name, args))
        success, res = False, None
        canceled = False
        try:
            # timeout will cancel coro
            success, res = await curio.timeout_after(self.worker_timeout,
                                                     wobj.recv)
        except curio.TaskTimeout:
            # got timeout
            self.logger.error('worker %s run %s(%s) timeout!' %
                              (wobj.ident, func_name, args))
            self.kill_worker(wobj)
            self.logger.info('shutdown worker %s...' % wobj.ident)
            if self.alive is True:
                # do not create new worker process while closing worker pool
                self.manage_worker()
        except curio.CancelledError:
            self.logger.info('watch %s cancel' % wobj.ident)
            canceled = True
        else:
            self.logger.info('worker %s run %s(%s) return %s, %s' %
                             (wobj.ident, func_name, args, success, res))
            del self.busy_workers[wobj.ident]
            self.idle_workers.append(wobj.ident)
        del self.watch_tasks[wobj.ident]
        # cancel would not send ack!!!!
        if canceled is False:
            await self.send_ack_queue(wobj.delivery_tag)
            if self.wait_for_idle is True:
                await self.idle_available.set()
        return

    def kill_all_workers(self):
        # unkindly kill every single worker!
        for _, wobj in list(self.workers.items()):
            self.kill_worker(wobj)
        return

    def kill_worker(self, wobj):
        self.logger.info('killing worker %s' % wobj.ident)
        wobj.shutdown()
        if wobj.ident in self.busy_workers:
            del self.busy_workers[wobj.ident]
        else:
            self.idle_workers.remove(wobj.ident)
        del self.workers[wobj.ident]
        return

    def reap_workers(self):
        try:
            while True:
                self.logger.info('reaping workers')
                wpid, status = os.waitpid(-1, os.WNOHANG)
                if not wpid:
                    self.logger.info('there is not any worker wait for reap')
                    break
                self.logger.info('reap worker %s, exit with %s' %
                                 (wpid, status))
                if wpid not in self.workers:
                    self.logger.error(
                        'worker pool do not contains reapd worker %' % wpid)
                    # TODO: how?
                else:
                    self.kill_worker(self.workers[wpid])
                self.manage_worker()
        except OSError as e:
            if e.errno != errno.ECHILD:
                self.logger.error('reap worker error signal %s' % e.errno,
                                  exc_info=True)
        return

    async def close(self, warm=True):
        # do not get amqp msg
        self.alive = False
        self.getter_queue._queue = deque()
        await self.wait_amqp_msg_task.cancel()
        # wait for worker done
        if warm is True:
            try:
                self.logger.info('watching tasks join, wait %ss' %
                                 self.worker_timeout)
                async with curio.timeout_after(self.worker_timeout):
                    async with curio.TaskGroup(
                            self.watch_tasks.values()) as wtg:
                        await wtg.join()
            except curio.TaskTimeout:
                # task_group will cancel all remaining tasks while catch TaskTimeout(CancelError), yes, that is true
                # so, we do not have to cancel all remaining tasks by ourself
                self.logger.info('watch_tasks join timeout...')
        else:
            # cold close, just cancel all watch tasks
            for watch_task_obj in list(self.watch_tasks.values()):
                await watch_task_obj.cancel()
        self.kill_all_workers()
        return
Exemplo n.º 11
0
 def __init__(self):
     self._event = Event()
     self._data = None
     self._exception = None
Exemplo n.º 12
0
 async def wait_for_flow_control(self, stream_id):
     evt = Event()
     self.flow_control_events[stream_id] = evt
     await evt.wait()